Safi029 commited on
Commit
e39c092
Β·
verified Β·
1 Parent(s): aaff9bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -61
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from ultralytics import YOLO
3
- from huggingface_hub import hf_hub_download, hf_hub_url
4
  from PIL import Image
5
  import torch
6
  import torch.serialization
@@ -9,14 +9,12 @@ import hashlib
9
  from typing import Optional
10
 
11
  # ===== MODEL COMPONENTS =====
12
- # Import all model components that may appear in custom YOLOv8 models
13
  from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
14
  from ultralytics.nn.modules.conv import Conv
15
  from ultralytics.nn.modules.block import Bottleneck, C2f, SPPF
16
  from ultralytics.nn.tasks import DetectionModel
17
 
18
  # ===== SAFE GLOBALS =====
19
- # Add all components to safe globals for secure loading
20
  torch.serialization.add_safe_globals([
21
  DetectionModel,
22
  Sequential,
@@ -37,91 +35,82 @@ torch.serialization.add_safe_globals([
37
  # ===== MODEL CONFIG =====
38
  MODEL_REPO = "Safi029/ABD-model"
39
  MODEL_FILE = "ABD.pt"
40
- EXPECTED_SHA256 = "a1b2c3d4..." # Replace with actual hash of your model file
41
 
42
  # ===== HELPER FUNCTIONS =====
43
- def verify_model(file_path: str, expected_hash: str) -> bool:
44
  """Verify model integrity using SHA256 hash"""
45
- if not os.path.exists(file_path):
46
- return False
47
-
48
  sha256 = hashlib.sha256()
49
  with open(file_path, "rb") as f:
50
  while chunk := f.read(8192):
51
  sha256.update(chunk)
52
- return sha256.hexdigest() == expected_hash.lower()
 
 
53
 
54
  def download_model() -> str:
55
- """Download model with verification"""
56
  os.makedirs("models", exist_ok=True)
57
  model_path = os.path.join("models", MODEL_FILE)
58
 
59
  # Download if not exists or verification fails
60
- if not os.path.exists(model_path) or not verify_model(model_path, EXPECTED_SHA256):
61
  print("⬇️ Downloading model...")
62
  hf_hub_download(
63
  repo_id=MODEL_REPO,
64
  filename=MODEL_FILE,
65
  local_dir="models",
66
- local_dir_use_symlinks=False,
67
- resume_download=True,
68
  force_download=True
69
  )
70
 
71
- if not verify_model(model_path, EXPECTED_SHA256):
72
  raise ValueError("❌ Downloaded model failed verification!")
73
 
74
  return model_path
75
 
76
  def load_model(model_path: str) -> YOLO:
77
- """Safely load YOLO model with error handling"""
78
- try:
79
- print("πŸ”§ Loading model...")
80
- model = YOLO(model_path)
81
- # Verify model loaded correctly by running a dummy inference
82
- dummy_input = torch.randn(1, 3, 640, 640)
83
- model(dummy_input) # Test forward pass
84
- print("βœ… Model loaded and verified!")
85
- return model
86
- except Exception as e:
87
- raise RuntimeError(f"Model loading failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # ===== MAIN APPLICATION =====
90
  def main():
91
  try:
92
- # 1. Download and verify model
93
- model_path = download_model()
94
 
95
- # 2. Load model
 
96
  model = load_model(model_path)
97
 
98
- # 3. Gradio interface
99
- def detect_structure(image: Image.Image) -> Image.Image:
100
- """Run detection on input image"""
101
- try:
102
- results = model(image)
103
- return Image.fromarray(results[0].plot())
104
- except Exception as e:
105
- print(f"❌ Inference error: {e}")
106
- error_img = Image.new("RGB", (300, 100), color="red")
107
- # Add error text
108
- error_img.text((10, 45), f"Error: {str(e)[:50]}...")
109
- return error_img
110
-
111
- # 4. Create interface
112
- demo = gr.Interface(
113
- fn=detect_structure,
114
- inputs=gr.Image(type="pil", label="Input Image"),
115
- outputs=gr.Image(type="pil", label="Detection Results"),
116
- title="YOLOv8 Molecular Structure Detector",
117
- description=(
118
- "πŸ”¬ Detect atoms and bonds in molecular structures. "
119
- "Upload an image of a chemical structure to see detection results."
120
- ),
121
- allow_flagging="never"
122
- )
123
-
124
- # 5. Launch
125
  print("πŸš€ Starting Gradio interface...")
126
  demo.launch(server_name="0.0.0.0", share=False)
127
 
@@ -130,12 +119,6 @@ def main():
130
  raise
131
 
132
  if __name__ == "__main__":
133
- # Check PyTorch and CUDA availability
134
- print(f"PyTorch version: {torch.__version__}")
135
- print(f"CUDA available: {torch.cuda.is_available()}")
136
-
137
  main()
138
 
139
 
140
-
141
-
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
+ from huggingface_hub import hf_hub_download
4
  from PIL import Image
5
  import torch
6
  import torch.serialization
 
9
  from typing import Optional
10
 
11
  # ===== MODEL COMPONENTS =====
 
12
  from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
13
  from ultralytics.nn.modules.conv import Conv
14
  from ultralytics.nn.modules.block import Bottleneck, C2f, SPPF
15
  from ultralytics.nn.tasks import DetectionModel
16
 
17
  # ===== SAFE GLOBALS =====
 
18
  torch.serialization.add_safe_globals([
19
  DetectionModel,
20
  Sequential,
 
35
  # ===== MODEL CONFIG =====
36
  MODEL_REPO = "Safi029/ABD-model"
37
  MODEL_FILE = "ABD.pt"
38
+ EXPECTED_SHA256 = "c3335b0cc6c504c4ac74b62bf2bc9aa06ecf402fa71184ec88f40a1f37979859"
39
 
40
  # ===== HELPER FUNCTIONS =====
41
+ def verify_model(file_path: str) -> bool:
42
  """Verify model integrity using SHA256 hash"""
 
 
 
43
  sha256 = hashlib.sha256()
44
  with open(file_path, "rb") as f:
45
  while chunk := f.read(8192):
46
  sha256.update(chunk)
47
+ actual_hash = sha256.hexdigest()
48
+ print(f"πŸ” Model SHA256: {actual_hash}")
49
+ return actual_hash == EXPECTED_SHA256.lower()
50
 
51
  def download_model() -> str:
52
+ """Download and verify model"""
53
  os.makedirs("models", exist_ok=True)
54
  model_path = os.path.join("models", MODEL_FILE)
55
 
56
  # Download if not exists or verification fails
57
+ if not os.path.exists(model_path) or not verify_model(model_path):
58
  print("⬇️ Downloading model...")
59
  hf_hub_download(
60
  repo_id=MODEL_REPO,
61
  filename=MODEL_FILE,
62
  local_dir="models",
 
 
63
  force_download=True
64
  )
65
 
66
+ if not verify_model(model_path):
67
  raise ValueError("❌ Downloaded model failed verification!")
68
 
69
  return model_path
70
 
71
  def load_model(model_path: str) -> YOLO:
72
+ """Safely load YOLO model"""
73
+ print("πŸ”§ Loading model...")
74
+ model = YOLO(model_path)
75
+ # Verify model works with dummy input
76
+ dummy_input = torch.randn(1, 3, 640, 640)
77
+ model(dummy_input) # Test forward pass
78
+ print("βœ… Model loaded and verified!")
79
+ return model
80
+
81
+ # ===== GRADIO INTERFACE =====
82
+ def create_interface(model):
83
+ def detect_structure(image: Image.Image) -> Image.Image:
84
+ """Run detection on input image"""
85
+ try:
86
+ results = model(image)
87
+ return Image.fromarray(results[0].plot())
88
+ except Exception as e:
89
+ print(f"❌ Inference error: {e}")
90
+ error_img = Image.new("RGB", (300, 100), color="red")
91
+ return error_img
92
+
93
+ return gr.Interface(
94
+ fn=detect_structure,
95
+ inputs=gr.Image(type="pil", label="Input Image"),
96
+ outputs=gr.Image(type="pil", label="Detection Results"),
97
+ title="YOLOv8 Molecular Structure Detector",
98
+ description="πŸ”¬ Detect atoms and bonds in molecular structures",
99
+ allow_flagging="never"
100
+ )
101
 
102
  # ===== MAIN APPLICATION =====
103
  def main():
104
  try:
105
+ print(f"PyTorch: {torch.__version__}")
106
+ print(f"CUDA: {torch.cuda.is_available()}")
107
 
108
+ # Download and load model
109
+ model_path = download_model()
110
  model = load_model(model_path)
111
 
112
+ # Create and launch interface
113
+ demo = create_interface(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  print("πŸš€ Starting Gradio interface...")
115
  demo.launch(server_name="0.0.0.0", share=False)
116
 
 
119
  raise
120
 
121
  if __name__ == "__main__":
 
 
 
 
122
  main()
123
 
124