astrosbd commited on
Commit
c387e2a
·
verified ·
1 Parent(s): b382923

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -25
app.py CHANGED
@@ -1,34 +1,394 @@
1
- def preload_models():
2
- """Preload models at startup to improve response time"""
3
- global radio_l_image_processor, radio_l_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- print("🔄 Preloading C model (4GB)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  try:
7
- hf_repo = os.getenv('MODEL_REPO', 'fallback')
8
- if hf_repo and hf_repo != 'fallback':
9
- from transformers import AutoModel, CLIPImageProcessor
10
-
11
- # Load the model first to inspect it
12
- radio_l_model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
13
 
14
- # Debug: Print available keys
15
- state_dict = radio_l_model.state_dict()
16
- print("Available keys in model (first 10):")
17
- for i, key in enumerate(list(state_dict.keys())[:10]):
18
- print(f" {key}")
 
 
 
 
19
 
20
- # Check for blocks.0.ls1 related keys
21
- ls1_keys = [k for k in state_dict.keys() if 'ls1' in k]
22
- if ls1_keys:
23
- print(f"Found ls1 keys: {ls1_keys[:5]}")
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  radio_l_image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  radio_l_model = radio_l_model.to(RADIO_DEVICE)
27
- radio_l_model.eval()
28
- print(" C model preloaded successfully!")
29
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  except Exception as e:
31
- print(f"⚠️ Could not preload C model: {repr(e)}")
32
- import traceback
33
  traceback.print_exc()
34
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ import traceback
5
+ import json
6
+ import pickle
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ import gradio as gr
11
+ from typing import Optional, Tuple, Dict, Any
12
+
13
+ # ============================================================
14
+ # CONFIGURATION AND SETUP
15
+ # ============================================================
16
+
17
+ print("=" * 60)
18
+ print("DEBUGGING APP.PY - Model Loading Diagnostics")
19
+ print("=" * 60)
20
+
21
+ # Environment info
22
+ print("\n📋 Environment Information:")
23
+ print(f"Python version: {sys.version}")
24
+ print(f"PyTorch version: {torch.__version__}")
25
+ print(f"CUDA available: {torch.cuda.is_available()}")
26
+ if torch.cuda.is_available():
27
+ print(f"CUDA version: {torch.version.cuda}")
28
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
29
+
30
+ # Check environment variables
31
+ print("\n📦 Environment Variables:")
32
+ env_vars = ['MODEL_REPO', 'HF_TOKEN', 'CUDA_VISIBLE_DEVICES']
33
+ for var in env_vars:
34
+ value = os.getenv(var, 'NOT SET')
35
+ if var == 'HF_TOKEN' and value != 'NOT SET':
36
+ value = value[:10] + '...' if len(value) > 10 else value
37
+ print(f" {var}: {value}")
38
 
39
+ # Device configuration
40
+ RADIO_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ AI_DETECT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ print(f"\n🖥️ Using device: {RADIO_DEVICE}")
43
+
44
+ # Global variables
45
+ radio_l_image_processor = None
46
+ radio_l_model = None
47
+ ai_detection_classifier = None
48
+
49
+ # ============================================================
50
+ # DEBUGGING UTILITIES
51
+ # ============================================================
52
+
53
+ def inspect_model_architecture(model, model_name="Model"):
54
+ """Inspect and print model architecture details"""
55
+ print(f"\n🔍 Inspecting {model_name}:")
56
+ print(f" Model type: {type(model).__name__}")
57
+
58
+ # Get state dict
59
  try:
60
+ state_dict = model.state_dict()
61
+ print(f" Total parameters: {len(state_dict)}")
62
+
63
+ # Look for specific patterns
64
+ patterns = ['ls1', 'blocks', 'gamma', 'grandma', 'layer_norm', 'ln']
65
+ for pattern in patterns:
66
+ matching_keys = [k for k in state_dict.keys() if pattern in k.lower()]
67
+ if matching_keys:
68
+ print(f" Keys containing '{pattern}': {len(matching_keys)}")
69
+ print(f" First 3: {matching_keys[:3]}")
70
+
71
+ # Show first 10 keys
72
+ print(f" First 10 state dict keys:")
73
+ for i, key in enumerate(list(state_dict.keys())[:10]):
74
+ print(f" {i+1}. {key}")
75
 
76
+ except Exception as e:
77
+ print(f" Could not inspect state dict: {e}")
78
+
79
+ # Check for specific attributes
80
+ attrs_to_check = ['blocks', 'layers', 'encoder', 'decoder', 'visual', 'text']
81
+ print(f" Model attributes:")
82
+ for attr in attrs_to_check:
83
+ if hasattr(model, attr):
84
+ print(f" ✓ Has '{attr}'")
85
 
86
+ def test_pickle_file(filepath):
87
+ """Test loading a pickle file and inspect its contents"""
88
+ print(f"\n🥒 Testing pickle file: {filepath}")
89
+ try:
90
+ with open(filepath, 'rb') as f:
91
+ obj = pickle.load(f)
92
+ print(f" ✓ Successfully loaded pickle")
93
+ print(f" Object type: {type(obj).__name__}")
94
+
95
+ if hasattr(obj, '__dict__'):
96
+ print(f" Object attributes: {list(obj.__dict__.keys())[:5]}")
97
+
98
+ if hasattr(obj, 'get_params'):
99
+ params = obj.get_params()
100
+ print(f" Model parameters: {list(params.keys())[:5]}")
101
 
102
+ return obj
103
+ except Exception as e:
104
+ print(f" ✗ Failed to load pickle: {e}")
105
+ return None
106
+
107
+ # ============================================================
108
+ # MODEL LOADING FUNCTIONS
109
+ # ============================================================
110
+
111
+ def preload_c_model_debug():
112
+ """Debug version of C model preloading"""
113
+ global radio_l_image_processor, radio_l_model
114
+
115
+ print("\n" + "=" * 60)
116
+ print("LOADING C MODEL")
117
+ print("=" * 60)
118
+
119
+ hf_repo = os.getenv('MODEL_REPO', 'fallback')
120
+ print(f"Repository: {hf_repo}")
121
+
122
+ if not hf_repo or hf_repo == 'fallback':
123
+ print("⚠️ No MODEL_REPO environment variable set")
124
+ return False
125
+
126
+ try:
127
+ print("\n1️⃣ Importing transformers...")
128
+ from transformers import AutoModel, CLIPImageProcessor, AutoConfig
129
+ print(" ✓ Imports successful")
130
+
131
+ # Try to load config first
132
+ print("\n2️⃣ Loading model config...")
133
+ try:
134
+ config = AutoConfig.from_pretrained(hf_repo, trust_remote_code=True)
135
+ print(f" ✓ Config loaded: {type(config).__name__}")
136
+ if hasattr(config, 'architectures'):
137
+ print(f" Architectures: {config.architectures}")
138
+ except Exception as e:
139
+ print(f" ⚠️ Could not load config: {e}")
140
+
141
+ # Load image processor
142
+ print("\n3️⃣ Loading image processor...")
143
+ try:
144
  radio_l_image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
145
+ print(f" ✓ Image processor loaded: {type(radio_l_image_processor).__name__}")
146
+ except Exception as e:
147
+ print(f" ✗ Failed to load image processor: {e}")
148
+ # Try alternative processors
149
+ print(" Trying alternative processors...")
150
+ from transformers import AutoImageProcessor
151
+ try:
152
+ radio_l_image_processor = AutoImageProcessor.from_pretrained(hf_repo)
153
+ print(f" ✓ Alternative processor loaded: {type(radio_l_image_processor).__name__}")
154
+ except:
155
+ print(" ✗ All processor attempts failed")
156
+
157
+ # Load model with detailed error catching
158
+ print("\n4️⃣ Loading model...")
159
+ try:
160
+ # First attempt - standard loading
161
+ radio_l_model = AutoModel.from_pretrained(
162
+ hf_repo,
163
+ trust_remote_code=True,
164
+ torch_dtype=torch.float32,
165
+ device_map='auto' if torch.cuda.is_available() else None
166
+ )
167
+ print(f" ✓ Model loaded: {type(radio_l_model).__name__}")
168
+
169
+ except Exception as e:
170
+ print(f" ✗ Standard loading failed: {e}")
171
+ print(" Trying with force_download...")
172
+
173
+ # Second attempt - force download
174
+ try:
175
+ radio_l_model = AutoModel.from_pretrained(
176
+ hf_repo,
177
+ trust_remote_code=True,
178
+ force_download=True,
179
+ torch_dtype=torch.float32
180
+ )
181
+ print(f" ✓ Model loaded with force_download")
182
+ except Exception as e2:
183
+ print(f" ✗ Force download failed: {e2}")
184
+ raise
185
+
186
+ # Inspect the loaded model
187
+ inspect_model_architecture(radio_l_model, "C Model")
188
+
189
+ # Move to device
190
+ print(f"\n5️⃣ Moving model to {RADIO_DEVICE}...")
191
+ if RADIO_DEVICE.type != 'cpu':
192
  radio_l_model = radio_l_model.to(RADIO_DEVICE)
193
+ radio_l_model.eval()
194
+ print(" Model moved and set to eval mode")
195
+
196
+ # Test forward pass
197
+ print("\n6️⃣ Testing forward pass...")
198
+ try:
199
+ with torch.no_grad():
200
+ # Create a dummy image
201
+ dummy_image = Image.new('RGB', (224, 224), color='white')
202
+ if radio_l_image_processor:
203
+ inputs = radio_l_image_processor(dummy_image, return_tensors="pt")
204
+ if RADIO_DEVICE.type != 'cpu':
205
+ inputs = {k: v.to(RADIO_DEVICE) for k, v in inputs.items()}
206
+
207
+ # Try forward pass
208
+ outputs = radio_l_model(**inputs)
209
+ print(f" ✓ Forward pass successful!")
210
+ print(f" Output type: {type(outputs)}")
211
+ if hasattr(outputs, 'keys'):
212
+ print(f" Output keys: {outputs.keys()}")
213
+ else:
214
+ print(" ⚠️ No image processor available for test")
215
+ except Exception as e:
216
+ print(f" ✗ Forward pass failed: {e}")
217
+ traceback.print_exc()
218
+
219
+ print("\n✅ C model loading completed (with warnings)")
220
+ return True
221
+
222
  except Exception as e:
223
+ print(f"\n❌ C model loading failed completely: {e}")
 
224
  traceback.print_exc()
225
+ return False
226
+
227
+ def preload_ai_detector_debug():
228
+ """Debug version of AI detector preloading"""
229
+ global ai_detection_classifier
230
+
231
+ print("\n" + "=" * 60)
232
+ print("LOADING AI DETECTION CLASSIFIER")
233
+ print("=" * 60)
234
+
235
+ try:
236
+ print("\n1️⃣ Checking for Askhedi model...")
237
+ from huggingface_hub import hf_hub_download, list_repo_files
238
+
239
+ repo_id = "Askhedi/hedi_v0_mix"
240
+ print(f" Repository: {repo_id}")
241
+
242
+ # List files in repo
243
+ try:
244
+ files = list(list_repo_files(repo_id))
245
+ print(f" Files in repo: {files[:10]}") # Show first 10 files
246
+ pkl_files = [f for f in files if f.endswith('.pkl')]
247
+ print(f" PKL files found: {pkl_files}")
248
+ except Exception as e:
249
+ print(f" Could not list repo files: {e}")
250
+
251
+ print("\n2️⃣ Downloading classifier...")
252
+ classifier_path = hf_hub_download(
253
+ repo_id=repo_id,
254
+ filename="V1.pkl"
255
+ )
256
+ print(f" ✓ Downloaded to: {classifier_path}")
257
+
258
+ print("\n3️⃣ Loading classifier...")
259
+ ai_detection_classifier = test_pickle_file(classifier_path)
260
+
261
+ if ai_detection_classifier:
262
+ print("\n4️⃣ Testing classifier...")
263
+ try:
264
+ # Create dummy features
265
+ dummy_features = np.random.randn(1, 100) # Adjust size as needed
266
+ prediction = ai_detection_classifier.predict(dummy_features)
267
+ print(f" ✓ Prediction successful: {prediction}")
268
+
269
+ if hasattr(ai_detection_classifier, 'predict_proba'):
270
+ proba = ai_detection_classifier.predict_proba(dummy_features)
271
+ print(f" ✓ Probability shape: {proba.shape}")
272
+
273
+ except Exception as e:
274
+ print(f" ⚠️ Classifier test failed: {e}")
275
+ print(f" This might be due to incorrect feature dimensions")
276
+
277
+ print("\n✅ AI detection classifier loaded")
278
+ return True
279
+
280
+ except Exception as e:
281
+ print(f"\n❌ AI detector loading failed: {e}")
282
+ traceback.print_exc()
283
+ return False
284
+
285
+ # ============================================================
286
+ # MAIN GRADIO APP
287
+ # ============================================================
288
+
289
+ def analyze_image_debug(image):
290
+ """Debug version of image analysis"""
291
+ results = {
292
+ "status": "Processing",
293
+ "c_model": "Not loaded",
294
+ "ai_detection": "Not loaded",
295
+ "errors": []
296
+ }
297
+
298
+ try:
299
+ if radio_l_model and radio_l_image_processor:
300
+ results["c_model"] = "Model loaded and ready"
301
+ # Add actual processing here if needed
302
+ else:
303
+ results["c_model"] = "Model not loaded"
304
+
305
+ if ai_detection_classifier:
306
+ results["ai_detection"] = "Classifier loaded and ready"
307
+ # Add actual detection here if needed
308
+ else:
309
+ results["ai_detection"] = "Classifier not loaded"
310
+
311
+ results["status"] = "Complete"
312
+
313
+ except Exception as e:
314
+ results["errors"].append(str(e))
315
+ results["status"] = "Error"
316
+
317
+ return json.dumps(results, indent=2)
318
+
319
+ def create_gradio_interface():
320
+ """Create the Gradio interface"""
321
+ with gr.Blocks(title="Model Loading Debugger") as demo:
322
+ gr.Markdown("# Model Loading Debugger")
323
+ gr.Markdown("This interface helps debug model loading issues.")
324
+
325
+ with gr.Tab("Status"):
326
+ gr.Markdown("## Current Model Status")
327
+ status_text = gr.Textbox(
328
+ label="Model Status",
329
+ value=f"C Model: {'Loaded' if radio_l_model else 'Not loaded'}\n"
330
+ f"AI Detector: {'Loaded' if ai_detection_classifier else 'Not loaded'}",
331
+ lines=10
332
+ )
333
+
334
+ refresh_btn = gr.Button("Refresh Status")
335
+
336
+ def refresh_status():
337
+ status = []
338
+ status.append(f"C Model: {'✓ Loaded' if radio_l_model else '✗ Not loaded'}")
339
+ status.append(f"Image Processor: {'✓ Loaded' if radio_l_image_processor else '✗ Not loaded'}")
340
+ status.append(f"AI Detector: {'✓ Loaded' if ai_detection_classifier else '✗ Not loaded'}")
341
+ status.append(f"\nDevice: {RADIO_DEVICE}")
342
+
343
+ if radio_l_model:
344
+ status.append(f"Model type: {type(radio_l_model).__name__}")
345
+
346
+ return "\n".join(status)
347
+
348
+ refresh_btn.click(refresh_status, outputs=status_text)
349
+
350
+ with gr.Tab("Test Image"):
351
+ image_input = gr.Image(label="Upload Test Image", type="pil")
352
+ analyze_btn = gr.Button("Analyze")
353
+ output = gr.Textbox(label="Analysis Results", lines=10)
354
+
355
+ analyze_btn.click(analyze_image_debug, inputs=image_input, outputs=output)
356
+
357
+ return demo
358
+
359
+ # ============================================================
360
+ # MAIN EXECUTION
361
+ # ============================================================
362
+
363
+ if __name__ == "__main__":
364
+ print("\n" + "=" * 60)
365
+ print("STARTING MODEL PRELOAD")
366
+ print("=" * 60)
367
+
368
+ # Suppress specific warnings if needed
369
+ warnings.filterwarnings("ignore", message="Couldn't find the key")
370
+ warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
371
+
372
+ # Load models
373
+ c_model_success = preload_c_model_debug()
374
+ ai_detector_success = preload_ai_detector_debug()
375
+
376
+ # Summary
377
+ print("\n" + "=" * 60)
378
+ print("LOADING SUMMARY")
379
+ print("=" * 60)
380
+ print(f"C Model: {'✅ Success' if c_model_success else '❌ Failed'}")
381
+ print(f"AI Detector: {'✅ Success' if ai_detector_success else '❌ Failed'}")
382
+
383
+ # Launch Gradio
384
+ print("\n" + "=" * 60)
385
+ print("LAUNCHING GRADIO INTERFACE")
386
+ print("=" * 60)
387
+
388
+ demo = create_gradio_interface()
389
+ demo.launch(
390
+ server_name="0.0.0.0",
391
+ server_port=7860,
392
+ share=False,
393
+ debug=True
394
+ )