Spaces:
Sleeping
Sleeping
| # app.py - Gradio version for Hugging Face Spaces deployment | |
| import gradio as gr | |
| import asyncio | |
| import aiohttp | |
| import json | |
| import base64 | |
| import hashlib | |
| import logging | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Tuple | |
| from PIL import Image | |
| import io | |
| import os | |
| # Optional integrations | |
| try: | |
| from rdkit import Chem | |
| from rdkit.Chem import rdMolDescriptors | |
| from rdkit import DataStructs | |
| RDKIT_AVAILABLE = True | |
| except ImportError: | |
| RDKIT_AVAILABLE = False | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class StructureRecognitionService: | |
| """Chemical structure recognition service for Gradio interface""" | |
| def __init__(self): | |
| self.molscribe_endpoint = os.getenv("MOLSCRIBE_ENDPOINT", "https://your-molscribe-api.com/predict") | |
| self.chemgraph_endpoint = os.getenv("CHEMGRAPH_ENDPOINT", "https://your-chemgraph-api.com/predict") | |
| self.model_timeout = 60 | |
| async def process_image(self, image: Image.Image, doc_id: str, image_id: str, reference_id: str) -> Dict: | |
| """Process image through ensemble models""" | |
| # Convert PIL image to base64 | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='PNG') | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
| start_time = datetime.now() | |
| # Run both models concurrently | |
| tasks = [ | |
| self._call_model(self.molscribe_endpoint, img_base64, "MolScribe"), | |
| self._call_model(self.chemgraph_endpoint, img_base64, "ChemGrapher") | |
| ] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| molscribe_result = results[0] if not isinstance(results[0], Exception) else { | |
| 'success': False, 'error': str(results[0]), 'service': 'MolScribe' | |
| } | |
| chemgraph_result = results[1] if not isinstance(results[1], Exception) else { | |
| 'success': False, 'error': str(results[1]), 'service': 'ChemGrapher' | |
| } | |
| # Cross-validate results | |
| validation = self._validate_ensemble_results(molscribe_result, chemgraph_result) | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| return { | |
| 'doc_id': doc_id, | |
| 'image_id': image_id, | |
| 'reference_id': reference_id, | |
| 'molscribe_result': molscribe_result, | |
| 'chemgraph_result': chemgraph_result, | |
| 'validation': validation, | |
| 'ensemble_smiles': validation.get('best_smiles'), | |
| 'ensemble_confidence': validation.get('ensemble_confidence', 0), | |
| 'agreement_level': validation.get('agreement_level', 'unknown'), | |
| 'processing_time': processing_time, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| async def _call_model(self, endpoint: str, image_base64: str, service_name: str) -> Dict: | |
| """Call individual structure recognition model""" | |
| payload = { | |
| 'image': image_base64, | |
| 'format': 'base64' | |
| } | |
| start_time = datetime.now() | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| endpoint, | |
| json=payload, | |
| timeout=aiohttp.ClientTimeout(total=self.model_timeout) | |
| ) as response: | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| if response.status == 200: | |
| result = await response.json() | |
| return { | |
| 'success': True, | |
| 'smiles': result.get('smiles', ''), | |
| 'confidence': float(result.get('confidence', 0)), | |
| 'processing_time': processing_time, | |
| 'service': service_name | |
| } | |
| else: | |
| error_text = await response.text() | |
| return { | |
| 'success': False, | |
| 'error': f'{service_name} API error {response.status}', | |
| 'processing_time': processing_time, | |
| 'service': service_name | |
| } | |
| except asyncio.TimeoutError: | |
| return { | |
| 'success': False, | |
| 'error': f'{service_name} timeout', | |
| 'processing_time': (datetime.now() - start_time).total_seconds(), | |
| 'service': service_name | |
| } | |
| except Exception as e: | |
| return { | |
| 'success': False, | |
| 'error': f'{service_name} failed: {str(e)}', | |
| 'processing_time': (datetime.now() - start_time).total_seconds(), | |
| 'service': service_name | |
| } | |
| def _validate_ensemble_results(self, molscribe_result: Dict, chemgraph_result: Dict) -> Dict: | |
| """Cross-validate ensemble predictions""" | |
| validation = { | |
| 'agreement_level': 'unknown', | |
| 'ensemble_confidence': 0.0, | |
| 'best_smiles': None, | |
| 'validation_notes': [], | |
| 'structural_similarity': 0.0 | |
| } | |
| molscribe_success = molscribe_result.get('success', False) | |
| chemgraph_success = chemgraph_result.get('success', False) | |
| # Both failed | |
| if not molscribe_success and not chemgraph_success: | |
| validation.update({ | |
| 'agreement_level': 'both_failed', | |
| 'validation_notes': ['Both models failed'] | |
| }) | |
| return validation | |
| # Only one succeeded | |
| if molscribe_success and not chemgraph_success: | |
| validation.update({ | |
| 'agreement_level': 'molscribe_only', | |
| 'ensemble_confidence': molscribe_result.get('confidence', 0) * 0.8, | |
| 'best_smiles': molscribe_result.get('smiles'), | |
| 'validation_notes': ['Only MolScribe succeeded'] | |
| }) | |
| return validation | |
| if chemgraph_success and not molscribe_success: | |
| validation.update({ | |
| 'agreement_level': 'chemgraph_only', | |
| 'ensemble_confidence': chemgraph_result.get('confidence', 0) * 0.8, | |
| 'best_smiles': chemgraph_result.get('smiles'), | |
| 'validation_notes': ['Only ChemGrapher succeeded'] | |
| }) | |
| return validation | |
| # Both succeeded | |
| molscribe_smiles = molscribe_result.get('smiles', '') | |
| chemgraph_smiles = chemgraph_result.get('smiles', '') | |
| molscribe_conf = molscribe_result.get('confidence', 0) | |
| chemgraph_conf = chemgraph_result.get('confidence', 0) | |
| # Exact match | |
| if molscribe_smiles == chemgraph_smiles: | |
| validation.update({ | |
| 'agreement_level': 'exact_match', | |
| 'ensemble_confidence': min(molscribe_conf, chemgraph_conf) * 1.2, | |
| 'best_smiles': molscribe_smiles, | |
| 'validation_notes': ['Exact SMILES match - high confidence'] | |
| }) | |
| return validation | |
| # Check structural similarity | |
| if RDKIT_AVAILABLE: | |
| similarity = self._calculate_similarity(molscribe_smiles, chemgraph_smiles) | |
| validation['structural_similarity'] = similarity | |
| if similarity >= 0.85: | |
| best_smiles = molscribe_smiles if molscribe_conf >= chemgraph_conf else chemgraph_smiles | |
| validation.update({ | |
| 'agreement_level': 'structural_match', | |
| 'ensemble_confidence': max(molscribe_conf, chemgraph_conf) * 1.1, | |
| 'best_smiles': self._canonicalize_smiles(best_smiles), | |
| 'validation_notes': [f'Structural similarity {similarity:.3f}'] | |
| }) | |
| return validation | |
| # Disagreement | |
| if molscribe_conf >= chemgraph_conf: | |
| validation.update({ | |
| 'agreement_level': 'disagreement_molscribe_preferred', | |
| 'ensemble_confidence': molscribe_conf * 0.7, | |
| 'best_smiles': molscribe_smiles, | |
| 'validation_notes': [f'Disagreement - MolScribe preferred (conf: {molscribe_conf:.2f})'] | |
| }) | |
| else: | |
| validation.update({ | |
| 'agreement_level': 'disagreement_chemgraph_preferred', | |
| 'ensemble_confidence': chemgraph_conf * 0.7, | |
| 'best_smiles': chemgraph_smiles, | |
| 'validation_notes': [f'Disagreement - ChemGrapher preferred (conf: {chemgraph_conf:.2f})'] | |
| }) | |
| return validation | |
| def _calculate_similarity(self, smiles1: str, smiles2: str) -> float: | |
| """Calculate structural similarity""" | |
| try: | |
| mol1 = Chem.MolFromSmiles(smiles1) | |
| mol2 = Chem.MolFromSmiles(smiles2) | |
| if mol1 is None or mol2 is None: | |
| return 0.0 | |
| fp1 = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol1, 2) | |
| fp2 = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol2, 2) | |
| return DataStructs.TanimotoSimilarity(fp1, fp2) | |
| except: | |
| return 0.0 | |
| def _canonicalize_smiles(self, smiles: str) -> str: | |
| """Convert to canonical SMILES""" | |
| if not RDKIT_AVAILABLE: | |
| return smiles | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| return Chem.MolToSmiles(mol) if mol else smiles | |
| except: | |
| return smiles | |
| # Initialize service | |
| service = StructureRecognitionService() | |
| def process_structure_sync(image, doc_id, image_id, reference_id): | |
| """Synchronous wrapper for Gradio""" | |
| if image is None: | |
| return "β Please upload an image", "", "" | |
| if not doc_id or not image_id or not reference_id: | |
| return "β Please fill in all ID fields", "", "" | |
| try: | |
| # Run async function | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| result = loop.run_until_complete( | |
| service.process_image(image, doc_id, image_id, reference_id) | |
| ) | |
| loop.close() | |
| # Format results | |
| ensemble_smiles = result.get('ensemble_smiles', 'No SMILES generated') | |
| confidence = result.get('ensemble_confidence', 0) | |
| agreement = result.get('agreement_level', 'unknown') | |
| # Summary | |
| summary = f""" | |
| **Ensemble Result:** {ensemble_smiles} | |
| **Confidence:** {confidence:.2f} | |
| **Agreement Level:** {agreement} | |
| **Processing Time:** {result.get('processing_time', 0):.2f}s | |
| **MolScribe:** | |
| - Success: {result['molscribe_result'].get('success', False)} | |
| - SMILES: {result['molscribe_result'].get('smiles', 'Failed')} | |
| - Confidence: {result['molscribe_result'].get('confidence', 0):.2f} | |
| **ChemGrapher:** | |
| - Success: {result['chemgraph_result'].get('success', False)} | |
| - SMILES: {result['chemgraph_result'].get('smiles', 'Failed')} | |
| - Confidence: {result['chemgraph_result'].get('confidence', 0):.2f} | |
| **Validation Notes:** | |
| {chr(10).join(result['validation'].get('validation_notes', []))} | |
| """ | |
| # JSON output for API consumption | |
| json_output = json.dumps(result, indent=2) | |
| return summary, ensemble_smiles, json_output | |
| except Exception as e: | |
| error_msg = f"β Processing failed: {str(e)}" | |
| return error_msg, "", "" | |
| # Gradio Interface | |
| with gr.Blocks(title="Chemical Structure Recognition Service", theme=gr.themes.Soft()) as app: | |
| gr.HTML(""" | |
| <h1 style='text-align: center; color: #2563eb;'>π§ͺ Chemical Structure Recognition Service</h1> | |
| <p style='text-align: center;'>Ensemble approach using MolScribe + ChemGrapher with cross-validation</p> | |
| <p style='text-align: center; color: #666;'>Upload chemical structure images and get validated SMILES predictions</p> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>π₯ Input</h3>") | |
| image_input = gr.Image( | |
| label="Chemical Structure Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| doc_id_input = gr.Textbox( | |
| label="Document ID", | |
| placeholder="doc_123", | |
| value="doc_001" | |
| ) | |
| image_id_input = gr.Textbox( | |
| label="Image ID", | |
| placeholder="img_456", | |
| value="img_001" | |
| ) | |
| reference_id_input = gr.Textbox( | |
| label="Reference ID", | |
| placeholder="ref_789", | |
| value="ref_001" | |
| ) | |
| process_btn = gr.Button("π¬ Recognize Structure", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.HTML("<h3>π Results</h3>") | |
| summary_output = gr.Markdown( | |
| label="Processing Summary", | |
| value="Upload an image and click 'Recognize Structure' to get results..." | |
| ) | |
| smiles_output = gr.Textbox( | |
| label="Best SMILES Result", | |
| placeholder="Ensemble SMILES will appear here..." | |
| ) | |
| with gr.Row(): | |
| gr.HTML("<h3>π Complete JSON Output</h3>") | |
| json_output = gr.Code( | |
| label="Full API Response", | |
| language="json", | |
| lines=20 | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| process_structure_sync, | |
| inputs=[image_input, doc_id_input, image_id_input, reference_id_input], | |
| outputs=[summary_output, smiles_output, json_output] | |
| ) | |
| # Footer | |
| gr.HTML(f""" | |
| <div style='text-align: center; margin-top: 20px; color: #666; border-top: 1px solid #eee; padding-top: 20px;'> | |
| <p><strong>Service Status:</strong></p> | |
| <p>𧬠RDKit Available: {'β ' if RDKIT_AVAILABLE else 'β'}</p> | |
| <p>π MolScribe Endpoint: {service.molscribe_endpoint}</p> | |
| <p>π ChemGrapher Endpoint: {service.chemgraph_endpoint}</p> | |
| <p><strong>Features:</strong> Ensemble Recognition β’ Cross-Validation β’ Structural Similarity Analysis</p> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True # Creates public link for sharing | |
| ) |