astrosbd commited on
Commit
ea287ef
·
verified ·
1 Parent(s): c093367

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -148
app.py CHANGED
@@ -27,57 +27,60 @@ import base64
27
  import io
28
 
29
  # --------------------------------------------------------------------------------------
30
- # PATCHED MODEL LOADING
31
  # --------------------------------------------------------------------------------------
32
 
33
- def patch_transformers_for_radio():
34
- """Patch transformers to handle missing ls1 parameters in C-RADIOv3-B"""
35
  try:
36
- import transformers.modeling_utils
 
 
37
 
38
- # Store original function
39
- if not hasattr(transformers.modeling_utils, '_original_load_state_dict'):
40
- transformers.modeling_utils._original_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
 
 
 
41
 
42
- def patched_load_state_dict_into_meta_model(model, state_dict, device_map=None,
43
- offload_folder=None, dtype=None,
44
- offload_state_dict=None,
45
- offload_buffers=None,
46
- keep_in_fp32_modules=None,
47
- tied_params=None,
48
- **kwargs):
49
- """Patched loader that ignores missing ls1 keys"""
50
-
51
- # Filter out any existing ls1 fake keys if they exist
52
- filtered_state = {k: v for k, v in state_dict.items()
53
- if not ('ls1.gamma' in k or 'ls1.grandma' in k)}
54
-
55
- # Try loading with the original function
56
- try:
57
- return transformers.modeling_utils._original_load_state_dict(
58
- model, filtered_state, device_map, offload_folder, dtype,
59
- offload_state_dict, offload_buffers, keep_in_fp32_modules,
60
- tied_params, **kwargs
61
- )
62
- except KeyError as e:
63
- if "ls1.gamma" in str(e) or "ls1.grandma" in str(e):
64
- print(f"⚠️ Ignoring missing layer scaling parameters: {e}")
65
- # Return empty dicts to indicate successful loading
66
- return {}, {}
67
- raise
68
 
69
- # Apply the patch
70
- transformers.modeling_utils._load_state_dict_into_meta_model = patched_load_state_dict_into_meta_model
71
- print(" Applied compatibility patch for C-RADIOv3-B")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  return True
73
 
74
  except Exception as e:
75
- print(f"⚠️ Could not apply patch: {e}")
76
  return False
77
 
78
- # Apply the patch at module load time
79
- patch_transformers_for_radio()
80
-
81
  # --------------------------------------------------------------------------------------
82
  # Check Detectron2
83
  # --------------------------------------------------------------------------------------
@@ -128,92 +131,92 @@ else:
128
 
129
  print(f"🖥️ Using device: {DEVICE}")
130
 
131
- # Global variables for C model
132
  image_processor = None
133
  model = None
134
  ai_detection_classifier = None
135
  _preloaded = False
 
136
 
137
  # --------------------------------------------------------------------------------------
138
- # FIXED Model Loading
139
  # --------------------------------------------------------------------------------------
140
 
141
  def preload_models():
142
- """Preload models with compatibility fixes"""
143
- global image_processor, model, _preloaded
144
 
145
  if _preloaded:
146
  print("✅ Models already loaded")
147
  return True
148
 
149
- print("🔄 Preloading C-RADIOv3-B model...")
150
 
151
- try:
152
- hf_repo = os.getenv('MODEL_REPO', 'nvidia/C-RADIOv3-B')
153
-
154
- if hf_repo == 'fallback':
155
- hf_repo = 'nvidia/C-RADIOv3-B'
156
-
157
- print(f"📦 Loading from: {hf_repo}")
158
-
159
- # Method 1: Try with patched loader
160
  try:
161
- # Ensure patch is applied
162
- patch_transformers_for_radio()
163
 
164
- # Load image processor
165
- from transformers import CLIPImageProcessor, AutoImageProcessor
166
- try:
167
- image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
168
- except:
169
- image_processor = AutoImageProcessor.from_pretrained(hf_repo)
170
 
171
- # Suppress the specific warning we know about
172
  with warnings.catch_warnings():
173
- warnings.filterwarnings("ignore", message="Couldn't find the key")
174
-
175
- # Load model with low_cpu_mem_usage=False to avoid meta model issues
176
- model = AutoModel.from_pretrained(
177
- hf_repo,
178
- trust_remote_code=True,
179
- low_cpu_mem_usage=False, # Important: disable meta model loading
180
- ignore_mismatched_sizes=True
181
- )
182
-
183
- model = model.to(DEVICE)
184
- model.eval()
185
-
186
- print("✅ C-RADIOv3-B model loaded successfully with compatibility fixes!")
187
- _preloaded = True
188
- return True
189
-
190
- except Exception as e1:
191
- print(f"⚠️ Method 1 failed: {e1}")
192
-
193
- # Method 2: Try loading without trust_remote_code
194
- try:
195
- print("Trying alternative loading method...")
196
-
197
- # Use a simpler CLIP model as fallback
198
- from transformers import CLIPModel, CLIPProcessor
199
-
200
- fallback_model = "openai/clip-vit-base-patch32"
201
- print(f"Loading fallback model: {fallback_model}")
202
-
203
- image_processor = CLIPProcessor.from_pretrained(fallback_model)
204
- model = CLIPModel.from_pretrained(fallback_model)
205
- model = model.to(DEVICE)
206
- model.eval()
207
-
208
- print("✅ Loaded fallback CLIP model successfully!")
209
- _preloaded = True
210
- return True
211
-
212
- except Exception as e2:
213
- print(f"⚠️ Method 2 failed: {e2}")
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  except Exception as e:
216
- print(f"❌ Could not preload model: {e}")
217
  traceback.print_exc()
218
 
219
  return False
@@ -317,7 +320,6 @@ def run_damage_detection(pil_image: Image.Image, score_thresh: float = 0.5):
317
 
318
  except Exception as e:
319
  print(f"⚠️ Stage 1 error: {e}")
320
- traceback.print_exc()
321
  # Fallback to simulator
322
  rgb = np.array(pil_image.convert("RGB"))
323
  boxes = simulate_damage_detection(rgb, seed_from=rgb)
@@ -375,13 +377,13 @@ def preprocess_image(image) -> Optional[Image.Image]:
375
  else:
376
  image = np.clip(image, 0, 255).astype(np.uint8)
377
 
378
- pil = Image.fromarray(image, 'RGB')
379
  else:
380
  # Try to convert whatever it is
381
  arr = np.array(image)
382
  if arr.dtype != np.uint8:
383
  arr = np.clip(arr, 0, 255).astype(np.uint8)
384
- pil = Image.fromarray(arr, 'RGB')
385
 
386
  # Handle EXIF orientation
387
  pil = ImageOps.exif_transpose(pil)
@@ -393,8 +395,8 @@ def preprocess_image(image) -> Optional[Image.Image]:
393
  return None
394
 
395
  def extract_features(image, return_stats=False):
396
- """Extract features with proper handling for different model types."""
397
- global image_processor, model
398
 
399
  if image_processor is None or model is None:
400
  raise Exception("Model not initialized")
@@ -410,42 +412,40 @@ def extract_features(image, return_stats=False):
410
  # Process image
411
  inputs = image_processor(images=image, return_tensors='pt', do_resize=True)
412
 
413
- # Handle different processor outputs
414
  if hasattr(inputs, 'pixel_values'):
415
  pixel_values = inputs.pixel_values.to(DEVICE)
416
  else:
417
- pixel_values = inputs['input_ids'].to(DEVICE) if 'input_ids' in inputs else inputs.to(DEVICE)
418
 
419
- # Get features
420
  with torch.no_grad():
421
- outputs = model(pixel_values)
422
-
423
- # Handle different model outputs
424
- if hasattr(model, 'get_image_features'):
425
- # CLIP model
426
- features = model.get_image_features(pixel_values)
427
- elif isinstance(outputs, dict):
428
- # Dictionary output
429
- if 'features' in outputs:
430
- features = outputs['features']
431
- elif 'last_hidden_state' in outputs:
432
- features = outputs['last_hidden_state']
433
- elif 'pooler_output' in outputs:
434
- features = outputs['pooler_output']
435
  else:
436
- # Take the first tensor value
437
- features = next(iter(outputs.values()))
438
- elif isinstance(outputs, (list, tuple)):
439
- # Tuple/list output - take last element
440
- features = outputs[-1] if len(outputs) > 1 else outputs[0]
441
- else:
442
- # Direct tensor output
443
- features = outputs
 
 
 
 
 
 
 
 
 
444
 
445
  # Pool if needed
446
  if features.ndim == 3: # (B, T, C)
447
  features = features.mean(dim=1)
448
- elif features.ndim == 4: # (B, C, H, W)
449
  features = features.mean(dim=(2, 3))
450
 
451
  # Normalize and flatten
@@ -458,7 +458,8 @@ def extract_features(image, return_stats=False):
458
  "std": float(features.std()),
459
  "min": float(features.min()),
460
  "max": float(features.max()),
461
- "shape": features.shape
 
462
  }
463
  return features, stats
464
 
@@ -648,10 +649,14 @@ def create_gradio_interface():
648
  print(f"⚠️ Stage 1 error: {e}")
649
 
650
  # Status display
651
- if isinstance(detailed_result, dict) and detailed_result.get("is_demo"):
652
- status_html = '<div style="padding: 10px; background: #fef3c7; border-radius: 8px;"><p style="margin: 0; color: #f59e0b;">⚠️ Running in Demo Mode</p></div>'
 
 
 
 
653
  else:
654
- status_html = '<div style="padding: 10px; background: #d1fae5; border-radius: 8px;"><p style="margin: 0; color: #10b981;">✅ Analysis Complete</p></div>'
655
 
656
  return simple_result, detailed_result, status_html, dmg_results, annotated
657
 
@@ -688,12 +693,17 @@ def create_gradio_interface():
688
  with gr.Accordion("ℹ️ About", open=False):
689
  gr.Markdown("""
690
  ### Pipeline
691
- - **Stage 1**: Detectron2 damage detection (optional)
692
- - **Stage 2**: Visual features + AI detection classifier
 
 
 
 
 
693
 
694
- ### Notes
695
- - Falls back to demo mode if models are unavailable
696
- - C-RADIOv3-B model includes compatibility fixes for layer scaling issues
697
  """)
698
 
699
  return app
@@ -712,14 +722,15 @@ if __name__ == "__main__":
712
 
713
  # Preload models with fixes
714
  if preload_models():
715
- print(" Models preloaded successfully")
 
716
  else:
717
- print("⚠️ Running in demo mode")
718
 
719
  # Load classifier
720
  model_path = huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH
721
  if load_ai_detection_classifier(model_path):
722
- print("✅ Classifier loaded")
723
 
724
  print("=" * 60)
725
 
 
27
  import io
28
 
29
  # --------------------------------------------------------------------------------------
30
+ # FIXED PATCHING FOR C-RADIOv3-B
31
  # --------------------------------------------------------------------------------------
32
 
33
+ def patch_dinov2_architecture():
34
+ """Patch the DINOv2 architecture directly to handle missing ls1 parameters"""
35
  try:
36
+ # Try to import and patch the dinov2_arch module if it exists
37
+ import sys
38
+ from huggingface_hub import hf_hub_download
39
 
40
+ # Download the dinov2_arch.py file
41
+ dinov2_path = hf_hub_download(
42
+ repo_id="nvidia/C-RADIOv3-B",
43
+ filename="dinov2_arch.py",
44
+ cache_dir=".cache"
45
+ )
46
 
47
+ # Read the file
48
+ with open(dinov2_path, 'r') as f:
49
+ dinov2_code = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Replace the error-raising code with a warning
52
+ dinov2_code = dinov2_code.replace(
53
+ 'raise KeyError(f"Couldn\'t find the key {key_a} nor {key_b} in the state dict!")',
54
+ '''
55
+ # Patched: Use default values instead of raising error
56
+ import torch.nn as nn
57
+ if not hasattr(self, 'ls1'):
58
+ self.ls1 = nn.Identity() # Use identity as fallback
59
+ print(f"Warning: Missing keys {key_a} and {key_b}, using Identity layer as fallback")
60
+ return
61
+ '''
62
+ )
63
+
64
+ # Save patched version
65
+ patched_path = ".cache/dinov2_arch_patched.py"
66
+ os.makedirs(".cache", exist_ok=True)
67
+ with open(patched_path, 'w') as f:
68
+ f.write(dinov2_code)
69
+
70
+ # Import the patched version
71
+ import importlib.util
72
+ spec = importlib.util.spec_from_file_location("dinov2_arch_patched", patched_path)
73
+ patched_module = importlib.util.module_from_spec(spec)
74
+ sys.modules['dinov2_arch'] = patched_module
75
+ spec.loader.exec_module(patched_module)
76
+
77
+ print("✅ Applied architecture patch for DINOv2")
78
  return True
79
 
80
  except Exception as e:
81
+ print(f"⚠️ Could not patch DINOv2 architecture: {e}")
82
  return False
83
 
 
 
 
84
  # --------------------------------------------------------------------------------------
85
  # Check Detectron2
86
  # --------------------------------------------------------------------------------------
 
131
 
132
  print(f"🖥️ Using device: {DEVICE}")
133
 
134
+ # Global variables for model
135
  image_processor = None
136
  model = None
137
  ai_detection_classifier = None
138
  _preloaded = False
139
+ _use_clip_fallback = False
140
 
141
  # --------------------------------------------------------------------------------------
142
+ # SIMPLIFIED Model Loading - Direct CLIP fallback
143
  # --------------------------------------------------------------------------------------
144
 
145
  def preload_models():
146
+ """Preload models - try RADIO first, fall back to CLIP"""
147
+ global image_processor, model, _preloaded, _use_clip_fallback
148
 
149
  if _preloaded:
150
  print("✅ Models already loaded")
151
  return True
152
 
153
+ print("🔄 Loading visual encoder model...")
154
 
155
+ # Try to load C-RADIOv3-B first
156
+ hf_repo = os.getenv('MODEL_REPO', 'nvidia/C-RADIOv3-B')
157
+
158
+ if hf_repo != 'fallback':
 
 
 
 
 
159
  try:
160
+ print(f"📦 Attempting to load: {hf_repo}")
 
161
 
162
+ # Try patching first
163
+ patch_dinov2_architecture()
 
 
 
 
164
 
165
+ # Try loading with various workarounds
166
  with warnings.catch_warnings():
167
+ warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ try:
170
+ # Method 1: Load without meta model
171
+ from transformers import AutoModel, CLIPImageProcessor
172
+
173
+ image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
174
+
175
+ # Load with specific settings to avoid issues
176
+ model = AutoModel.from_pretrained(
177
+ hf_repo,
178
+ trust_remote_code=True,
179
+ low_cpu_mem_usage=False,
180
+ torch_dtype=torch.float32
181
+ )
182
+
183
+ model = model.to(DEVICE)
184
+ model.eval()
185
+
186
+ print(f"✅ Successfully loaded {hf_repo}")
187
+ _preloaded = True
188
+ _use_clip_fallback = False
189
+ return True
190
+
191
+ except KeyError as ke:
192
+ if "ls1.gamma" in str(ke) or "ls1.grandma" in str(ke):
193
+ print(f"⚠️ Known C-RADIOv3-B issue: {ke}")
194
+ else:
195
+ print(f"⚠️ Unexpected error: {ke}")
196
+ except Exception as e:
197
+ print(f"⚠️ Could not load {hf_repo}: {e}")
198
+
199
+ except Exception as e:
200
+ print(f"⚠️ Error during RADIO loading: {e}")
201
+
202
+ # Fall back to CLIP model which we know works
203
+ try:
204
+ print("📦 Loading fallback CLIP model...")
205
+ from transformers import CLIPModel, CLIPProcessor
206
+
207
+ clip_model = "openai/clip-vit-base-patch32"
208
+ image_processor = CLIPProcessor.from_pretrained(clip_model)
209
+ model = CLIPModel.from_pretrained(clip_model)
210
+ model = model.to(DEVICE)
211
+ model.eval()
212
+
213
+ print(f"✅ Successfully loaded fallback {clip_model}")
214
+ _preloaded = True
215
+ _use_clip_fallback = True
216
+ return True
217
+
218
  except Exception as e:
219
+ print(f"❌ Could not load any model: {e}")
220
  traceback.print_exc()
221
 
222
  return False
 
320
 
321
  except Exception as e:
322
  print(f"⚠️ Stage 1 error: {e}")
 
323
  # Fallback to simulator
324
  rgb = np.array(pil_image.convert("RGB"))
325
  boxes = simulate_damage_detection(rgb, seed_from=rgb)
 
377
  else:
378
  image = np.clip(image, 0, 255).astype(np.uint8)
379
 
380
+ pil = Image.fromarray(image)
381
  else:
382
  # Try to convert whatever it is
383
  arr = np.array(image)
384
  if arr.dtype != np.uint8:
385
  arr = np.clip(arr, 0, 255).astype(np.uint8)
386
+ pil = Image.fromarray(arr)
387
 
388
  # Handle EXIF orientation
389
  pil = ImageOps.exif_transpose(pil)
 
395
  return None
396
 
397
  def extract_features(image, return_stats=False):
398
+ """Extract features - handles both CLIP and RADIO models."""
399
+ global image_processor, model, _use_clip_fallback
400
 
401
  if image_processor is None or model is None:
402
  raise Exception("Model not initialized")
 
412
  # Process image
413
  inputs = image_processor(images=image, return_tensors='pt', do_resize=True)
414
 
415
+ # Get the right input tensor
416
  if hasattr(inputs, 'pixel_values'):
417
  pixel_values = inputs.pixel_values.to(DEVICE)
418
  else:
419
+ pixel_values = inputs['pixel_values'].to(DEVICE)
420
 
421
+ # Extract features based on model type
422
  with torch.no_grad():
423
+ if _use_clip_fallback and hasattr(model, 'get_image_features'):
424
+ # CLIP model
425
+ features = model.get_image_features(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
426
  else:
427
+ # RADIO or other model
428
+ outputs = model(pixel_values)
429
+
430
+ # Handle different output formats
431
+ if isinstance(outputs, dict):
432
+ if 'features' in outputs:
433
+ features = outputs['features']
434
+ elif 'last_hidden_state' in outputs:
435
+ features = outputs['last_hidden_state']
436
+ elif 'pooler_output' in outputs:
437
+ features = outputs['pooler_output']
438
+ else:
439
+ features = list(outputs.values())[0]
440
+ elif isinstance(outputs, (tuple, list)):
441
+ features = outputs[-1] if len(outputs) > 1 else outputs[0]
442
+ else:
443
+ features = outputs
444
 
445
  # Pool if needed
446
  if features.ndim == 3: # (B, T, C)
447
  features = features.mean(dim=1)
448
+ elif features.ndim == 4: # (B, C, H, W)
449
  features = features.mean(dim=(2, 3))
450
 
451
  # Normalize and flatten
 
458
  "std": float(features.std()),
459
  "min": float(features.min()),
460
  "max": float(features.max()),
461
+ "shape": features.shape,
462
+ "model_type": "CLIP" if _use_clip_fallback else "RADIO"
463
  }
464
  return features, stats
465
 
 
649
  print(f"⚠️ Stage 1 error: {e}")
650
 
651
  # Status display
652
+ if isinstance(detailed_result, dict):
653
+ if detailed_result.get("is_demo"):
654
+ status_html = '<div style="padding: 10px; background: #fef3c7; border-radius: 8px;"><p style="margin: 0; color: #f59e0b;">⚠️ Running in Demo Mode (using fallback model)</p></div>'
655
+ else:
656
+ model_info = detailed_result.get('feature_stats', {}).get('model_type', 'Unknown')
657
+ status_html = f'<div style="padding: 10px; background: #d1fae5; border-radius: 8px;"><p style="margin: 0; color: #10b981;">✅ Analysis Complete (using {model_info} model)</p></div>'
658
  else:
659
+ status_html = '<div style="padding: 10px; background: #fee2e2; border-radius: 8px;"><p style="margin: 0; color: #dc2626;">❌ Analysis Failed</p></div>'
660
 
661
  return simple_result, detailed_result, status_html, dmg_results, annotated
662
 
 
693
  with gr.Accordion("ℹ️ About", open=False):
694
  gr.Markdown("""
695
  ### Pipeline
696
+ - **Stage 1**: Detectron2 damage detection (simulated if not available)
697
+ - **Stage 2**: Visual feature extraction + AI detection classifier
698
+
699
+ ### Models
700
+ - **Primary**: C-RADIOv3-B visual encoder (if available)
701
+ - **Fallback**: CLIP-ViT-B-32 (reliable alternative)
702
+ - **Classifier**: Scikit-learn model for AI detection
703
 
704
+ ### Status
705
+ - The app will show which model is being used in the status display
706
+ - Falls back gracefully if primary models are unavailable
707
  """)
708
 
709
  return app
 
722
 
723
  # Preload models with fixes
724
  if preload_models():
725
+ model_type = "CLIP" if _use_clip_fallback else "RADIO"
726
+ print(f"✅ Visual encoder loaded ({model_type})")
727
  else:
728
+ print("⚠️ Running in full demo mode")
729
 
730
  # Load classifier
731
  model_path = huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH
732
  if load_ai_detection_classifier(model_path):
733
+ print("✅ AI detection classifier loaded")
734
 
735
  print("=" * 60)
736