Safi029 commited on
Commit
e4ee29f
·
verified ·
1 Parent(s): 3263193

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -10
app.py CHANGED
@@ -6,14 +6,21 @@ import torch
6
  import torch.serialization
7
  import os
8
  import hashlib
 
9
  from typing import Optional
10
 
11
- # ===== IMPORT ALL NECESSARY MODULES =====
12
  from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
13
- from ultralytics.nn.modules.conv import Conv, Concat
14
- from ultralytics.nn.modules.block import Bottleneck, C2f, SPPF
15
- from ultralytics.nn.modules.head import Detect # Added Detect
16
- from ultralytics.nn.tasks import DetectionModel
 
 
 
 
 
 
17
 
18
  # ===== SAFE GLOBALS CONFIGURATION =====
19
  # Add all components to safe globals
@@ -23,10 +30,14 @@ torch.serialization.add_safe_globals([
23
  MaxPool2d, Upsample, ModuleList,
24
 
25
  # Ultralytics modules
26
- DetectionModel,
27
  Conv, Concat,
28
  Bottleneck, C2f, SPPF,
29
- Detect, # Added Detect head
 
 
 
 
30
  ])
31
 
32
  # ===== MODEL CONFIG =====
@@ -68,10 +79,16 @@ def load_model(model_path: str) -> YOLO:
68
  """Safely load YOLO model with error handling"""
69
  print("🔧 Loading model...")
70
  try:
71
- # Temporary workaround for PyTorch 2.6+ weights_only=True issues
72
- # Only use if you trust the model source!
 
 
 
73
  model = YOLO(model_path, task='detect')
74
 
 
 
 
75
  # Test with small dummy input
76
  with torch.no_grad():
77
  dummy = torch.zeros(1, 3, 640, 640)
@@ -79,6 +96,9 @@ def load_model(model_path: str) -> YOLO:
79
  print("✅ Model loaded and verified!")
80
  return model
81
  except Exception as e:
 
 
 
82
  raise RuntimeError(f"Model loading failed: {str(e)}")
83
 
84
  # ===== GRADIO INTERFACE =====
@@ -126,6 +146,8 @@ def main():
126
  raise
127
 
128
  if __name__ == "__main__":
 
 
129
  main()
130
-
131
 
 
6
  import torch.serialization
7
  import os
8
  import hashlib
9
+ import warnings
10
  from typing import Optional
11
 
12
+ # ===== IMPORT ALL ULTRALYTICS MODULES =====
13
  from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
14
+ from ultralytics.nn.modules import (
15
+ Conv, Concat,
16
+ Bottleneck, C2f, SPPF,
17
+ Detect, DFL, # Added DFL
18
+ C2fAttn, ImagePoolingAttn, # Common attention modules
19
+ HGStem, HGBlock, # Additional blocks
20
+ AIFI, # Additional modules
21
+ Segment, Pose, Classify, RTDETRDecoder # Task-specific heads
22
+ )
23
+ from ultralytics.nn.tasks import DetectionModel, SegmentationModel, PoseModel, ClassificationModel
24
 
25
  # ===== SAFE GLOBALS CONFIGURATION =====
26
  # Add all components to safe globals
 
30
  MaxPool2d, Upsample, ModuleList,
31
 
32
  # Ultralytics modules
33
+ DetectionModel, SegmentationModel, PoseModel, ClassificationModel,
34
  Conv, Concat,
35
  Bottleneck, C2f, SPPF,
36
+ Detect, DFL, # Added DFL
37
+ C2fAttn, ImagePoolingAttn,
38
+ HGStem, HGBlock,
39
+ AIFI,
40
+ Segment, Pose, Classify, RTDETRDecoder
41
  ])
42
 
43
  # ===== MODEL CONFIG =====
 
79
  """Safely load YOLO model with error handling"""
80
  print("🔧 Loading model...")
81
  try:
82
+ # Temporary monkey patch for PyTorch 2.6+ weights_only restriction
83
+ # ONLY USE IF YOU TRUST THE MODEL SOURCE!
84
+ original_load = torch.load
85
+ torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False)
86
+
87
  model = YOLO(model_path, task='detect')
88
 
89
+ # Restore original torch.load
90
+ torch.load = original_load
91
+
92
  # Test with small dummy input
93
  with torch.no_grad():
94
  dummy = torch.zeros(1, 3, 640, 640)
 
96
  print("✅ Model loaded and verified!")
97
  return model
98
  except Exception as e:
99
+ # Ensure original torch.load is restored even if error occurs
100
+ if 'original_load' in locals():
101
+ torch.load = original_load
102
  raise RuntimeError(f"Model loading failed: {str(e)}")
103
 
104
  # ===== GRADIO INTERFACE =====
 
146
  raise
147
 
148
  if __name__ == "__main__":
149
+ # Suppress torch.load warnings
150
+ warnings.filterwarnings("ignore", category=UserWarning, message="torch.load")
151
  main()
152
+
153