astrosbd commited on
Commit
b382923
·
verified ·
1 Parent(s): 042bdbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -726
app.py CHANGED
@@ -1,88 +1,3 @@
1
- #!/usr/bin/env python3
2
- import os
3
- import sys
4
- import traceback
5
- from typing import Optional, Tuple, Dict, Any, List
6
-
7
- import importlib.util
8
- import time
9
- import cv2
10
- import torch
11
- import numpy as np
12
- import gradio as gr
13
- from PIL import Image, ImageOps
14
- from torchvision import transforms
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- import traceback
18
- from torchvision.models import vit_b_16
19
- from transformers import AutoModel, CLIPImageProcessor
20
- import joblib
21
- import zipfile
22
- import json
23
- from datetime import datetime
24
- import requests
25
- import base64
26
- import io
27
-
28
-
29
- # Check if detectron2 is installed and attempt installation if needed
30
- if importlib.util.find_spec("detectron") is None:
31
- print("🔄 Detectron2 not found. Attempting installation...")
32
- print("Installing PyTorch and Detectron2...")
33
- os.system("pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu")
34
- os.system("pip install git+https://github.com/facebookresearch/detectron2.git")
35
- print("Installation complete!")
36
-
37
- # Optional Detectron2 import
38
- DETECTRON2_AVAILABLE = False
39
- try:
40
- print("Attempting to import Detectron2...")
41
- from detectron2.engine import DefaultPredictor
42
- from detectron2.config import get_cfg
43
- from detectron2.utils.visualizer import Visualizer, ColorMode
44
- from detectron2 import model_zoo
45
-
46
- DETECTRON2_AVAILABLE = True
47
- print("✅ Detectron2 imported successfully")
48
- except ImportError as e:
49
- print(f"⚠️ Detectron2 not available: {e}")
50
- DETECTRON2_AVAILABLE = False
51
-
52
- # Try to download model from Hugging Face
53
- huggingface_model_path = None
54
- try:
55
- from huggingface_hub import hf_hub_download
56
-
57
- # Try to download from your repository
58
- huggingface_model_path = hf_hub_download(
59
- repo_id=os.getenv('PRIVATE_REPO', 'fallback'),
60
- filename="V1.pkl",
61
- token=os.getenv('key')
62
- )
63
- print(f"✅ Model downloaded from Hugging Face: {huggingface_model_path}")
64
- except Exception as e:
65
- print(f"⚠️ Could not download model from Hugging Face: {e}")
66
- print("🔄 Will use demo mode with simulated results")
67
- huggingface_model_path = None
68
- # --------------------------------------------------------------------------------------
69
- # Basics
70
-
71
- # Initialize device for model
72
- if torch.backends.mps.is_available():
73
- RADIO_DEVICE = torch.device("mps")
74
- elif torch.cuda.is_available():
75
- RADIO_DEVICE = torch.device("cuda")
76
- else:
77
- RADIO_DEVICE = torch.device("cpu")
78
-
79
- # Global variables for C model
80
- radio_l_image_processor = None
81
- radio_l_model = None
82
- ai_detection_classifier = None
83
-
84
-
85
- # Preload the C model at startup
86
  def preload_models():
87
  """Preload models at startup to improve response time"""
88
  global radio_l_image_processor, radio_l_model
@@ -92,651 +7,28 @@ def preload_models():
92
  hf_repo = os.getenv('MODEL_REPO', 'fallback')
93
  if hf_repo and hf_repo != 'fallback':
94
  from transformers import AutoModel, CLIPImageProcessor
95
- radio_l_image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
 
96
  radio_l_model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  radio_l_model = radio_l_model.to(RADIO_DEVICE)
98
  radio_l_model.eval()
99
  print("✅ C model preloaded successfully!")
100
  return True
101
  except Exception as e:
102
- print(f"⚠️ Could not preload C model: {e}")
103
- return False
104
-
105
-
106
- # Add current directory to path (for local modules if any)
107
- if os.getcwd() not in sys.path:
108
- sys.path.append(os.getcwd())
109
-
110
- # --------------------------------------------------------------------------------------
111
- # Hugging Face download (robust)
112
- # --------------------------------------------------------------------------------------
113
-
114
-
115
-
116
- # --------------------------------------------------------------------------------------
117
- # Paths / Devices
118
- # --------------------------------------------------------------------------------------
119
- DEFAULT_AI_DETECTION_MODEL_PATH = "./output/V1.pkl"
120
- DEFAULT_DAMAGE_MODEL_PATH = "./output/model_final.pth" # Stage 1 (Detectron2)
121
-
122
- if torch.backends.mps.is_available():
123
- DEVICE = torch.device("mps")
124
- elif torch.cuda.is_available():
125
- DEVICE = torch.device("cuda")
126
- else:
127
- DEVICE = torch.device("cpu")
128
-
129
- print(f"🖥️ Using device: {DEVICE}")
130
-
131
- # --------------------------------------------------------------------------------------
132
- # Globals
133
- # --------------------------------------------------------------------------------------
134
- image_processor: Optional[CLIPImageProcessor] = None
135
- model: Optional[AutoModel] = None
136
- ai_detection_classifier = None
137
- _preloaded = False
138
-
139
- # --------------------------------------------------------------------------------------
140
- # Detectron2 (Stage 1) availability & loader
141
- # --------------------------------------------------------------------------------------
142
- DETECTRON2_AVAILABLE = False
143
- try:
144
- import importlib.util as _imp
145
- if _imp.find_spec("detectron2") is not None:
146
- from detectron2.engine import DefaultPredictor
147
- from detectron2.config import get_cfg
148
- from detectron2 import model_zoo
149
- DETECTRON2_AVAILABLE = True
150
- print("✅ Detectron2 detected")
151
- else:
152
- print("ℹ️ Detectron2 not installed; Stage 1 will run in demo mode")
153
- except Exception as _e:
154
- print(f"ℹ️ Detectron2 unavailable ({_e}); Stage 1 will run in demo mode")
155
- DETECTRON2_AVAILABLE = False
156
-
157
- _damage_predictor = None
158
-
159
- def load_damage_model(model_path: str, device_str: str = None):
160
- """Load fine-tuned Detectron2 model once (Stage 1)."""
161
- global _damage_predictor
162
- if _damage_predictor is not None:
163
- return _damage_predictor
164
-
165
- if (not DETECTRON2_AVAILABLE) or (not model_path) or (not os.path.exists(model_path)):
166
- print("ℹ️ Stage 1 damage model not available; using simulator")
167
- return None
168
-
169
- try:
170
- cfg = get_cfg()
171
- cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
172
- cfg.MODEL.WEIGHTS = model_path
173
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
174
- cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # fine-tuned for single 'damage' class
175
-
176
- if device_str is None:
177
- device_str = "cuda" if torch.cuda.is_available() else "cpu"
178
- cfg.MODEL.DEVICE = device_str
179
-
180
- _damage_predictor = DefaultPredictor(cfg)
181
- print(f"✅ Damage model loaded on {device_str}")
182
- return _damage_predictor
183
- except Exception as e:
184
- print(f"❌ Could not load Detectron2 model: {e}")
185
- return None
186
-
187
- def simulate_damage_detection(rgb_image: np.ndarray, seed_from: np.ndarray = None) -> List[Dict[str, Any]]:
188
- """Deterministic fake detections for demo mode."""
189
- import hashlib, random
190
- h, w = rgb_image.shape[:2]
191
- if seed_from is None:
192
- seed_from = rgb_image
193
- img_hash = hashlib.md5(seed_from.tobytes()).hexdigest()
194
- seed = int(img_hash[:8], 16) % 10_000
195
- random.seed(seed)
196
- n = random.randint(0, 3)
197
- boxes = []
198
- for _ in range(n):
199
- x1 = random.randint(0, max(0, w - w//3))
200
- y1 = random.randint(0, max(0, h - h//3))
201
- x2 = min(w-1, x1 + random.randint(w//8, w//3))
202
- y2 = min(h-1, y1 + random.randint(h//8, h//3))
203
- conf = round(random.uniform(0.6, 0.95), 3)
204
- boxes.append({"bbox":[x1,y1,x2,y2], "score":conf, "label":"damage"})
205
- return boxes
206
-
207
- def auto_install_dependencies():
208
- """Attempt to install dependencies if needed"""
209
- try:
210
- import importlib.util
211
-
212
- # Check for PyTorch
213
- if importlib.util.find_spec("torch") is None:
214
- print("Installing PyTorch...")
215
- os.system("pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu")
216
-
217
- # Check for Detectron2
218
- if importlib.util.find_spec("detectron2") is None:
219
- print("Installing Detectron2...")
220
- os.system("pip install git+https://github.com/facebookresearch/detectron2.git")
221
-
222
- # Check for Gradio
223
- if importlib.util.find_spec("gradio") is None:
224
- print("Installing Gradio...")
225
- os.system("pip install gradio")
226
-
227
- print("Dependencies installation complete!")
228
- return True
229
- except Exception as e:
230
- print(f"Error installing dependencies: {e}")
231
- return False
232
-
233
-
234
-
235
-
236
- def run_damage_detection(pil_image: Image.Image, score_thresh: float = 0.5):
237
- """
238
- Returns:
239
- damage_boxes: list of dicts {bbox:[x1,y1,x2,y2], score:float, label:str}
240
- annotated: numpy RGB image with boxes annotated (or None on error)
241
- demo: bool
242
- reason: str|None
243
- """
244
- try:
245
- rgb = np.array(pil_image.convert("RGB"))
246
- predictor = load_damage_model(DEFAULT_DAMAGE_MODEL_PATH)
247
- if predictor is None:
248
- boxes = simulate_damage_detection(rgb, seed_from=rgb)
249
- annotated = rgb.copy()
250
- for i, b in enumerate(boxes):
251
- x1,y1,x2,y2 = b["bbox"]
252
- cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,255,0), 2)
253
- cv2.putText(annotated, f"Damage {i+1} {b['score']*100:.1f}%",
254
- (x1, max(0,y1-8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2)
255
- return boxes, annotated, True, "predictor not available"
256
- # Real inference
257
- outputs = predictor(rgb)
258
- instances = outputs["instances"].to("cpu")
259
- boxes = []
260
- if len(instances) > 0:
261
- pred_boxes = instances.pred_boxes.tensor.numpy()
262
- scores = instances.scores.numpy()
263
- for i, (box, sc) in enumerate(zip(pred_boxes, scores)):
264
- if sc >= score_thresh:
265
- x1,y1,x2,y2 = [int(v) for v in box]
266
- boxes.append({"bbox":[x1,y1,x2,y2], "score":float(sc), "label":"damage"})
267
- annotated = rgb.copy()
268
- for i, b in enumerate(boxes):
269
- x1,y1,x2,y2 = b["bbox"]
270
- cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,255,0), 2)
271
- cv2.putText(annotated, f"Damage {i+1} {b['score']*100:.1f}%",
272
- (x1, max(0,y1-8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2)
273
- return boxes, annotated, False, None
274
- except Exception as e:
275
- print(f"⚠️ Stage 1 error: {e}")
276
- traceback.print_exc()
277
- # fall back to simulator
278
- rgb = np.array(pil_image.convert("RGB"))
279
- boxes = simulate_damage_detection(rgb, seed_from=rgb)
280
- annotated = rgb.copy()
281
- for i, b in enumerate(boxes):
282
- x1,y1,x2,y2 = b["bbox"]
283
- cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,255,0), 2)
284
- cv2.putText(annotated, f"Damage {i+1} {b['score']*100:.1f}%",
285
- (x1, max(0,y1-8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2)
286
- return boxes, annotated, True, "stage1 error"
287
-
288
- # --------------------------------------------------------------------------------------
289
- # Stage 2: RADIO feature extractor + classifier
290
- # --------------------------------------------------------------------------------------
291
-
292
-
293
- def load_ai_detection_classifier(model_path):
294
- """Load the AI detection classifier (joblib)."""
295
- global ai_detection_classifier
296
- if ai_detection_classifier is not None:
297
- print("✅ Classifier already loaded, reusing...")
298
- return ai_detection_classifier
299
-
300
- if model_path is None or not os.path.exists(model_path):
301
- print(f"❌ AI detection model not found at: {model_path}")
302
- return None
303
-
304
- try:
305
- ai_detection_classifier = joblib.load(model_path)
306
- print(f"✅ AI detection classifier loaded from {model_path}")
307
- print(f" Classifier type: {type(ai_detection_classifier).__name__}")
308
- return ai_detection_classifier
309
- except Exception as e:
310
- print(f"❌ Error loading classifier: {e}")
311
- return None
312
-
313
- def preprocess_image(image) -> Optional[Image.Image]:
314
- """Robust image preprocessing with EXIF orientation and dtype/range fixes."""
315
- try:
316
- if image is None:
317
- return None
318
-
319
- if isinstance(image, Image.Image):
320
- pil = image
321
- elif isinstance(image, (str,)):
322
- arr = cv2.imread(image, cv2.IMREAD_UNCHANGED)
323
- if arr is None:
324
- return None
325
- if arr.ndim == 3:
326
- arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
327
- if arr.dtype != np.uint8:
328
- arr = np.clip(arr, 0, 255).astype(np.uint8)
329
- pil = Image.fromarray(arr if arr.ndim == 3 else np.stack([arr]*3, axis=-1), 'RGB')
330
- elif isinstance(image, dict) and "path" in image:
331
- arr = cv2.imread(image["path"], cv2.IMREAD_UNCHANGED)
332
- if arr is None:
333
- return None
334
- if arr.ndim == 3:
335
- arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
336
- if arr.dtype != np.uint8:
337
- arr = np.clip(arr, 0, 255).astype(np.uint8)
338
- pil = Image.fromarray(arr if arr.ndim == 3 else np.stack([arr]*3, axis=-1), 'RGB')
339
- else:
340
- # assume numpy-like
341
- arr = np.array(image)
342
- if arr.ndim == 2:
343
- arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
344
- elif arr.ndim == 3 and arr.shape[2] == 4:
345
- arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2RGB)
346
- if arr.dtype != np.uint8:
347
- if arr.max() <= 1.0:
348
- arr = np.clip(arr, 0, 1)
349
- arr = (arr * 255.0).astype(np.uint8)
350
- else:
351
- arr = np.clip(arr, 0, 255).astype(np.uint8)
352
- pil = Image.fromarray(arr, 'RGB')
353
-
354
- # Normalize EXIF orientation
355
- pil = ImageOps.exif_transpose(pil)
356
- return pil.convert("RGB")
357
- except Exception as e:
358
- print(f"❌ Preprocess error: {e}")
359
  traceback.print_exc()
360
- return None
361
-
362
- def _forward_radio(pixel_values: torch.Tensor) -> torch.Tensor:
363
- with torch.no_grad():
364
- if DEVICE.type == "cuda":
365
- from torch.cuda.amp import autocast
366
- with autocast():
367
- out = model(pixel_values)
368
- else:
369
- out = model(pixel_values)
370
-
371
- # Accept dict / tuple / single tensor
372
- if isinstance(out, dict):
373
- feats = out.get("features") or out.get("last_hidden_state") or next(iter(out.values()))
374
- elif isinstance(out, (list, tuple)):
375
- feats = out[-1]
376
- else:
377
- feats = out
378
- return feats
379
-
380
- def extract_features(image, return_stats=False):
381
- """Extract normalized features using RADIO model."""
382
- global image_processor, model
383
- if image_processor is None or model is None:
384
- raise Exception("RADIO model not initialized")
385
-
386
- if not isinstance(image, Image.Image):
387
- image = preprocess_image(image)
388
- if image is None:
389
- raise Exception("Failed to preprocess image")
390
-
391
- image = image.resize((224, 224), Image.Resampling.LANCZOS)
392
- pixel_values = image_processor(images=image, return_tensors='pt', do_resize=True).pixel_values.to(DEVICE)
393
-
394
- feats = _forward_radio(pixel_values)
395
-
396
- # Pool if sequence or feature map
397
- if feats.ndim == 3: # (B, T, C)
398
- feats = feats.mean(dim=1)
399
- elif feats.ndim == 4: # (B, C, H, W)
400
- feats = feats.mean(dim=(2, 3))
401
-
402
- feats = feats.detach().flatten()
403
- feats = F.normalize(feats, p=2, dim=-1).cpu().flatten()
404
- out_np = feats.numpy()
405
-
406
- if return_stats:
407
- stats = {
408
- "mean": float(out_np.mean()),
409
- "std": float(out_np.std()),
410
- "min": float(out_np.min()),
411
- "max": float(out_np.max()),
412
- "shape": out_np.shape
413
- }
414
- return out_np, stats
415
- return out_np
416
-
417
- def _predict_with_classifier(classifier, features: np.ndarray) -> Tuple[int, float, float]:
418
- """Predict with proba mapping respecting classifier.classes_.
419
- Returns (pred_label, ai_prob, real_prob). Assumes label '1' == AI, '0' == real.
420
- """
421
- pred = int(classifier.predict(features)[0])
422
- ai_prob = real_prob = 0.5
423
- if hasattr(classifier, "predict_proba"):
424
- probs = classifier.predict_proba(features)[0]
425
- if hasattr(classifier, "classes_"):
426
- classes = list(classifier.classes_)
427
- if 1 in classes:
428
- ai_prob = float(probs[classes.index(1)])
429
- if 0 in classes:
430
- real_prob = float(probs[classes.index(0)])
431
- if 0 not in classes or 1 not in classes:
432
- m = float(probs.max()); ai_prob = m; real_prob = 1.0 - m
433
- else:
434
- m = float(probs.max()); ai_prob = m; real_prob = 1.0 - m
435
- elif hasattr(classifier, "decision_function"):
436
- df = float(classifier.decision_function(features)[0])
437
- ai_prob = 1.0 / (1.0 + np.exp(-df))
438
- real_prob = 1.0 - ai_prob
439
- else:
440
- ai_prob = float(pred); real_prob = 1.0 - ai_prob
441
- return pred, ai_prob, real_prob
442
-
443
- def simulate_prediction(image) -> Dict[str, Any]:
444
- """Fallback simulation when models/classifier aren't available."""
445
- import hashlib, random
446
- if isinstance(image, Image.Image):
447
- arr = np.array(image.convert("RGB"))
448
- elif isinstance(image, np.ndarray):
449
- arr = image
450
- else:
451
- arr = np.array(preprocess_image(image) or Image.new("RGB",(16,16),(0,0,0)))
452
-
453
- img_hash = hashlib.md5(arr.tobytes()).hexdigest()
454
- seed = int(img_hash[:8], 16) % 1000
455
- random.seed(seed)
456
- ai_prob = random.uniform(0.1, 0.9)
457
- is_ai = ai_prob > 0.5
458
- confidence_level = "HIGH" if abs(ai_prob - 0.5) > 0.3 else "MEDIUM" if abs(ai_prob - 0.5) > 0.15 else "LOW"
459
-
460
- return {
461
- "prediction": "AI-Generated" if is_ai else "Real",
462
- "ai_probability": ai_prob,
463
- "real_probability": 1 - ai_prob,
464
- "confidence": confidence_level,
465
- "is_demo": True
466
- }
467
-
468
- def _overlay_final_verdict(annotated_rgb: np.ndarray, verdict_text: str, ai_prob: float, real_prob: float, is_ai: bool):
469
- out = annotated_rgb.copy()
470
- color = (0,255,0) if not is_ai else (0,0,255)
471
- conf = max(ai_prob, real_prob)
472
- cv2.putText(out, verdict_text, (30, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 3)
473
- cv2.putText(out, f"Confidence: {conf*100:.1f}%", (30, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
474
- return out
475
-
476
- # --------------------------------------------------------------------------------------
477
- # Gradio App
478
- # --------------------------------------------------------------------------------------
479
- def create_gradio_interface():
480
- """Enhanced Gradio interface with Stage 1 + Stage 2."""
481
- custom_css = """
482
- .gradio-container { font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Ubuntu, Cantarell, 'Helvetica Neue', Arial, 'Noto Sans', 'Apple Color Emoji', 'Segoe UI Emoji'; }
483
- """
484
-
485
- with gr.Blocks(title="AI Image Detection", css=custom_css) as app:
486
- gr.HTML("""
487
- <div style="text-align: center; padding: 20px; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">
488
- <h1 style="margin: 0;">🤖 AI Image Detection</h1>
489
- <p style="margin: 10px 0 0 0;">Stage 1 (Damage) + Stage 2 (AI-Generated) — with graceful fallbacks</p>
490
- </div>
491
- """)
492
-
493
- with gr.Row():
494
- with gr.Column():
495
- input_image = gr.Image(
496
- type="numpy",
497
- label="Upload Image",
498
- height=400
499
- )
500
-
501
- with gr.Row():
502
- predict_btn = gr.Button("🔍 Analyze", variant="primary", size="lg")
503
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg")
504
-
505
- enable_damage = gr.Checkbox(value=True, label="Enable Stage 1 (Damage Detection)")
506
- damage_thresh = gr.Slider(0.1, 0.95, value=0.5, step=0.05, label="Damage Score Threshold")
507
-
508
- with gr.Column():
509
- output_text = gr.Textbox(
510
- label="Prediction Result",
511
- placeholder="Upload an image and click Analyze",
512
- interactive=False,
513
- lines=2
514
- )
515
- output_json = gr.JSON(label="Detailed Analysis (Stage 2)")
516
- damage_json = gr.JSON(label="Stage 1: Damage Detections")
517
- annotated_image = gr.Image(label="Annotated Output (Damage + Verdict)")
518
- status_display = gr.HTML("""
519
- <div style="padding: 10px; background: #f0f4f8; border-radius: 8px; margin-top: 10px;">
520
- <p style="margin: 0; color: #64748b;">Ready for analysis...</p>
521
- </div>
522
- """)
523
-
524
- # --- Analyze callback ---
525
- def analyze_with_status(image, enable_damage, damage_thresh):
526
- """Analyze and update all outputs (Stage 1 + Stage 2)."""
527
- if image is None:
528
- return (
529
- "❌ No image provided",
530
- {"error": "No image provided"},
531
- """
532
- <div style="padding: 10px; background: #fee2e2; border-radius: 8px; margin-top: 10px;">
533
- <p style="margin: 0; color: #dc2626; font-weight: bold;">❌ No image provided</p>
534
- </div>
535
- """,
536
- [],
537
- None
538
- )
539
-
540
- # Stage 2 init
541
- model_initialized = (image_processor is not None and model is not None) or preload_models()
542
- model_path = huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH
543
- classifier = ai_detection_classifier or load_ai_detection_classifier(model_path)
544
-
545
- demo_reasons = []
546
- if not model_initialized: demo_reasons.append("feature extractor missing")
547
- if classifier is None: demo_reasons.append("classifier missing")
548
-
549
- # Stage 2 run
550
- try:
551
- if demo_reasons:
552
- result2 = simulate_prediction(preprocess_image(image))
553
- result2["demo_reasons"] = demo_reasons
554
- simple_result = f"{result2['prediction']} (AI: {result2['ai_probability']:.2%}) [Demo: {', '.join(demo_reasons)}]"
555
- detailed_result = result2
556
- else:
557
- feats, stats = extract_features(preprocess_image(image), return_stats=True)
558
- feats = feats.reshape(1, -1)
559
- pred, ai_prob, real_prob = _predict_with_classifier(classifier, feats)
560
- is_ai = pred == 1
561
- result_text = "AI-Generated" if is_ai else "Real"
562
- conf_score = max(ai_prob, real_prob)
563
- confidence = "HIGH" if conf_score > 0.80 else "MEDIUM" if conf_score > 0.60 else "LOW"
564
- detailed_result = {
565
- "prediction": result_text,
566
- "ai_probability": ai_prob,
567
- "real_probability": real_prob,
568
- "confidence": confidence,
569
- "confidence_score": conf_score,
570
- "is_demo": False,
571
- "feature_stats": stats
572
- }
573
- simple_result = f"{result_text} (Confidence: {conf_score:.2%}, {confidence})"
574
- except Exception as e:
575
- print(f"❌ Error Stage 2: {e}")
576
- traceback.print_exc()
577
- result2 = simulate_prediction(preprocess_image(image))
578
- result2["demo_reasons"] = ["stage2 error"]
579
- simple_result = f"{result2['prediction']} (AI: {result2['ai_probability']:.2%}) [Demo: stage2 error]"
580
- detailed_result = result2
581
-
582
- # Get Stage 2 values for overlay
583
- is_ai = False
584
- ai_prob = real_prob = 0.5
585
- verdict_text = "Unknown"
586
- if isinstance(detailed_result, dict) and "error" not in detailed_result:
587
- is_ai = (detailed_result.get("prediction") == "AI-Generated")
588
- ai_prob = float(detailed_result.get("ai_probability", 0.5))
589
- real_prob = float(detailed_result.get("real_probability", 0.5))
590
- verdict_text = detailed_result.get("prediction", "Unknown")
591
-
592
- # Stage 1 (optional)
593
- dmg_results = []
594
- annotated = None
595
- stage1_demo_reasons = []
596
- try:
597
- if enable_damage:
598
- pil = preprocess_image(image)
599
- if pil is not None:
600
- boxes, annotated_rgb, demo, reason = run_damage_detection(pil, float(damage_thresh))
601
- dmg_results = boxes
602
- if demo and reason:
603
- stage1_demo_reasons.append(f"Stage1:{reason}")
604
- annotated = annotated_rgb
605
- else:
606
- stage1_demo_reasons.append("Stage1:preprocess-fail")
607
- except Exception as e:
608
- print(f"⚠️ Stage 1 analyze error: {e}")
609
- traceback.print_exc()
610
- stage1_demo_reasons.append("Stage1:error")
611
-
612
- # If no annotated yet, create from original for overlay
613
- if annotated is None:
614
- pil = preprocess_image(image)
615
- if pil is not None:
616
- annotated = np.array(pil.convert("RGB"))
617
-
618
- # Overlay verdict
619
- if annotated is not None:
620
- annotated = _overlay_final_verdict(annotated, verdict_text, ai_prob, real_prob, is_ai)
621
-
622
- # Status HTML
623
- if isinstance(detailed_result, dict) and "error" not in detailed_result:
624
- if detailed_result.get("is_demo"):
625
- status_color = "#f59e0b"
626
- extra = ""
627
- demo_list = detailed_result.get('demo_reasons', [])
628
- if demo_list:
629
- extra = f" (Demo: {', '.join(demo_list)})"
630
- if stage1_demo_reasons:
631
- if extra:
632
- extra = extra[:-1] + ", " + ", ".join(stage1_demo_reasons) + ")"
633
- else:
634
- extra = f" (Stage1 demo: {', '.join(stage1_demo_reasons)})"
635
- status_text = f"⚠️ Demo Mode{extra}"
636
- else:
637
- status_color = "#10b981"
638
- extra = f" (Stage1 demo: {', '.join(stage1_demo_reasons)})" if stage1_demo_reasons else ""
639
- status_text = f"✅ Analysis Complete{extra}"
640
- status_html = f"""
641
- <div style="padding: 10px; background: #f0f4f8; border-radius: 8px; margin-top: 10px;">
642
- <p style="margin: 0; color: {status_color}; font-weight: bold;">{status_text}</p>
643
- </div>
644
- """
645
- else:
646
- status_html = """
647
- <div style="padding: 10px; background: #fee2e2; border-radius: 8px; margin-top: 10px;">
648
- <p style="margin: 0; color: #dc2626; font-weight: bold;">❌ Analysis Failed</p>
649
- </div>
650
- """
651
-
652
- return simple_result, detailed_result, status_html, dmg_results, (annotated if annotated is not None else None)
653
-
654
- def clear_all():
655
- """Clear all fields."""
656
- return (
657
- None, # input image
658
- "", # output_text
659
- {}, # output_json
660
- """
661
- <div style="padding: 10px; background: #f0f4f8; border-radius: 8px; margin-top: 10px;">
662
- <p style="margin: 0; color: #64748b;">Ready for analysis...</p>
663
- </div>
664
- """,
665
- [], # damage_json
666
- None # annotated_image
667
- )
668
-
669
- # Wire events
670
- predict_btn.click(
671
- fn=analyze_with_status,
672
- inputs=[input_image, enable_damage, damage_thresh],
673
- outputs=[output_text, output_json, status_display, damage_json, annotated_image]
674
- )
675
-
676
- clear_btn.click(
677
- fn=clear_all,
678
- outputs=[input_image, output_text, output_json, status_display, damage_json, annotated_image]
679
- )
680
-
681
- # Auto-analyze on image change
682
- input_image.change(
683
- fn=analyze_with_status,
684
- inputs=[input_image, enable_damage, damage_thresh],
685
- outputs=[output_text, output_json, status_display, damage_json, annotated_image]
686
- )
687
-
688
- # Examples (add paths if you have assets)
689
- gr.Examples(
690
- examples=[ # ("path/to/example1.jpg"), ("path/to/example2.png")
691
- ],
692
- inputs=input_image,
693
- label="Example Images"
694
- )
695
-
696
- with gr.Accordion("ℹ️ About This Model", open=False):
697
- gr.Markdown("""
698
- ### Pipeline
699
- - **Stage 1 (optional)**: Detectron2 damage/zone detection (Mask R-CNN R50-FPN), with simulated fallback.
700
- - **Stage 2**: RADIO visual features + scikit-learn classifier (`V1.pkl`) for AI-generated vs real.
701
-
702
- ### Notes
703
- - If the RADIO extractor or classifier is missing, the app runs in **Demo Mode** with deterministic simulation.
704
- - If Detectron2 or your weights are missing, Stage 1 uses a fast simulator but keeps the UX intact.
705
- """)
706
-
707
- return app
708
-
709
- # --------------------------------------------------------------------------------------
710
- # Main
711
- # --------------------------------------------------------------------------------------
712
- if __name__ == "__main__":
713
- print("=" * 60)
714
- print("🚀 Starting AI Image Detection App (Model 2 + Stage 1)")
715
- print("=" * 60)
716
- print(f"📍 Device: {DEVICE}")
717
- print(f"📦 Classifier Path: {huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH}")
718
- print(f"🛠️ Damage Model Path: {DEFAULT_DAMAGE_MODEL_PATH} ({'exists' if os.path.exists(DEFAULT_DAMAGE_MODEL_PATH) else 'missing'})")
719
-
720
- # Check if dependencies are installed
721
- auto_install_dependencies()
722
-
723
- # Preload C model at startup
724
- preload_models()
725
-
726
- # Load classifier
727
- model_path = huggingface_model_path or DEFAULT_AI_DETECTION_MODEL_PATH
728
- classifier_loaded = load_ai_detection_classifier(model_path) is not None
729
-
730
- print("=" * 60)
731
- print("=" * 60)
732
-
733
-
734
-
735
-
736
- app = create_gradio_interface()
737
- app.launch(
738
- share=False,
739
- server_name="0.0.0.0",
740
- server_port=7860,
741
- show_error=True
742
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def preload_models():
2
  """Preload models at startup to improve response time"""
3
  global radio_l_image_processor, radio_l_model
 
7
  hf_repo = os.getenv('MODEL_REPO', 'fallback')
8
  if hf_repo and hf_repo != 'fallback':
9
  from transformers import AutoModel, CLIPImageProcessor
10
+
11
+ # Load the model first to inspect it
12
  radio_l_model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
13
+
14
+ # Debug: Print available keys
15
+ state_dict = radio_l_model.state_dict()
16
+ print("Available keys in model (first 10):")
17
+ for i, key in enumerate(list(state_dict.keys())[:10]):
18
+ print(f" {key}")
19
+
20
+ # Check for blocks.0.ls1 related keys
21
+ ls1_keys = [k for k in state_dict.keys() if 'ls1' in k]
22
+ if ls1_keys:
23
+ print(f"Found ls1 keys: {ls1_keys[:5]}")
24
+
25
+ radio_l_image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
26
  radio_l_model = radio_l_model.to(RADIO_DEVICE)
27
  radio_l_model.eval()
28
  print("✅ C model preloaded successfully!")
29
  return True
30
  except Exception as e:
31
+ print(f"⚠️ Could not preload C model: {repr(e)}")
32
+ import traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  traceback.print_exc()
34
+ return False