File size: 5,068 Bytes
1940078 e39c092 1940078 e7673ab 7832262 aaff9bd e4ee29f aaff9bd 1940078 e4ee29f aaff9bd e4ee29f e7673ab 3263193 3e286ee 3263193 e4ee29f 3263193 e4ee29f 3e286ee 81b4674 aaff9bd e39c092 5e1593e aaff9bd e39c092 aaff9bd e39c092 1940078 aaff9bd e39c092 aaff9bd e39c092 aaff9bd e39c092 aaff9bd 1940078 aaff9bd c6047af e39c092 c6047af e4ee29f 3263193 e4ee29f c6047af e4ee29f c6047af e39c092 c6047af e39c092 9508a99 aaff9bd e39c092 aaff9bd e39c092 aaff9bd e39c092 aaff9bd c6047af aaff9bd e4ee29f aaff9bd e4ee29f 5e1593e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from PIL import Image
import torch
import torch.serialization
import os
import hashlib
import warnings
from typing import Optional
# ===== IMPORT ALL ULTRALYTICS MODULES =====
from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
from ultralytics.nn.modules import (
Conv, Concat,
Bottleneck, C2f, SPPF,
Detect, DFL, # Added DFL
C2fAttn, ImagePoolingAttn, # Common attention modules
HGStem, HGBlock, # Additional blocks
AIFI, # Additional modules
Segment, Pose, Classify, RTDETRDecoder # Task-specific heads
)
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, PoseModel, ClassificationModel
# ===== SAFE GLOBALS CONFIGURATION =====
# Add all components to safe globals
torch.serialization.add_safe_globals([
# Torch modules
Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU,
MaxPool2d, Upsample, ModuleList,
# Ultralytics modules
DetectionModel, SegmentationModel, PoseModel, ClassificationModel,
Conv, Concat,
Bottleneck, C2f, SPPF,
Detect, DFL, # Added DFL
C2fAttn, ImagePoolingAttn,
HGStem, HGBlock,
AIFI,
Segment, Pose, Classify, RTDETRDecoder
])
# ===== MODEL CONFIG =====
MODEL_REPO = "Safi029/ABD-model"
MODEL_FILE = "ABD.pt"
EXPECTED_SHA256 = "c3335b0cc6c504c4ac74b62bf2bc9aa06ecf402fa71184ec88f40a1f37979859"
# ===== HELPER FUNCTIONS =====
def verify_model(file_path: str) -> bool:
"""Verify model integrity using SHA256 hash"""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while chunk := f.read(8192):
sha256.update(chunk)
actual_hash = sha256.hexdigest()
print(f"π Model SHA256: {actual_hash}")
return actual_hash == EXPECTED_SHA256.lower()
def download_model() -> str:
"""Download and verify model"""
os.makedirs("models", exist_ok=True)
model_path = os.path.join("models", MODEL_FILE)
if not os.path.exists(model_path) or not verify_model(model_path):
print("β¬οΈ Downloading model...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
local_dir="models",
force_download=True
)
if not verify_model(model_path):
raise ValueError("β Downloaded model failed verification!")
return model_path
def load_model(model_path: str) -> YOLO:
"""Safely load YOLO model with error handling"""
print("π§ Loading model...")
try:
# Temporary monkey patch for PyTorch 2.6+ weights_only restriction
# ONLY USE IF YOU TRUST THE MODEL SOURCE!
original_load = torch.load
torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False)
model = YOLO(model_path, task='detect')
# Restore original torch.load
torch.load = original_load
# Test with small dummy input
with torch.no_grad():
dummy = torch.zeros(1, 3, 640, 640)
model(dummy)
print("β
Model loaded and verified!")
return model
except Exception as e:
# Ensure original torch.load is restored even if error occurs
if 'original_load' in locals():
torch.load = original_load
raise RuntimeError(f"Model loading failed: {str(e)}")
# ===== GRADIO INTERFACE =====
def create_interface(model):
def detect_structure(image: Image.Image) -> Image.Image:
"""Run detection on input image"""
try:
results = model(image)
return Image.fromarray(results[0].plot())
except Exception as e:
print(f"β Inference error: {e}")
error_img = Image.new("RGB", (300, 100), color="red")
return error_img
return gr.Interface(
fn=detect_structure,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Detection Results"),
title="YOLOv8 Molecular Structure Detector",
description="π¬ Detect atoms and bonds in molecular structures",
examples=[["example.jpg"]] if os.path.exists("example.jpg") else None
)
# ===== MAIN APPLICATION =====
def main():
try:
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
# Download and load model
model_path = download_model()
model = load_model(model_path)
# Create and launch interface
demo = create_interface(model)
print("π Starting Gradio interface...")
demo.launch(
server_name="0.0.0.0",
share=False,
server_port=7860
)
except Exception as e:
print(f"β Fatal error: {str(e)}")
raise
if __name__ == "__main__":
# Suppress torch.load warnings
warnings.filterwarnings("ignore", category=UserWarning, message="torch.load")
main()
|