Safi029 commited on
Commit
c6047af
Β·
verified Β·
1 Parent(s): e39c092

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -10,15 +10,17 @@ from typing import Optional
10
 
11
  # ===== MODEL COMPONENTS =====
12
  from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
13
- from ultralytics.nn.modules.conv import Conv
14
  from ultralytics.nn.modules.block import Bottleneck, C2f, SPPF
15
  from ultralytics.nn.tasks import DetectionModel
16
 
17
  # ===== SAFE GLOBALS =====
 
18
  torch.serialization.add_safe_globals([
19
  DetectionModel,
20
  Sequential,
21
  Conv,
 
22
  Conv2d,
23
  BatchNorm2d,
24
  SiLU,
@@ -53,7 +55,6 @@ def download_model() -> str:
53
  os.makedirs("models", exist_ok=True)
54
  model_path = os.path.join("models", MODEL_FILE)
55
 
56
- # Download if not exists or verification fails
57
  if not os.path.exists(model_path) or not verify_model(model_path):
58
  print("⬇️ Downloading model...")
59
  hf_hub_download(
@@ -69,14 +70,18 @@ def download_model() -> str:
69
  return model_path
70
 
71
  def load_model(model_path: str) -> YOLO:
72
- """Safely load YOLO model"""
73
  print("πŸ”§ Loading model...")
74
- model = YOLO(model_path)
75
- # Verify model works with dummy input
76
- dummy_input = torch.randn(1, 3, 640, 640)
77
- model(dummy_input) # Test forward pass
78
- print("βœ… Model loaded and verified!")
79
- return model
 
 
 
 
80
 
81
  # ===== GRADIO INTERFACE =====
82
  def create_interface(model):
@@ -96,7 +101,7 @@ def create_interface(model):
96
  outputs=gr.Image(type="pil", label="Detection Results"),
97
  title="YOLOv8 Molecular Structure Detector",
98
  description="πŸ”¬ Detect atoms and bonds in molecular structures",
99
- allow_flagging="never"
100
  )
101
 
102
  # ===== MAIN APPLICATION =====
@@ -112,7 +117,11 @@ def main():
112
  # Create and launch interface
113
  demo = create_interface(model)
114
  print("πŸš€ Starting Gradio interface...")
115
- demo.launch(server_name="0.0.0.0", share=False)
 
 
 
 
116
 
117
  except Exception as e:
118
  print(f"❌ Fatal error: {str(e)}")
 
10
 
11
  # ===== MODEL COMPONENTS =====
12
  from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
13
+ from ultralytics.nn.modules.conv import Conv, Concat # Added Concat
14
  from ultralytics.nn.modules.block import Bottleneck, C2f, SPPF
15
  from ultralytics.nn.tasks import DetectionModel
16
 
17
  # ===== SAFE GLOBALS =====
18
+ # Add all components including Concat to safe globals
19
  torch.serialization.add_safe_globals([
20
  DetectionModel,
21
  Sequential,
22
  Conv,
23
+ Concat, # Explicitly added
24
  Conv2d,
25
  BatchNorm2d,
26
  SiLU,
 
55
  os.makedirs("models", exist_ok=True)
56
  model_path = os.path.join("models", MODEL_FILE)
57
 
 
58
  if not os.path.exists(model_path) or not verify_model(model_path):
59
  print("⬇️ Downloading model...")
60
  hf_hub_download(
 
70
  return model_path
71
 
72
  def load_model(model_path: str) -> YOLO:
73
+ """Safely load YOLO model with error handling"""
74
  print("πŸ”§ Loading model...")
75
+ try:
76
+ model = YOLO(model_path)
77
+ # Test with small dummy input
78
+ with torch.no_grad():
79
+ dummy = torch.zeros(1, 3, 640, 640)
80
+ model(dummy)
81
+ print("βœ… Model loaded and verified!")
82
+ return model
83
+ except Exception as e:
84
+ raise RuntimeError(f"Model loading failed: {str(e)}")
85
 
86
  # ===== GRADIO INTERFACE =====
87
  def create_interface(model):
 
101
  outputs=gr.Image(type="pil", label="Detection Results"),
102
  title="YOLOv8 Molecular Structure Detector",
103
  description="πŸ”¬ Detect atoms and bonds in molecular structures",
104
+ examples=[["example.jpg"]] if os.path.exists("example.jpg") else None
105
  )
106
 
107
  # ===== MAIN APPLICATION =====
 
117
  # Create and launch interface
118
  demo = create_interface(model)
119
  print("πŸš€ Starting Gradio interface...")
120
+ demo.launch(
121
+ server_name="0.0.0.0",
122
+ share=False,
123
+ server_port=7860
124
+ )
125
 
126
  except Exception as e:
127
  print(f"❌ Fatal error: {str(e)}")