astrosbd commited on
Commit
c093367
ยท
verified ยท
1 Parent(s): 8568bc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +704 -240
app.py CHANGED
@@ -1,268 +1,732 @@
 
1
  import os
 
 
 
 
 
 
 
 
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
3
  import json
4
- from huggingface_hub import hf_hub_download
5
- import safetensors.torch
 
 
6
 
7
- print("=" * 60)
8
- print("C-RADIOv3-B Model Deep Inspection")
9
- print("=" * 60)
10
 
11
- # Step 1: Download and inspect the model file directly
12
- def inspect_model_weights():
13
- """Directly inspect the safetensors file to see what keys exist"""
14
-
15
- print("\n๐Ÿ“ฅ Downloading model weights for inspection...")
16
-
17
- # Download the model file
18
- model_path = hf_hub_download(
19
- repo_id="nvidia/C-RADIOv3-B",
20
- filename="model.safetensors"
21
- )
22
-
23
- print(f"Downloaded to: {model_path}")
24
-
25
- # Load the safetensors file
26
- print("\n๐Ÿ” Inspecting model weights...")
27
- state_dict = safetensors.torch.load_file(model_path)
28
-
29
- # Analyze the keys
30
- all_keys = list(state_dict.keys())
31
- print(f"Total keys in model: {len(all_keys)}")
32
-
33
- # Look for ls1 related keys
34
- ls1_keys = [k for k in all_keys if 'ls1' in k.lower()]
35
- ls_keys = [k for k in all_keys if 'ls' in k.lower()]
36
- gamma_keys = [k for k in all_keys if 'gamma' in k.lower()]
37
- block_keys = [k for k in all_keys if k.startswith('blocks.')]
38
-
39
- print(f"\n๐Ÿ“Š Key Analysis:")
40
- print(f" Keys with 'ls1': {len(ls1_keys)}")
41
- print(f" Keys with 'ls': {len(ls_keys)}")
42
- print(f" Keys with 'gamma': {len(gamma_keys)}")
43
- print(f" Keys starting with 'blocks.': {len(block_keys)}")
44
-
45
- # Show first few block keys
46
- print(f"\n๐Ÿ“ First 20 block keys:")
47
- for i, key in enumerate(sorted([k for k in all_keys if k.startswith('blocks.0.')])[:20]):
48
- print(f" {key}")
49
-
50
- # Check what's actually in blocks.0
51
- blocks_0_keys = [k for k in all_keys if k.startswith('blocks.0.')]
52
- print(f"\n๐Ÿ”Ž All blocks.0 submodules:")
53
- submodules = set()
54
- for key in blocks_0_keys:
55
- parts = key.split('.')
56
- if len(parts) > 2:
57
- submodules.add(parts[2])
58
- for submodule in sorted(submodules):
59
- count = len([k for k in blocks_0_keys if f'blocks.0.{submodule}.' in k])
60
- print(f" blocks.0.{submodule}.*: {count} parameters")
61
-
62
- return state_dict, all_keys
63
 
64
- # Step 2: Check the model architecture expectations
65
- def inspect_model_code():
66
- """Download and inspect the model code to understand what it expects"""
67
-
68
- print("\n๐Ÿ“œ Downloading model code...")
69
-
70
- # Download the dinov2_arch.py file
71
- dinov2_path = hf_hub_download(
72
- repo_id="nvidia/C-RADIOv3-B",
73
- filename="dinov2_arch.py"
74
- )
75
-
76
- print(f"Downloaded dinov2_arch.py to: {dinov2_path}")
77
-
78
- # Read the problematic part of the code
79
- with open(dinov2_path, 'r') as f:
80
- lines = f.readlines()
81
-
82
- # Find the error location (around line 309)
83
- print("\n๐Ÿ” Code around line 309 (error location):")
84
- for i in range(max(0, 308-10), min(len(lines), 308+10)):
85
- if i == 308: # Line 309 (0-indexed)
86
- print(f">>> {i+1}: {lines[i].rstrip()}")
87
- else:
88
- print(f" {i+1}: {lines[i].rstrip()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Look for _load_from_state_dict method
91
- print("\n๐Ÿ“– Looking for _load_from_state_dict method...")
92
- for i, line in enumerate(lines):
93
- if '_load_from_state_dict' in line:
94
- print(f"Found at line {i+1}: {line.rstrip()}")
95
- # Show context
96
- for j in range(max(0, i-2), min(len(lines), i+15)):
97
- print(f" {j+1}: {lines[j].rstrip()}")
98
- break
99
-
100
- # Step 3: Create a working loader
101
- def create_fixed_loader():
102
- """Create a fixed loading function that handles the missing keys"""
103
 
104
- print("\n๐Ÿ”ง Creating Fixed Model Loader...")
105
 
106
- # Create a custom model loading function
107
- code = '''
108
- import torch
109
- from transformers import AutoModel, AutoConfig
110
- import warnings
111
-
112
- class RADIOModelFixed:
113
- @staticmethod
114
- def from_pretrained(repo_id="nvidia/C-RADIOv3-B"):
115
- """Load RADIO model with compatibility fixes"""
116
-
117
- print("Loading with compatibility fixes...")
118
 
119
- # First, modify the environment to skip the problematic check
120
- import sys
121
- import transformers.modeling_utils as mu
122
 
123
- # Store original function
124
- original_load = mu._load_state_dict_into_meta_model
125
 
126
- def patched_load(model, state_dict, device_map=None, offload_folder=None,
127
- dtype=None, offload_state_dict=None, tie_weights=True,
128
- **kwargs):
129
- """Patched loader that handles missing ls1 keys"""
130
 
131
- # Create a modified state dict with dummy ls1 keys if needed
132
- modified_state = state_dict.copy()
 
 
 
 
133
 
134
- # Check if we need to add dummy ls1 keys
135
- block_keys = [k for k in state_dict.keys() if k.startswith('blocks.')]
136
- if block_keys and not any('ls1' in k for k in block_keys):
137
- print(" Adding compatibility keys for ls1 layers...")
138
 
139
- # Find all blocks
140
- block_indices = set()
141
- for key in block_keys:
142
- parts = key.split('.')
143
- if len(parts) > 1 and parts[1].isdigit():
144
- block_indices.add(int(parts[1]))
145
-
146
- # Add dummy ls1 parameters for each block
147
- for idx in block_indices:
148
- # These will be ignored but prevent the error
149
- if f'blocks.{idx}.norm1.weight' in state_dict:
150
- # Use norm1 as a template for shape
151
- template = state_dict[f'blocks.{idx}.norm1.weight']
152
- modified_state[f'blocks.{idx}.ls1.gamma'] = torch.ones_like(template)
153
- else:
154
- # Default to scalar
155
- modified_state[f'blocks.{idx}.ls1.gamma'] = torch.tensor(1.0)
156
 
157
- # Call original with modified state
158
- return original_load(model, modified_state, device_map, offload_folder,
159
- dtype, offload_state_dict, tie_weights, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- # Temporarily replace the function
162
- mu._load_state_dict_into_meta_model = patched_load
 
 
 
 
 
 
 
163
 
164
- try:
165
- # Load the model
166
- model = AutoModel.from_pretrained(
167
- repo_id,
168
- trust_remote_code=True,
169
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
170
- )
171
- print(" Model loaded successfully with compatibility fixes!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- finally:
174
- # Restore original function
175
- mu._load_state_dict_into_meta_model = original_load
 
 
 
 
 
 
 
 
176
 
177
- return model
 
 
 
178
 
179
- # Usage:
180
- model = RADIOModelFixed.from_pretrained()
181
- '''
182
-
183
- print(code)
184
-
185
- # Save to file
186
- with open('radio_loader_fixed.py', 'w') as f:
187
- f.write(code)
188
 
189
- print("\nโœ… Fixed loader saved to 'radio_loader_fixed.py'")
 
190
 
191
- # Step 4: Alternative loading approach
192
- def try_alternative_loading():
193
- """Try loading the model with different strategies"""
194
-
195
- print("\n๐Ÿ”„ Trying Alternative Loading Methods...")
196
-
197
- from transformers import AutoModel, AutoConfig
198
- import transformers.modeling_utils
199
-
200
- repo_id = "nvidia/C-RADIOv3-B"
201
-
202
- # Method 1: Load config first and check architecture
203
- print("\n1๏ธโƒฃ Checking model config...")
204
- config = AutoConfig.from_pretrained(repo_id, trust_remote_code=True)
205
- print(f" Architecture: {config.architectures}")
206
- print(f" Model type: {config.model_type}")
207
-
208
- # Method 2: Try loading without state dict verification
209
- print("\n2๏ธโƒฃ Attempting to load with strict=False...")
210
-
211
- # Monkey-patch the DINOv2 architecture
212
- import importlib.util
213
- import sys
214
-
215
- # Download the dinov2_arch.py
216
- dinov2_path = hf_hub_download(repo_id=repo_id, filename="dinov2_arch.py")
217
-
218
- # Load it as a module
219
- spec = importlib.util.spec_from_file_location("dinov2_arch_patched", dinov2_path)
220
- dinov2_module = importlib.util.module_from_spec(spec)
221
-
222
- # Patch the _load_from_state_dict method before loading
223
- original_code = open(dinov2_path, 'r').read()
224
-
225
- # Replace the error-raising code
226
- patched_code = original_code.replace(
227
- 'raise KeyError(f"Couldn\'t find the key {key_a} nor {key_b} in the state dict!")',
228
- '''
229
- print(f" Warning: Missing keys {key_a} and {key_b}, using defaults")
230
- # Use identity/ones as default
231
- if "gamma" in key_a:
232
- setattr(self, key_a.split(".")[-1], torch.nn.Parameter(torch.ones(self.dim)))
233
- elif "beta" in key_a:
234
- setattr(self, key_a.split(".")[-1], torch.nn.Parameter(torch.zeros(self.dim)))
235
- return
236
- '''
237
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- # Save patched version
240
- patched_path = "dinov2_arch_patched.py"
241
- with open(patched_path, 'w') as f:
242
- f.write(patched_code)
243
 
244
- print(f" Created patched architecture file: {patched_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- print("\nโœ… Alternative loading methods prepared")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- # Run all inspections
249
  if __name__ == "__main__":
250
- try:
251
- # Step 1: Inspect weights
252
- state_dict, keys = inspect_model_weights()
253
-
254
- # Step 2: Inspect code
255
- inspect_model_code()
256
-
257
- # Step 3: Create fixed loader
258
- create_fixed_loader()
259
-
260
- # Step 4: Try alternatives
261
- try_alternative_loading()
262
-
263
- print("\n" + "=" * 60)
264
- print("DIAGNOSIS COMPLETE")
265
- print("=" * 60)
266
-
267
-
268
-
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
  import os
3
+ import sys
4
+ import traceback
5
+ from typing import Optional, Tuple, Dict, Any, List
6
+ import warnings
7
+
8
+ import importlib.util
9
+ import time
10
+ import cv2
11
  import torch
12
+ import numpy as np
13
+ import gradio as gr
14
+ from PIL import Image, ImageOps
15
+ from torchvision import transforms
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import traceback
19
+ from torchvision.models import vit_b_16
20
+ from transformers import AutoModel, CLIPImageProcessor, AutoConfig
21
+ import joblib
22
+ import zipfile
23
  import json
24
+ from datetime import datetime
25
+ import requests
26
+ import base64
27
+ import io
28
 
29
+ # --------------------------------------------------------------------------------------
30
+ # PATCHED MODEL LOADING
31
+ # --------------------------------------------------------------------------------------
32
 
33
+ def patch_transformers_for_radio():
34
+ """Patch transformers to handle missing ls1 parameters in C-RADIOv3-B"""
35
+ try:
36
+ import transformers.modeling_utils
37
+
38
+ # Store original function
39
+ if not hasattr(transformers.modeling_utils, '_original_load_state_dict'):
40
+ transformers.modeling_utils._original_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
41
+
42
+ def patched_load_state_dict_into_meta_model(model, state_dict, device_map=None,
43
+ offload_folder=None, dtype=None,
44
+ offload_state_dict=None,
45
+ offload_buffers=None,
46
+ keep_in_fp32_modules=None,
47
+ tied_params=None,
48
+ **kwargs):
49
+ """Patched loader that ignores missing ls1 keys"""
50
+
51
+ # Filter out any existing ls1 fake keys if they exist
52
+ filtered_state = {k: v for k, v in state_dict.items()
53
+ if not ('ls1.gamma' in k or 'ls1.grandma' in k)}
54
+
55
+ # Try loading with the original function
56
+ try:
57
+ return transformers.modeling_utils._original_load_state_dict(
58
+ model, filtered_state, device_map, offload_folder, dtype,
59
+ offload_state_dict, offload_buffers, keep_in_fp32_modules,
60
+ tied_params, **kwargs
61
+ )
62
+ except KeyError as e:
63
+ if "ls1.gamma" in str(e) or "ls1.grandma" in str(e):
64
+ print(f"โš ๏ธ Ignoring missing layer scaling parameters: {e}")
65
+ # Return empty dicts to indicate successful loading
66
+ return {}, {}
67
+ raise
68
+
69
+ # Apply the patch
70
+ transformers.modeling_utils._load_state_dict_into_meta_model = patched_load_state_dict_into_meta_model
71
+ print("โœ… Applied compatibility patch for C-RADIOv3-B")
72
+ return True
73
+
74
+ except Exception as e:
75
+ print(f"โš ๏ธ Could not apply patch: {e}")
76
+ return False
 
 
 
 
 
 
 
 
77
 
78
+ # Apply the patch at module load time
79
+ patch_transformers_for_radio()
80
+
81
+ # --------------------------------------------------------------------------------------
82
+ # Check Detectron2
83
+ # --------------------------------------------------------------------------------------
84
+
85
+ DETECTRON2_AVAILABLE = False
86
+ try:
87
+ from detectron2.engine import DefaultPredictor
88
+ from detectron2.config import get_cfg
89
+ from detectron2.utils.visualizer import Visualizer, ColorMode
90
+ from detectron2 import model_zoo
91
+ DETECTRON2_AVAILABLE = True
92
+ print("โœ… Detectron2 imported successfully")
93
+ except ImportError as e:
94
+ print(f"โš ๏ธ Detectron2 not available: {e}")
95
+ DETECTRON2_AVAILABLE = False
96
+
97
+ # Try to download model from Hugging Face
98
+ huggingface_model_path = None
99
+ try:
100
+ from huggingface_hub import hf_hub_download
101
+
102
+ repo = os.getenv('PRIVATE_REPO', 'fallback')
103
+ token = os.getenv('key')
104
+
105
+ if repo != 'fallback' and token:
106
+ huggingface_model_path = hf_hub_download(
107
+ repo_id=repo,
108
+ filename="V1.pkl",
109
+ token=token
110
+ )
111
+ print(f"โœ… Model downloaded from Hugging Face: {huggingface_model_path}")
112
+ except Exception as e:
113
+ print(f"โš ๏ธ Could not download model from Hugging Face: {e}")
114
+ print("๐Ÿ”„ Will use demo mode with simulated results")
115
+ huggingface_model_path = None
116
+
117
+ # --------------------------------------------------------------------------------------
118
+ # Basics
119
+ # --------------------------------------------------------------------------------------
120
+
121
+ # Initialize device for model
122
+ if torch.backends.mps.is_available():
123
+ DEVICE = torch.device("mps")
124
+ elif torch.cuda.is_available():
125
+ DEVICE = torch.device("cuda")
126
+ else:
127
+ DEVICE = torch.device("cpu")
128
+
129
+ print(f"๐Ÿ–ฅ๏ธ Using device: {DEVICE}")
130
+
131
+ # Global variables for C model
132
+ image_processor = None
133
+ model = None
134
+ ai_detection_classifier = None
135
+ _preloaded = False
136
+
137
+ # --------------------------------------------------------------------------------------
138
+ # FIXED Model Loading
139
+ # --------------------------------------------------------------------------------------
140
+
141
+ def preload_models():
142
+ """Preload models with compatibility fixes"""
143
+ global image_processor, model, _preloaded
144
 
145
+ if _preloaded:
146
+ print("โœ… Models already loaded")
147
+ return True
 
 
 
 
 
 
 
 
 
 
148
 
149
+ print("๐Ÿ”„ Preloading C-RADIOv3-B model...")
150
 
151
+ try:
152
+ hf_repo = os.getenv('MODEL_REPO', 'nvidia/C-RADIOv3-B')
 
 
 
 
 
 
 
 
 
 
153
 
154
+ if hf_repo == 'fallback':
155
+ hf_repo = 'nvidia/C-RADIOv3-B'
 
156
 
157
+ print(f"๐Ÿ“ฆ Loading from: {hf_repo}")
 
158
 
159
+ # Method 1: Try with patched loader
160
+ try:
161
+ # Ensure patch is applied
162
+ patch_transformers_for_radio()
163
 
164
+ # Load image processor
165
+ from transformers import CLIPImageProcessor, AutoImageProcessor
166
+ try:
167
+ image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
168
+ except:
169
+ image_processor = AutoImageProcessor.from_pretrained(hf_repo)
170
 
171
+ # Suppress the specific warning we know about
172
+ with warnings.catch_warnings():
173
+ warnings.filterwarnings("ignore", message="Couldn't find the key")
 
174
 
175
+ # Load model with low_cpu_mem_usage=False to avoid meta model issues
176
+ model = AutoModel.from_pretrained(
177
+ hf_repo,
178
+ trust_remote_code=True,
179
+ low_cpu_mem_usage=False, # Important: disable meta model loading
180
+ ignore_mismatched_sizes=True
181
+ )
182
+
183
+ model = model.to(DEVICE)
184
+ model.eval()
185
+
186
+ print("โœ… C-RADIOv3-B model loaded successfully with compatibility fixes!")
187
+ _preloaded = True
188
+ return True
189
+
190
+ except Exception as e1:
191
+ print(f"โš ๏ธ Method 1 failed: {e1}")
192
 
193
+ # Method 2: Try loading without trust_remote_code
194
+ try:
195
+ print("Trying alternative loading method...")
196
+
197
+ # Use a simpler CLIP model as fallback
198
+ from transformers import CLIPModel, CLIPProcessor
199
+
200
+ fallback_model = "openai/clip-vit-base-patch32"
201
+ print(f"Loading fallback model: {fallback_model}")
202
+
203
+ image_processor = CLIPProcessor.from_pretrained(fallback_model)
204
+ model = CLIPModel.from_pretrained(fallback_model)
205
+ model = model.to(DEVICE)
206
+ model.eval()
207
+
208
+ print("โœ… Loaded fallback CLIP model successfully!")
209
+ _preloaded = True
210
+ return True
211
+
212
+ except Exception as e2:
213
+ print(f"โš ๏ธ Method 2 failed: {e2}")
214
+
215
+ except Exception as e:
216
+ print(f"โŒ Could not preload model: {e}")
217
+ traceback.print_exc()
218
+
219
+ return False
220
+
221
+ # --------------------------------------------------------------------------------------
222
+ # Paths
223
+ # --------------------------------------------------------------------------------------
224
+ DEFAULT_AI_DETECTION_MODEL_PATH = "./output/V1.pkl"
225
+ DEFAULT_DAMAGE_MODEL_PATH = "./output/model_final.pth"
226
+
227
+ # --------------------------------------------------------------------------------------
228
+ # Damage Detection (Stage 1)
229
+ # --------------------------------------------------------------------------------------
230
+
231
+ _damage_predictor = None
232
+
233
+ def load_damage_model(model_path: str, device_str: str = None):
234
+ """Load fine-tuned Detectron2 model once (Stage 1)."""
235
+ global _damage_predictor
236
+ if _damage_predictor is not None:
237
+ return _damage_predictor
238
+
239
+ if (not DETECTRON2_AVAILABLE) or (not model_path) or (not os.path.exists(model_path)):
240
+ print("โ„น๏ธ Stage 1 damage model not available; using simulator")
241
+ return None
242
+
243
+ try:
244
+ cfg = get_cfg()
245
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
246
+ cfg.MODEL.WEIGHTS = model_path
247
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
248
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
249
+
250
+ if device_str is None:
251
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
252
+ cfg.MODEL.DEVICE = device_str
253
+
254
+ _damage_predictor = DefaultPredictor(cfg)
255
+ print(f"โœ… Damage model loaded on {device_str}")
256
+ return _damage_predictor
257
+ except Exception as e:
258
+ print(f"โŒ Could not load Detectron2 model: {e}")
259
+ return None
260
+
261
+ def simulate_damage_detection(rgb_image: np.ndarray, seed_from: np.ndarray = None) -> List[Dict[str, Any]]:
262
+ """Deterministic fake detections for demo mode."""
263
+ import hashlib, random
264
+ h, w = rgb_image.shape[:2]
265
+ if seed_from is None:
266
+ seed_from = rgb_image
267
+ img_hash = hashlib.md5(seed_from.tobytes()).hexdigest()
268
+ seed = int(img_hash[:8], 16) % 10_000
269
+ random.seed(seed)
270
+ n = random.randint(0, 3)
271
+ boxes = []
272
+ for _ in range(n):
273
+ x1 = random.randint(0, max(0, w - w//3))
274
+ y1 = random.randint(0, max(0, h - h//3))
275
+ x2 = min(w-1, x1 + random.randint(w//8, w//3))
276
+ y2 = min(h-1, y1 + random.randint(h//8, h//3))
277
+ conf = round(random.uniform(0.6, 0.95), 3)
278
+ boxes.append({"bbox":[x1,y1,x2,y2], "score":conf, "label":"damage"})
279
+ return boxes
280
+
281
+ def run_damage_detection(pil_image: Image.Image, score_thresh: float = 0.5):
282
+ """Run damage detection with fallback."""
283
+ try:
284
+ rgb = np.array(pil_image.convert("RGB"))
285
+ predictor = load_damage_model(DEFAULT_DAMAGE_MODEL_PATH)
286
 
287
+ if predictor is None:
288
+ boxes = simulate_damage_detection(rgb, seed_from=rgb)
289
+ annotated = rgb.copy()
290
+ for i, b in enumerate(boxes):
291
+ x1,y1,x2,y2 = b["bbox"]
292
+ cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,255,0), 2)
293
+ cv2.putText(annotated, f"Damage {i+1} {b['score']*100:.1f}%",
294
+ (x1, max(0,y1-8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2)
295
+ return boxes, annotated, True, "predictor not available"
296
 
297
+ # Real inference
298
+ outputs = predictor(rgb)
299
+ instances = outputs["instances"].to("cpu")
300
+ boxes = []
301
+ if len(instances) > 0:
302
+ pred_boxes = instances.pred_boxes.tensor.numpy()
303
+ scores = instances.scores.numpy()
304
+ for i, (box, sc) in enumerate(zip(pred_boxes, scores)):
305
+ if sc >= score_thresh:
306
+ x1,y1,x2,y2 = [int(v) for v in box]
307
+ boxes.append({"bbox":[x1,y1,x2,y2], "score":float(sc), "label":"damage"})
308
+
309
+ annotated = rgb.copy()
310
+ for i, b in enumerate(boxes):
311
+ x1,y1,x2,y2 = b["bbox"]
312
+ cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,255,0), 2)
313
+ cv2.putText(annotated, f"Damage {i+1} {b['score']*100:.1f}%",
314
+ (x1, max(0,y1-8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2)
315
+
316
+ return boxes, annotated, False, None
317
+
318
+ except Exception as e:
319
+ print(f"โš ๏ธ Stage 1 error: {e}")
320
+ traceback.print_exc()
321
+ # Fallback to simulator
322
+ rgb = np.array(pil_image.convert("RGB"))
323
+ boxes = simulate_damage_detection(rgb, seed_from=rgb)
324
+ annotated = rgb.copy()
325
+ for i, b in enumerate(boxes):
326
+ x1,y1,x2,y2 = b["bbox"]
327
+ cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,255,0), 2)
328
+ cv2.putText(annotated, f"Damage {i+1} {b['score']*100:.1f}%",
329
+ (x1, max(0,y1-8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2)
330
+ return boxes, annotated, True, "stage1 error"
331
+
332
+ # --------------------------------------------------------------------------------------
333
+ # Stage 2: Feature extraction + classifier
334
+ # --------------------------------------------------------------------------------------
335
+
336
+ def load_ai_detection_classifier(model_path):
337
+ """Load the AI detection classifier (joblib)."""
338
+ global ai_detection_classifier
339
+ if ai_detection_classifier is not None:
340
+ print("โœ… Classifier already loaded, reusing...")
341
+ return ai_detection_classifier
342
+
343
+ if model_path is None or not os.path.exists(model_path):
344
+ print(f"โŒ AI detection model not found at: {model_path}")
345
+ return None
346
+
347
+ try:
348
+ ai_detection_classifier = joblib.load(model_path)
349
+ print(f"โœ… AI detection classifier loaded from {model_path}")
350
+ print(f" Classifier type: {type(ai_detection_classifier).__name__}")
351
+ return ai_detection_classifier
352
+ except Exception as e:
353
+ print(f"โŒ Error loading classifier: {e}")
354
+ return None
355
+
356
+ def preprocess_image(image) -> Optional[Image.Image]:
357
+ """Robust image preprocessing."""
358
+ try:
359
+ if image is None:
360
+ return None
361
+
362
+ if isinstance(image, Image.Image):
363
+ pil = image
364
+ elif isinstance(image, str):
365
+ pil = Image.open(image)
366
+ elif isinstance(image, np.ndarray):
367
+ if image.ndim == 2:
368
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
369
+ elif image.ndim == 3 and image.shape[2] == 4:
370
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
371
+
372
+ if image.dtype != np.uint8:
373
+ if image.max() <= 1.0:
374
+ image = (image * 255).astype(np.uint8)
375
+ else:
376
+ image = np.clip(image, 0, 255).astype(np.uint8)
377
 
378
+ pil = Image.fromarray(image, 'RGB')
379
+ else:
380
+ # Try to convert whatever it is
381
+ arr = np.array(image)
382
+ if arr.dtype != np.uint8:
383
+ arr = np.clip(arr, 0, 255).astype(np.uint8)
384
+ pil = Image.fromarray(arr, 'RGB')
385
+
386
+ # Handle EXIF orientation
387
+ pil = ImageOps.exif_transpose(pil)
388
+ return pil.convert("RGB")
389
 
390
+ except Exception as e:
391
+ print(f"โŒ Preprocess error: {e}")
392
+ traceback.print_exc()
393
+ return None
394
 
395
+ def extract_features(image, return_stats=False):
396
+ """Extract features with proper handling for different model types."""
397
+ global image_processor, model
 
 
 
 
 
 
398
 
399
+ if image_processor is None or model is None:
400
+ raise Exception("Model not initialized")
401
 
402
+ if not isinstance(image, Image.Image):
403
+ image = preprocess_image(image)
404
+ if image is None:
405
+ raise Exception("Failed to preprocess image")
406
+
407
+ # Resize to 224x224
408
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
409
+
410
+ # Process image
411
+ inputs = image_processor(images=image, return_tensors='pt', do_resize=True)
412
+
413
+ # Handle different processor outputs
414
+ if hasattr(inputs, 'pixel_values'):
415
+ pixel_values = inputs.pixel_values.to(DEVICE)
416
+ else:
417
+ pixel_values = inputs['input_ids'].to(DEVICE) if 'input_ids' in inputs else inputs.to(DEVICE)
418
+
419
+ # Get features
420
+ with torch.no_grad():
421
+ outputs = model(pixel_values)
422
+
423
+ # Handle different model outputs
424
+ if hasattr(model, 'get_image_features'):
425
+ # CLIP model
426
+ features = model.get_image_features(pixel_values)
427
+ elif isinstance(outputs, dict):
428
+ # Dictionary output
429
+ if 'features' in outputs:
430
+ features = outputs['features']
431
+ elif 'last_hidden_state' in outputs:
432
+ features = outputs['last_hidden_state']
433
+ elif 'pooler_output' in outputs:
434
+ features = outputs['pooler_output']
435
+ else:
436
+ # Take the first tensor value
437
+ features = next(iter(outputs.values()))
438
+ elif isinstance(outputs, (list, tuple)):
439
+ # Tuple/list output - take last element
440
+ features = outputs[-1] if len(outputs) > 1 else outputs[0]
441
+ else:
442
+ # Direct tensor output
443
+ features = outputs
444
+
445
+ # Pool if needed
446
+ if features.ndim == 3: # (B, T, C)
447
+ features = features.mean(dim=1)
448
+ elif features.ndim == 4: # (B, C, H, W)
449
+ features = features.mean(dim=(2, 3))
450
+
451
+ # Normalize and flatten
452
+ features = features.detach().flatten()
453
+ features = F.normalize(features, p=2, dim=-1).cpu().numpy()
454
+
455
+ if return_stats:
456
+ stats = {
457
+ "mean": float(features.mean()),
458
+ "std": float(features.std()),
459
+ "min": float(features.min()),
460
+ "max": float(features.max()),
461
+ "shape": features.shape
462
+ }
463
+ return features, stats
464
+
465
+ return features
466
+
467
+ def simulate_prediction(image) -> Dict[str, Any]:
468
+ """Fallback simulation when models/classifier aren't available."""
469
+ import hashlib, random
470
+
471
+ if isinstance(image, Image.Image):
472
+ arr = np.array(image.convert("RGB"))
473
+ elif isinstance(image, np.ndarray):
474
+ arr = image
475
+ else:
476
+ arr = np.array(preprocess_image(image) or Image.new("RGB",(16,16),(0,0,0)))
477
+
478
+ img_hash = hashlib.md5(arr.tobytes()).hexdigest()
479
+ seed = int(img_hash[:8], 16) % 1000
480
+ random.seed(seed)
481
+ ai_prob = random.uniform(0.1, 0.9)
482
+ is_ai = ai_prob > 0.5
483
+ confidence_level = "HIGH" if abs(ai_prob - 0.5) > 0.3 else "MEDIUM" if abs(ai_prob - 0.5) > 0.15 else "LOW"
484
+
485
+ return {
486
+ "prediction": "AI-Generated" if is_ai else "Real",
487
+ "ai_probability": ai_prob,
488
+ "real_probability": 1 - ai_prob,
489
+ "confidence": confidence_level,
490
+ "is_demo": True
491
+ }
492
+
493
+ def _predict_with_classifier(classifier, features: np.ndarray) -> Tuple[int, float, float]:
494
+ """Predict with classifier."""
495
+ features = features.reshape(1, -1)
496
+ pred = int(classifier.predict(features)[0])
497
 
498
+ ai_prob = real_prob = 0.5
 
 
 
499
 
500
+ if hasattr(classifier, "predict_proba"):
501
+ try:
502
+ probs = classifier.predict_proba(features)[0]
503
+ if len(probs) >= 2:
504
+ real_prob = float(probs[0])
505
+ ai_prob = float(probs[1])
506
+ else:
507
+ ai_prob = float(probs[0]) if pred == 1 else 1 - float(probs[0])
508
+ real_prob = 1 - ai_prob
509
+ except:
510
+ pass
511
+ elif hasattr(classifier, "decision_function"):
512
+ try:
513
+ df = float(classifier.decision_function(features)[0])
514
+ ai_prob = 1.0 / (1.0 + np.exp(-df))
515
+ real_prob = 1.0 - ai_prob
516
+ except:
517
+ pass
518
 
519
+ return pred, ai_prob, real_prob
520
+
521
+ # --------------------------------------------------------------------------------------
522
+ # Gradio Interface
523
+ # --------------------------------------------------------------------------------------
524
+
525
+ def create_gradio_interface():
526
+ """Create the Gradio interface."""
527
+
528
+ with gr.Blocks(title="AI Image Detection", css=".gradio-container { font-family: Inter, system-ui; }") as app:
529
+ gr.HTML("""
530
+ <div style="text-align: center; padding: 20px; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">
531
+ <h1 style="margin: 0;">๐Ÿค– AI Image Detection</h1>
532
+ <p style="margin: 10px 0 0 0;">Stage 1 (Damage) + Stage 2 (AI Detection)</p>
533
+ </div>
534
+ """)
535
+
536
+ with gr.Row():
537
+ with gr.Column():
538
+ input_image = gr.Image(
539
+ type="numpy",
540
+ label="Upload Image",
541
+ height=400
542
+ )
543
+
544
+ with gr.Row():
545
+ predict_btn = gr.Button("๐Ÿ” Analyze", variant="primary", size="lg")
546
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear", variant="secondary", size="lg")
547
+
548
+ enable_damage = gr.Checkbox(value=True, label="Enable Stage 1 (Damage Detection)")
549
+ damage_thresh = gr.Slider(0.1, 0.95, value=0.5, step=0.05, label="Damage Score Threshold")
550
+
551
+ with gr.Column():
552
+ output_text = gr.Textbox(
553
+ label="Prediction Result",
554
+ placeholder="Upload an image and click Analyze",
555
+ interactive=False,
556
+ lines=2
557
+ )
558
+ output_json = gr.JSON(label="Detailed Analysis (Stage 2)")
559
+ damage_json = gr.JSON(label="Stage 1: Damage Detections")
560
+ annotated_image = gr.Image(label="Annotated Output")
561
+ status_display = gr.HTML("""
562
+ <div style="padding: 10px; background: #f0f4f8; border-radius: 8px; margin-top: 10px;">
563
+ <p style="margin: 0; color: #64748b;">Ready for analysis...</p>
564
+ </div>
565
+ """)
566
+
567
+ def analyze_with_status(image, enable_damage, damage_thresh):
568
+ """Analyze image."""
569
+ if image is None:
570
+ return (
571
+ "โŒ No image provided",
572
+ {"error": "No image provided"},
573
+ '<div style="padding: 10px; background: #fee2e2; border-radius: 8px;"><p style="margin: 0; color: #dc2626;">โŒ No image provided</p></div>',
574
+ [],
575
+ None
576
+ )
577
+
578
+ # Initialize models
579
+ model_initialized = (image_processor is not None and model is not None) or preload_models()
580
+ model_path = huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH
581
+ classifier = ai_detection_classifier or load_ai_detection_classifier(model_path)
582
+
583
+ demo_reasons = []
584
+ if not model_initialized:
585
+ demo_reasons.append("feature extractor missing")
586
+ if classifier is None:
587
+ demo_reasons.append("classifier missing")
588
+
589
+ # Stage 2: AI Detection
590
+ try:
591
+ if demo_reasons:
592
+ result = simulate_prediction(preprocess_image(image))
593
+ result["demo_reasons"] = demo_reasons
594
+ simple_result = f"{result['prediction']} (AI: {result['ai_probability']:.2%}) [Demo]"
595
+ detailed_result = result
596
+ else:
597
+ feats, stats = extract_features(preprocess_image(image), return_stats=True)
598
+ pred, ai_prob, real_prob = _predict_with_classifier(classifier, feats)
599
+ is_ai = pred == 1
600
+ result_text = "AI-Generated" if is_ai else "Real"
601
+ conf_score = max(ai_prob, real_prob)
602
+ confidence = "HIGH" if conf_score > 0.80 else "MEDIUM" if conf_score > 0.60 else "LOW"
603
+
604
+ detailed_result = {
605
+ "prediction": result_text,
606
+ "ai_probability": ai_prob,
607
+ "real_probability": real_prob,
608
+ "confidence": confidence,
609
+ "confidence_score": conf_score,
610
+ "is_demo": False,
611
+ "feature_stats": stats
612
+ }
613
+ simple_result = f"{result_text} (Confidence: {conf_score:.2%})"
614
+
615
+ except Exception as e:
616
+ print(f"โŒ Stage 2 error: {e}")
617
+ traceback.print_exc()
618
+ result = simulate_prediction(preprocess_image(image))
619
+ result["demo_reasons"] = ["stage2 error"]
620
+ simple_result = f"{result['prediction']} (AI: {result['ai_probability']:.2%}) [Demo]"
621
+ detailed_result = result
622
+
623
+ # Stage 1: Damage Detection
624
+ dmg_results = []
625
+ annotated = None
626
+
627
+ if enable_damage:
628
+ try:
629
+ pil = preprocess_image(image)
630
+ if pil:
631
+ boxes, annotated_rgb, demo, reason = run_damage_detection(pil, float(damage_thresh))
632
+ dmg_results = boxes
633
+ annotated = annotated_rgb
634
+
635
+ # Add verdict overlay
636
+ if annotated is not None and isinstance(detailed_result, dict):
637
+ is_ai = (detailed_result.get("prediction") == "AI-Generated")
638
+ ai_prob = float(detailed_result.get("ai_probability", 0.5))
639
+ color = (0,0,255) if is_ai else (0,255,0)
640
+ verdict = detailed_result.get("prediction", "Unknown")
641
+
642
+ cv2.putText(annotated, verdict, (30, 50),
643
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 3)
644
+ cv2.putText(annotated, f"Confidence: {ai_prob*100:.1f}%",
645
+ (30, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
646
+
647
+ except Exception as e:
648
+ print(f"โš ๏ธ Stage 1 error: {e}")
649
+
650
+ # Status display
651
+ if isinstance(detailed_result, dict) and detailed_result.get("is_demo"):
652
+ status_html = '<div style="padding: 10px; background: #fef3c7; border-radius: 8px;"><p style="margin: 0; color: #f59e0b;">โš ๏ธ Running in Demo Mode</p></div>'
653
+ else:
654
+ status_html = '<div style="padding: 10px; background: #d1fae5; border-radius: 8px;"><p style="margin: 0; color: #10b981;">โœ… Analysis Complete</p></div>'
655
+
656
+ return simple_result, detailed_result, status_html, dmg_results, annotated
657
+
658
+ def clear_all():
659
+ """Clear all fields."""
660
+ return (
661
+ None,
662
+ "",
663
+ {},
664
+ '<div style="padding: 10px; background: #f0f4f8; border-radius: 8px;"><p style="margin: 0; color: #64748b;">Ready for analysis...</p></div>',
665
+ [],
666
+ None
667
+ )
668
+
669
+ # Wire up events
670
+ predict_btn.click(
671
+ fn=analyze_with_status,
672
+ inputs=[input_image, enable_damage, damage_thresh],
673
+ outputs=[output_text, output_json, status_display, damage_json, annotated_image]
674
+ )
675
+
676
+ clear_btn.click(
677
+ fn=clear_all,
678
+ outputs=[input_image, output_text, output_json, status_display, damage_json, annotated_image]
679
+ )
680
+
681
+ # Auto-analyze on image change
682
+ input_image.change(
683
+ fn=analyze_with_status,
684
+ inputs=[input_image, enable_damage, damage_thresh],
685
+ outputs=[output_text, output_json, status_display, damage_json, annotated_image]
686
+ )
687
+
688
+ with gr.Accordion("โ„น๏ธ About", open=False):
689
+ gr.Markdown("""
690
+ ### Pipeline
691
+ - **Stage 1**: Detectron2 damage detection (optional)
692
+ - **Stage 2**: Visual features + AI detection classifier
693
+
694
+ ### Notes
695
+ - Falls back to demo mode if models are unavailable
696
+ - C-RADIOv3-B model includes compatibility fixes for layer scaling issues
697
+ """)
698
+
699
+ return app
700
+
701
+ # --------------------------------------------------------------------------------------
702
+ # Main
703
+ # --------------------------------------------------------------------------------------
704
 
 
705
  if __name__ == "__main__":
706
+ print("=" * 60)
707
+ print("๐Ÿš€ Starting AI Image Detection App")
708
+ print("=" * 60)
709
+ print(f"๐Ÿ“ Device: {DEVICE}")
710
+ print(f"๐Ÿ“ฆ Classifier: {huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH}")
711
+ print(f"๐Ÿ› ๏ธ Damage Model: {DEFAULT_DAMAGE_MODEL_PATH}")
712
+
713
+ # Preload models with fixes
714
+ if preload_models():
715
+ print("โœ… Models preloaded successfully")
716
+ else:
717
+ print("โš ๏ธ Running in demo mode")
718
+
719
+ # Load classifier
720
+ model_path = huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH
721
+ if load_ai_detection_classifier(model_path):
722
+ print("โœ… Classifier loaded")
723
+
724
+ print("=" * 60)
725
+
726
+ app = create_gradio_interface()
727
+ app.launch(
728
+ share=False,
729
+ server_name="0.0.0.0",
730
+ server_port=7860,
731
+ show_error=True
732
+ )