import gradio as gr from ultralytics import YOLO from huggingface_hub import hf_hub_download from PIL import Image import torch import torch.serialization import os import hashlib import warnings from typing import Optional # ===== IMPORT ALL ULTRALYTICS MODULES ===== from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList from ultralytics.nn.modules import ( Conv, Concat, Bottleneck, C2f, SPPF, Detect, DFL, # Added DFL C2fAttn, ImagePoolingAttn, # Common attention modules HGStem, HGBlock, # Additional blocks AIFI, # Additional modules Segment, Pose, Classify, RTDETRDecoder # Task-specific heads ) from ultralytics.nn.tasks import DetectionModel, SegmentationModel, PoseModel, ClassificationModel # ===== SAFE GLOBALS CONFIGURATION ===== # Add all components to safe globals torch.serialization.add_safe_globals([ # Torch modules Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList, # Ultralytics modules DetectionModel, SegmentationModel, PoseModel, ClassificationModel, Conv, Concat, Bottleneck, C2f, SPPF, Detect, DFL, # Added DFL C2fAttn, ImagePoolingAttn, HGStem, HGBlock, AIFI, Segment, Pose, Classify, RTDETRDecoder ]) # ===== MODEL CONFIG ===== MODEL_REPO = "Safi029/ABD-model" MODEL_FILE = "ABD.pt" EXPECTED_SHA256 = "c3335b0cc6c504c4ac74b62bf2bc9aa06ecf402fa71184ec88f40a1f37979859" # ===== HELPER FUNCTIONS ===== def verify_model(file_path: str) -> bool: """Verify model integrity using SHA256 hash""" sha256 = hashlib.sha256() with open(file_path, "rb") as f: while chunk := f.read(8192): sha256.update(chunk) actual_hash = sha256.hexdigest() print(f"🔍 Model SHA256: {actual_hash}") return actual_hash == EXPECTED_SHA256.lower() def download_model() -> str: """Download and verify model""" os.makedirs("models", exist_ok=True) model_path = os.path.join("models", MODEL_FILE) if not os.path.exists(model_path) or not verify_model(model_path): print("⬇️ Downloading model...") hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILE, local_dir="models", force_download=True ) if not verify_model(model_path): raise ValueError("❌ Downloaded model failed verification!") return model_path def load_model(model_path: str) -> YOLO: """Safely load YOLO model with error handling""" print("🔧 Loading model...") try: # Temporary monkey patch for PyTorch 2.6+ weights_only restriction # ONLY USE IF YOU TRUST THE MODEL SOURCE! original_load = torch.load torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False) model = YOLO(model_path, task='detect') # Restore original torch.load torch.load = original_load # Test with small dummy input with torch.no_grad(): dummy = torch.zeros(1, 3, 640, 640) model(dummy) print("✅ Model loaded and verified!") return model except Exception as e: # Ensure original torch.load is restored even if error occurs if 'original_load' in locals(): torch.load = original_load raise RuntimeError(f"Model loading failed: {str(e)}") # ===== GRADIO INTERFACE ===== def create_interface(model): def detect_structure(image: Image.Image) -> Image.Image: """Run detection on input image""" try: results = model(image) return Image.fromarray(results[0].plot()) except Exception as e: print(f"❌ Inference error: {e}") error_img = Image.new("RGB", (300, 100), color="red") return error_img return gr.Interface( fn=detect_structure, inputs=gr.Image(type="pil", label="Input Image"), outputs=gr.Image(type="pil", label="Detection Results"), title="YOLOv8 Molecular Structure Detector", description="🔬 Detect atoms and bonds in molecular structures", examples=[["example.jpg"]] if os.path.exists("example.jpg") else None ) # ===== MAIN APPLICATION ===== def main(): try: print(f"PyTorch: {torch.__version__}") print(f"CUDA: {torch.cuda.is_available()}") # Download and load model model_path = download_model() model = load_model(model_path) # Create and launch interface demo = create_interface(model) print("🚀 Starting Gradio interface...") demo.launch( server_name="0.0.0.0", share=False, server_port=7860 ) except Exception as e: print(f"❌ Fatal error: {str(e)}") raise if __name__ == "__main__": # Suppress torch.load warnings warnings.filterwarnings("ignore", category=UserWarning, message="torch.load") main()