astrosbd commited on
Commit
8568bc5
·
verified ·
1 Parent(s): c387e2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -363
app.py CHANGED
@@ -1,394 +1,268 @@
1
  import os
2
- import sys
3
- import warnings
4
- import traceback
5
- import json
6
- import pickle
7
  import torch
8
- import numpy as np
9
- from PIL import Image
10
- import gradio as gr
11
- from typing import Optional, Tuple, Dict, Any
12
-
13
- # ============================================================
14
- # CONFIGURATION AND SETUP
15
- # ============================================================
16
 
17
  print("=" * 60)
18
- print("DEBUGGING APP.PY - Model Loading Diagnostics")
19
  print("=" * 60)
20
 
21
- # Environment info
22
- print("\n📋 Environment Information:")
23
- print(f"Python version: {sys.version}")
24
- print(f"PyTorch version: {torch.__version__}")
25
- print(f"CUDA available: {torch.cuda.is_available()}")
26
- if torch.cuda.is_available():
27
- print(f"CUDA version: {torch.version.cuda}")
28
- print(f"GPU: {torch.cuda.get_device_name(0)}")
29
-
30
- # Check environment variables
31
- print("\n📦 Environment Variables:")
32
- env_vars = ['MODEL_REPO', 'HF_TOKEN', 'CUDA_VISIBLE_DEVICES']
33
- for var in env_vars:
34
- value = os.getenv(var, 'NOT SET')
35
- if var == 'HF_TOKEN' and value != 'NOT SET':
36
- value = value[:10] + '...' if len(value) > 10 else value
37
- print(f" {var}: {value}")
38
-
39
- # Device configuration
40
- RADIO_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
- AI_DETECT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- print(f"\n🖥️ Using device: {RADIO_DEVICE}")
43
-
44
- # Global variables
45
- radio_l_image_processor = None
46
- radio_l_model = None
47
- ai_detection_classifier = None
48
-
49
- # ============================================================
50
- # DEBUGGING UTILITIES
51
- # ============================================================
52
-
53
- def inspect_model_architecture(model, model_name="Model"):
54
- """Inspect and print model architecture details"""
55
- print(f"\n🔍 Inspecting {model_name}:")
56
- print(f" Model type: {type(model).__name__}")
57
 
58
- # Get state dict
59
- try:
60
- state_dict = model.state_dict()
61
- print(f" Total parameters: {len(state_dict)}")
62
-
63
- # Look for specific patterns
64
- patterns = ['ls1', 'blocks', 'gamma', 'grandma', 'layer_norm', 'ln']
65
- for pattern in patterns:
66
- matching_keys = [k for k in state_dict.keys() if pattern in k.lower()]
67
- if matching_keys:
68
- print(f" Keys containing '{pattern}': {len(matching_keys)}")
69
- print(f" First 3: {matching_keys[:3]}")
70
-
71
- # Show first 10 keys
72
- print(f" First 10 state dict keys:")
73
- for i, key in enumerate(list(state_dict.keys())[:10]):
74
- print(f" {i+1}. {key}")
75
-
76
- except Exception as e:
77
- print(f" Could not inspect state dict: {e}")
78
-
79
- # Check for specific attributes
80
- attrs_to_check = ['blocks', 'layers', 'encoder', 'decoder', 'visual', 'text']
81
- print(f" Model attributes:")
82
- for attr in attrs_to_check:
83
- if hasattr(model, attr):
84
- print(f" Has '{attr}'")
85
-
86
- def test_pickle_file(filepath):
87
- """Test loading a pickle file and inspect its contents"""
88
- print(f"\n🥒 Testing pickle file: {filepath}")
89
- try:
90
- with open(filepath, 'rb') as f:
91
- obj = pickle.load(f)
92
- print(f" ✓ Successfully loaded pickle")
93
- print(f" Object type: {type(obj).__name__}")
94
-
95
- if hasattr(obj, '__dict__'):
96
- print(f" Object attributes: {list(obj.__dict__.keys())[:5]}")
97
-
98
- if hasattr(obj, 'get_params'):
99
- params = obj.get_params()
100
- print(f" Model parameters: {list(params.keys())[:5]}")
101
-
102
- return obj
103
- except Exception as e:
104
- print(f" ✗ Failed to load pickle: {e}")
105
- return None
106
-
107
- # ============================================================
108
- # MODEL LOADING FUNCTIONS
109
- # ============================================================
110
 
111
- def preload_c_model_debug():
112
- """Debug version of C model preloading"""
113
- global radio_l_image_processor, radio_l_model
114
 
115
- print("\n" + "=" * 60)
116
- print("LOADING C MODEL")
117
- print("=" * 60)
118
 
119
- hf_repo = os.getenv('MODEL_REPO', 'fallback')
120
- print(f"Repository: {hf_repo}")
 
 
 
121
 
122
- if not hf_repo or hf_repo == 'fallback':
123
- print("⚠️ No MODEL_REPO environment variable set")
124
- return False
125
 
126
- try:
127
- print("\n1️⃣ Importing transformers...")
128
- from transformers import AutoModel, CLIPImageProcessor, AutoConfig
129
- print(" ✓ Imports successful")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # Try to load config first
132
- print("\n2️⃣ Loading model config...")
133
- try:
134
- config = AutoConfig.from_pretrained(hf_repo, trust_remote_code=True)
135
- print(f" ✓ Config loaded: {type(config).__name__}")
136
- if hasattr(config, 'architectures'):
137
- print(f" Architectures: {config.architectures}")
138
- except Exception as e:
139
- print(f" ⚠️ Could not load config: {e}")
140
 
141
- # Load image processor
142
- print("\n3️⃣ Loading image processor...")
143
- try:
144
- radio_l_image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
145
- print(f" ✓ Image processor loaded: {type(radio_l_image_processor).__name__}")
146
- except Exception as e:
147
- print(f" ✗ Failed to load image processor: {e}")
148
- # Try alternative processors
149
- print(" Trying alternative processors...")
150
- from transformers import AutoImageProcessor
151
- try:
152
- radio_l_image_processor = AutoImageProcessor.from_pretrained(hf_repo)
153
- print(f" ✓ Alternative processor loaded: {type(radio_l_image_processor).__name__}")
154
- except:
155
- print(" ✗ All processor attempts failed")
156
 
157
- # Load model with detailed error catching
158
- print("\n4️⃣ Loading model...")
159
- try:
160
- # First attempt - standard loading
161
- radio_l_model = AutoModel.from_pretrained(
162
- hf_repo,
163
- trust_remote_code=True,
164
- torch_dtype=torch.float32,
165
- device_map='auto' if torch.cuda.is_available() else None
166
- )
167
- print(f" ✓ Model loaded: {type(radio_l_model).__name__}")
168
 
169
- except Exception as e:
170
- print(f" ✗ Standard loading failed: {e}")
171
- print(" Trying with force_download...")
172
 
173
- # Second attempt - force download
174
- try:
175
- radio_l_model = AutoModel.from_pretrained(
176
- hf_repo,
177
- trust_remote_code=True,
178
- force_download=True,
179
- torch_dtype=torch.float32
180
- )
181
- print(f" ✓ Model loaded with force_download")
182
- except Exception as e2:
183
- print(f" ✗ Force download failed: {e2}")
184
- raise
185
-
186
- # Inspect the loaded model
187
- inspect_model_architecture(radio_l_model, "C Model")
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- # Move to device
190
- print(f"\n5️⃣ Moving model to {RADIO_DEVICE}...")
191
- if RADIO_DEVICE.type != 'cpu':
192
- radio_l_model = radio_l_model.to(RADIO_DEVICE)
193
- radio_l_model.eval()
194
- print(" ✓ Model moved and set to eval mode")
195
 
196
- # Test forward pass
197
- print("\n6️⃣ Testing forward pass...")
198
  try:
199
- with torch.no_grad():
200
- # Create a dummy image
201
- dummy_image = Image.new('RGB', (224, 224), color='white')
202
- if radio_l_image_processor:
203
- inputs = radio_l_image_processor(dummy_image, return_tensors="pt")
204
- if RADIO_DEVICE.type != 'cpu':
205
- inputs = {k: v.to(RADIO_DEVICE) for k, v in inputs.items()}
206
-
207
- # Try forward pass
208
- outputs = radio_l_model(**inputs)
209
- print(f" ✓ Forward pass successful!")
210
- print(f" Output type: {type(outputs)}")
211
- if hasattr(outputs, 'keys'):
212
- print(f" Output keys: {outputs.keys()}")
213
- else:
214
- print(" ⚠️ No image processor available for test")
215
- except Exception as e:
216
- print(f" ✗ Forward pass failed: {e}")
217
- traceback.print_exc()
218
-
219
- print("\n✅ C model loading completed (with warnings)")
220
- return True
221
 
222
- except Exception as e:
223
- print(f"\n❌ C model loading failed completely: {e}")
224
- traceback.print_exc()
225
- return False
226
 
227
- def preload_ai_detector_debug():
228
- """Debug version of AI detector preloading"""
229
- global ai_detection_classifier
230
 
231
- print("\n" + "=" * 60)
232
- print("LOADING AI DETECTION CLASSIFIER")
233
- print("=" * 60)
234
 
235
- try:
236
- print("\n1️⃣ Checking for Askhedi model...")
237
- from huggingface_hub import hf_hub_download, list_repo_files
238
-
239
- repo_id = "Askhedi/hedi_v0_mix"
240
- print(f" Repository: {repo_id}")
241
-
242
- # List files in repo
243
- try:
244
- files = list(list_repo_files(repo_id))
245
- print(f" Files in repo: {files[:10]}") # Show first 10 files
246
- pkl_files = [f for f in files if f.endswith('.pkl')]
247
- print(f" PKL files found: {pkl_files}")
248
- except Exception as e:
249
- print(f" Could not list repo files: {e}")
250
-
251
- print("\n2️⃣ Downloading classifier...")
252
- classifier_path = hf_hub_download(
253
- repo_id=repo_id,
254
- filename="V1.pkl"
255
- )
256
- print(f" ✓ Downloaded to: {classifier_path}")
257
-
258
- print("\n3️⃣ Loading classifier...")
259
- ai_detection_classifier = test_pickle_file(classifier_path)
260
-
261
- if ai_detection_classifier:
262
- print("\n4️⃣ Testing classifier...")
263
- try:
264
- # Create dummy features
265
- dummy_features = np.random.randn(1, 100) # Adjust size as needed
266
- prediction = ai_detection_classifier.predict(dummy_features)
267
- print(f" ✓ Prediction successful: {prediction}")
268
-
269
- if hasattr(ai_detection_classifier, 'predict_proba'):
270
- proba = ai_detection_classifier.predict_proba(dummy_features)
271
- print(f" ✓ Probability shape: {proba.shape}")
272
-
273
- except Exception as e:
274
- print(f" ⚠️ Classifier test failed: {e}")
275
- print(f" This might be due to incorrect feature dimensions")
276
-
277
- print("\n✅ AI detection classifier loaded")
278
- return True
279
-
280
- except Exception as e:
281
- print(f"\n❌ AI detector loading failed: {e}")
282
- traceback.print_exc()
283
- return False
284
-
285
- # ============================================================
286
- # MAIN GRADIO APP
287
- # ============================================================
288
 
289
- def analyze_image_debug(image):
290
- """Debug version of image analysis"""
291
- results = {
292
- "status": "Processing",
293
- "c_model": "Not loaded",
294
- "ai_detection": "Not loaded",
295
- "errors": []
296
- }
297
 
298
- try:
299
- if radio_l_model and radio_l_image_processor:
300
- results["c_model"] = "Model loaded and ready"
301
- # Add actual processing here if needed
302
- else:
303
- results["c_model"] = "Model not loaded"
304
-
305
- if ai_detection_classifier:
306
- results["ai_detection"] = "Classifier loaded and ready"
307
- # Add actual detection here if needed
308
- else:
309
- results["ai_detection"] = "Classifier not loaded"
310
-
311
- results["status"] = "Complete"
312
-
313
- except Exception as e:
314
- results["errors"].append(str(e))
315
- results["status"] = "Error"
316
 
317
- return json.dumps(results, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- def create_gradio_interface():
320
- """Create the Gradio interface"""
321
- with gr.Blocks(title="Model Loading Debugger") as demo:
322
- gr.Markdown("# Model Loading Debugger")
323
- gr.Markdown("This interface helps debug model loading issues.")
324
 
325
- with gr.Tab("Status"):
326
- gr.Markdown("## Current Model Status")
327
- status_text = gr.Textbox(
328
- label="Model Status",
329
- value=f"C Model: {'Loaded' if radio_l_model else 'Not loaded'}\n"
330
- f"AI Detector: {'Loaded' if ai_detection_classifier else 'Not loaded'}",
331
- lines=10
332
- )
333
-
334
- refresh_btn = gr.Button("Refresh Status")
335
-
336
- def refresh_status():
337
- status = []
338
- status.append(f"C Model: {'✓ Loaded' if radio_l_model else '✗ Not loaded'}")
339
- status.append(f"Image Processor: {'✓ Loaded' if radio_l_image_processor else '✗ Not loaded'}")
340
- status.append(f"AI Detector: {'✓ Loaded' if ai_detection_classifier else '✗ Not loaded'}")
341
- status.append(f"\nDevice: {RADIO_DEVICE}")
342
-
343
- if radio_l_model:
344
- status.append(f"Model type: {type(radio_l_model).__name__}")
345
-
346
- return "\n".join(status)
347
-
348
- refresh_btn.click(refresh_status, outputs=status_text)
349
 
350
- with gr.Tab("Test Image"):
351
- image_input = gr.Image(label="Upload Test Image", type="pil")
352
- analyze_btn = gr.Button("Analyze")
353
- output = gr.Textbox(label="Analysis Results", lines=10)
354
-
355
- analyze_btn.click(analyze_image_debug, inputs=image_input, outputs=output)
356
 
357
- return demo
358
-
359
- # ============================================================
360
- # MAIN EXECUTION
361
- # ============================================================
362
-
363
- if __name__ == "__main__":
364
- print("\n" + "=" * 60)
365
- print("STARTING MODEL PRELOAD")
366
- print("=" * 60)
367
-
368
- # Suppress specific warnings if needed
369
- warnings.filterwarnings("ignore", message="Couldn't find the key")
370
- warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
371
-
372
- # Load models
373
- c_model_success = preload_c_model_debug()
374
- ai_detector_success = preload_ai_detector_debug()
375
-
376
- # Summary
377
- print("\n" + "=" * 60)
378
- print("LOADING SUMMARY")
379
- print("=" * 60)
380
- print(f"C Model: {'✅ Success' if c_model_success else '❌ Failed'}")
381
- print(f"AI Detector: {'✅ Success' if ai_detector_success else '❌ Failed'}")
382
-
383
- # Launch Gradio
384
- print("\n" + "=" * 60)
385
- print("LAUNCHING GRADIO INTERFACE")
386
- print("=" * 60)
387
-
388
- demo = create_gradio_interface()
389
- demo.launch(
390
- server_name="0.0.0.0",
391
- server_port=7860,
392
- share=False,
393
- debug=True
394
- )
 
1
  import os
 
 
 
 
 
2
  import torch
3
+ import json
4
+ from huggingface_hub import hf_hub_download
5
+ import safetensors.torch
 
 
 
 
 
6
 
7
  print("=" * 60)
8
+ print("C-RADIOv3-B Model Deep Inspection")
9
  print("=" * 60)
10
 
11
+ # Step 1: Download and inspect the model file directly
12
+ def inspect_model_weights():
13
+ """Directly inspect the safetensors file to see what keys exist"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ print("\n📥 Downloading model weights for inspection...")
16
+
17
+ # Download the model file
18
+ model_path = hf_hub_download(
19
+ repo_id="nvidia/C-RADIOv3-B",
20
+ filename="model.safetensors"
21
+ )
22
+
23
+ print(f"Downloaded to: {model_path}")
24
+
25
+ # Load the safetensors file
26
+ print("\n🔍 Inspecting model weights...")
27
+ state_dict = safetensors.torch.load_file(model_path)
28
+
29
+ # Analyze the keys
30
+ all_keys = list(state_dict.keys())
31
+ print(f"Total keys in model: {len(all_keys)}")
32
+
33
+ # Look for ls1 related keys
34
+ ls1_keys = [k for k in all_keys if 'ls1' in k.lower()]
35
+ ls_keys = [k for k in all_keys if 'ls' in k.lower()]
36
+ gamma_keys = [k for k in all_keys if 'gamma' in k.lower()]
37
+ block_keys = [k for k in all_keys if k.startswith('blocks.')]
38
+
39
+ print(f"\n📊 Key Analysis:")
40
+ print(f" Keys with 'ls1': {len(ls1_keys)}")
41
+ print(f" Keys with 'ls': {len(ls_keys)}")
42
+ print(f" Keys with 'gamma': {len(gamma_keys)}")
43
+ print(f" Keys starting with 'blocks.': {len(block_keys)}")
44
+
45
+ # Show first few block keys
46
+ print(f"\n📝 First 20 block keys:")
47
+ for i, key in enumerate(sorted([k for k in all_keys if k.startswith('blocks.0.')])[:20]):
48
+ print(f" {key}")
49
+
50
+ # Check what's actually in blocks.0
51
+ blocks_0_keys = [k for k in all_keys if k.startswith('blocks.0.')]
52
+ print(f"\n🔎 All blocks.0 submodules:")
53
+ submodules = set()
54
+ for key in blocks_0_keys:
55
+ parts = key.split('.')
56
+ if len(parts) > 2:
57
+ submodules.add(parts[2])
58
+ for submodule in sorted(submodules):
59
+ count = len([k for k in blocks_0_keys if f'blocks.0.{submodule}.' in k])
60
+ print(f" blocks.0.{submodule}.*: {count} parameters")
61
+
62
+ return state_dict, all_keys
 
 
 
 
63
 
64
+ # Step 2: Check the model architecture expectations
65
+ def inspect_model_code():
66
+ """Download and inspect the model code to understand what it expects"""
67
 
68
+ print("\n📜 Downloading model code...")
 
 
69
 
70
+ # Download the dinov2_arch.py file
71
+ dinov2_path = hf_hub_download(
72
+ repo_id="nvidia/C-RADIOv3-B",
73
+ filename="dinov2_arch.py"
74
+ )
75
 
76
+ print(f"Downloaded dinov2_arch.py to: {dinov2_path}")
 
 
77
 
78
+ # Read the problematic part of the code
79
+ with open(dinov2_path, 'r') as f:
80
+ lines = f.readlines()
81
+
82
+ # Find the error location (around line 309)
83
+ print("\n🔍 Code around line 309 (error location):")
84
+ for i in range(max(0, 308-10), min(len(lines), 308+10)):
85
+ if i == 308: # Line 309 (0-indexed)
86
+ print(f">>> {i+1}: {lines[i].rstrip()}")
87
+ else:
88
+ print(f" {i+1}: {lines[i].rstrip()}")
89
+
90
+ # Look for _load_from_state_dict method
91
+ print("\n📖 Looking for _load_from_state_dict method...")
92
+ for i, line in enumerate(lines):
93
+ if '_load_from_state_dict' in line:
94
+ print(f"Found at line {i+1}: {line.rstrip()}")
95
+ # Show context
96
+ for j in range(max(0, i-2), min(len(lines), i+15)):
97
+ print(f" {j+1}: {lines[j].rstrip()}")
98
+ break
99
+
100
+ # Step 3: Create a working loader
101
+ def create_fixed_loader():
102
+ """Create a fixed loading function that handles the missing keys"""
103
+
104
+ print("\n🔧 Creating Fixed Model Loader...")
105
+
106
+ # Create a custom model loading function
107
+ code = '''
108
+ import torch
109
+ from transformers import AutoModel, AutoConfig
110
+ import warnings
111
+
112
+ class RADIOModelFixed:
113
+ @staticmethod
114
+ def from_pretrained(repo_id="nvidia/C-RADIOv3-B"):
115
+ """Load RADIO model with compatibility fixes"""
116
 
117
+ print("Loading with compatibility fixes...")
 
 
 
 
 
 
 
 
118
 
119
+ # First, modify the environment to skip the problematic check
120
+ import sys
121
+ import transformers.modeling_utils as mu
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ # Store original function
124
+ original_load = mu._load_state_dict_into_meta_model
125
+
126
+ def patched_load(model, state_dict, device_map=None, offload_folder=None,
127
+ dtype=None, offload_state_dict=None, tie_weights=True,
128
+ **kwargs):
129
+ """Patched loader that handles missing ls1 keys"""
 
 
 
 
130
 
131
+ # Create a modified state dict with dummy ls1 keys if needed
132
+ modified_state = state_dict.copy()
 
133
 
134
+ # Check if we need to add dummy ls1 keys
135
+ block_keys = [k for k in state_dict.keys() if k.startswith('blocks.')]
136
+ if block_keys and not any('ls1' in k for k in block_keys):
137
+ print(" Adding compatibility keys for ls1 layers...")
138
+
139
+ # Find all blocks
140
+ block_indices = set()
141
+ for key in block_keys:
142
+ parts = key.split('.')
143
+ if len(parts) > 1 and parts[1].isdigit():
144
+ block_indices.add(int(parts[1]))
145
+
146
+ # Add dummy ls1 parameters for each block
147
+ for idx in block_indices:
148
+ # These will be ignored but prevent the error
149
+ if f'blocks.{idx}.norm1.weight' in state_dict:
150
+ # Use norm1 as a template for shape
151
+ template = state_dict[f'blocks.{idx}.norm1.weight']
152
+ modified_state[f'blocks.{idx}.ls1.gamma'] = torch.ones_like(template)
153
+ else:
154
+ # Default to scalar
155
+ modified_state[f'blocks.{idx}.ls1.gamma'] = torch.tensor(1.0)
156
+
157
+ # Call original with modified state
158
+ return original_load(model, modified_state, device_map, offload_folder,
159
+ dtype, offload_state_dict, tie_weights, **kwargs)
160
 
161
+ # Temporarily replace the function
162
+ mu._load_state_dict_into_meta_model = patched_load
 
 
 
 
163
 
 
 
164
  try:
165
+ # Load the model
166
+ model = AutoModel.from_pretrained(
167
+ repo_id,
168
+ trust_remote_code=True,
169
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
170
+ )
171
+ print(" Model loaded successfully with compatibility fixes!")
172
+
173
+ finally:
174
+ # Restore original function
175
+ mu._load_state_dict_into_meta_model = original_load
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ return model
 
 
 
178
 
179
+ # Usage:
180
+ model = RADIOModelFixed.from_pretrained()
181
+ '''
182
 
183
+ print(code)
 
 
184
 
185
+ # Save to file
186
+ with open('radio_loader_fixed.py', 'w') as f:
187
+ f.write(code)
188
+
189
+ print("\n✅ Fixed loader saved to 'radio_loader_fixed.py'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ # Step 4: Alternative loading approach
192
+ def try_alternative_loading():
193
+ """Try loading the model with different strategies"""
 
 
 
 
 
194
 
195
+ print("\n🔄 Trying Alternative Loading Methods...")
196
+
197
+ from transformers import AutoModel, AutoConfig
198
+ import transformers.modeling_utils
199
+
200
+ repo_id = "nvidia/C-RADIOv3-B"
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ # Method 1: Load config first and check architecture
203
+ print("\n1️⃣ Checking model config...")
204
+ config = AutoConfig.from_pretrained(repo_id, trust_remote_code=True)
205
+ print(f" Architecture: {config.architectures}")
206
+ print(f" Model type: {config.model_type}")
207
+
208
+ # Method 2: Try loading without state dict verification
209
+ print("\n2️⃣ Attempting to load with strict=False...")
210
+
211
+ # Monkey-patch the DINOv2 architecture
212
+ import importlib.util
213
+ import sys
214
+
215
+ # Download the dinov2_arch.py
216
+ dinov2_path = hf_hub_download(repo_id=repo_id, filename="dinov2_arch.py")
217
+
218
+ # Load it as a module
219
+ spec = importlib.util.spec_from_file_location("dinov2_arch_patched", dinov2_path)
220
+ dinov2_module = importlib.util.module_from_spec(spec)
221
+
222
+ # Patch the _load_from_state_dict method before loading
223
+ original_code = open(dinov2_path, 'r').read()
224
+
225
+ # Replace the error-raising code
226
+ patched_code = original_code.replace(
227
+ 'raise KeyError(f"Couldn\'t find the key {key_a} nor {key_b} in the state dict!")',
228
+ '''
229
+ print(f" Warning: Missing keys {key_a} and {key_b}, using defaults")
230
+ # Use identity/ones as default
231
+ if "gamma" in key_a:
232
+ setattr(self, key_a.split(".")[-1], torch.nn.Parameter(torch.ones(self.dim)))
233
+ elif "beta" in key_a:
234
+ setattr(self, key_a.split(".")[-1], torch.nn.Parameter(torch.zeros(self.dim)))
235
+ return
236
+ '''
237
+ )
238
+
239
+ # Save patched version
240
+ patched_path = "dinov2_arch_patched.py"
241
+ with open(patched_path, 'w') as f:
242
+ f.write(patched_code)
243
+
244
+ print(f" Created patched architecture file: {patched_path}")
245
+
246
+ print("\n✅ Alternative loading methods prepared")
247
 
248
+ # Run all inspections
249
+ if __name__ == "__main__":
250
+ try:
251
+ # Step 1: Inspect weights
252
+ state_dict, keys = inspect_model_weights()
253
 
254
+ # Step 2: Inspect code
255
+ inspect_model_code()
256
+
257
+ # Step 3: Create fixed loader
258
+ create_fixed_loader()
259
+
260
+ # Step 4: Try alternatives
261
+ try_alternative_loading()
262
+
263
+ print("\n" + "=" * 60)
264
+ print("DIAGNOSIS COMPLETE")
265
+ print("=" * 60)
 
 
 
 
 
 
 
 
 
 
 
 
266
 
 
 
 
 
 
 
267
 
268
+