File size: 5,068 Bytes
1940078
 
e39c092
1940078
e7673ab
7832262
aaff9bd
 
e4ee29f
aaff9bd
1940078
e4ee29f
aaff9bd
e4ee29f
 
 
 
 
 
 
 
 
 
e7673ab
3263193
 
3e286ee
3263193
 
 
 
 
e4ee29f
3263193
 
e4ee29f
 
 
 
 
3e286ee
81b4674
aaff9bd
 
 
e39c092
5e1593e
aaff9bd
e39c092
aaff9bd
 
 
 
 
e39c092
 
 
1940078
aaff9bd
e39c092
aaff9bd
 
 
e39c092
aaff9bd
 
 
 
 
 
 
 
e39c092
aaff9bd
 
 
1940078
aaff9bd
c6047af
e39c092
c6047af
e4ee29f
 
 
 
 
3263193
 
e4ee29f
 
 
c6047af
 
 
 
 
 
 
e4ee29f
 
 
c6047af
e39c092
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6047af
e39c092
9508a99
aaff9bd
 
 
e39c092
 
aaff9bd
e39c092
 
aaff9bd
 
e39c092
 
aaff9bd
c6047af
 
 
 
 
aaff9bd
 
 
 
 
 
e4ee29f
 
aaff9bd
e4ee29f
5e1593e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from PIL import Image
import torch
import torch.serialization
import os
import hashlib
import warnings
from typing import Optional

# ===== IMPORT ALL ULTRALYTICS MODULES =====
from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
from ultralytics.nn.modules import (
    Conv, Concat, 
    Bottleneck, C2f, SPPF, 
    Detect, DFL,  # Added DFL
    C2fAttn, ImagePoolingAttn,  # Common attention modules
    HGStem, HGBlock,  # Additional blocks
    AIFI,  # Additional modules
    Segment, Pose, Classify, RTDETRDecoder  # Task-specific heads
)
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, PoseModel, ClassificationModel

# ===== SAFE GLOBALS CONFIGURATION =====
# Add all components to safe globals
torch.serialization.add_safe_globals([
    # Torch modules
    Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, 
    MaxPool2d, Upsample, ModuleList,
    
    # Ultralytics modules
    DetectionModel, SegmentationModel, PoseModel, ClassificationModel,
    Conv, Concat,
    Bottleneck, C2f, SPPF,
    Detect, DFL,  # Added DFL
    C2fAttn, ImagePoolingAttn,
    HGStem, HGBlock,
    AIFI,
    Segment, Pose, Classify, RTDETRDecoder
])

# ===== MODEL CONFIG =====
MODEL_REPO = "Safi029/ABD-model"
MODEL_FILE = "ABD.pt"
EXPECTED_SHA256 = "c3335b0cc6c504c4ac74b62bf2bc9aa06ecf402fa71184ec88f40a1f37979859"

# ===== HELPER FUNCTIONS =====
def verify_model(file_path: str) -> bool:
    """Verify model integrity using SHA256 hash"""
    sha256 = hashlib.sha256()
    with open(file_path, "rb") as f:
        while chunk := f.read(8192):
            sha256.update(chunk)
    actual_hash = sha256.hexdigest()
    print(f"πŸ” Model SHA256: {actual_hash}")
    return actual_hash == EXPECTED_SHA256.lower()

def download_model() -> str:
    """Download and verify model"""
    os.makedirs("models", exist_ok=True)
    model_path = os.path.join("models", MODEL_FILE)
    
    if not os.path.exists(model_path) or not verify_model(model_path):
        print("⬇️ Downloading model...")
        hf_hub_download(
            repo_id=MODEL_REPO,
            filename=MODEL_FILE,
            local_dir="models",
            force_download=True
        )
        
        if not verify_model(model_path):
            raise ValueError("❌ Downloaded model failed verification!")
    
    return model_path

def load_model(model_path: str) -> YOLO:
    """Safely load YOLO model with error handling"""
    print("πŸ”§ Loading model...")
    try:
        # Temporary monkey patch for PyTorch 2.6+ weights_only restriction
        # ONLY USE IF YOU TRUST THE MODEL SOURCE!
        original_load = torch.load
        torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False)
        
        model = YOLO(model_path, task='detect')
        
        # Restore original torch.load
        torch.load = original_load
        
        # Test with small dummy input
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 640, 640)
            model(dummy)
        print("βœ… Model loaded and verified!")
        return model
    except Exception as e:
        # Ensure original torch.load is restored even if error occurs
        if 'original_load' in locals():
            torch.load = original_load
        raise RuntimeError(f"Model loading failed: {str(e)}")

# ===== GRADIO INTERFACE =====
def create_interface(model):
    def detect_structure(image: Image.Image) -> Image.Image:
        """Run detection on input image"""
        try:
            results = model(image)
            return Image.fromarray(results[0].plot())
        except Exception as e:
            print(f"❌ Inference error: {e}")
            error_img = Image.new("RGB", (300, 100), color="red")
            return error_img

    return gr.Interface(
        fn=detect_structure,
        inputs=gr.Image(type="pil", label="Input Image"),
        outputs=gr.Image(type="pil", label="Detection Results"),
        title="YOLOv8 Molecular Structure Detector",
        description="πŸ”¬ Detect atoms and bonds in molecular structures",
        examples=[["example.jpg"]] if os.path.exists("example.jpg") else None
    )

# ===== MAIN APPLICATION =====
def main():
    try:
        print(f"PyTorch: {torch.__version__}")
        print(f"CUDA: {torch.cuda.is_available()}")
        
        # Download and load model
        model_path = download_model()
        model = load_model(model_path)
        
        # Create and launch interface
        demo = create_interface(model)
        print("πŸš€ Starting Gradio interface...")
        demo.launch(
            server_name="0.0.0.0",
            share=False,
            server_port=7860
        )

    except Exception as e:
        print(f"❌ Fatal error: {str(e)}")
        raise

if __name__ == "__main__":
    # Suppress torch.load warnings
    warnings.filterwarnings("ignore", category=UserWarning, message="torch.load")
    main()