Mustafa Akcanca commited on
Commit
90f2bb7
Β·
1 Parent(s): 20d6258

Fix weights downloader

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +16 -16
  3. app_requirements.txt +1 -1
  4. src/utils/weight_downloader.py +76 -65
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ”
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -204,21 +204,7 @@ def create_interface():
204
  """Create and configure the Gradio interface."""
205
 
206
  with gr.Blocks(
207
- title="Forensic Image Analysis Agent",
208
- theme=gr.themes.Soft(),
209
- css="""
210
- .gradio-container {
211
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
212
- }
213
- .main-header {
214
- text-align: center;
215
- padding: 20px;
216
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
217
- color: white;
218
- border-radius: 10px;
219
- margin-bottom: 20px;
220
- }
221
- """
222
  ) as demo:
223
  gr.HTML("""
224
  <div class="main-header">
@@ -337,6 +323,20 @@ if __name__ == "__main__":
337
  demo.launch(
338
  server_name="0.0.0.0",
339
  server_port=7860,
340
- share=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  )
342
 
 
204
  """Create and configure the Gradio interface."""
205
 
206
  with gr.Blocks(
207
+ title="Forensic Image Analysis Agent"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  ) as demo:
209
  gr.HTML("""
210
  <div class="main-header">
 
323
  demo.launch(
324
  server_name="0.0.0.0",
325
  server_port=7860,
326
+ share=False,
327
+ theme=gr.themes.Soft(),
328
+ css="""
329
+ .gradio-container {
330
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
331
+ }
332
+ .main-header {
333
+ text-align: center;
334
+ padding: 20px;
335
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
336
+ color: white;
337
+ border-radius: 10px;
338
+ margin-bottom: 20px;
339
+ }
340
+ """
341
  )
342
 
app_requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  # Gradio for UI
2
- gradio>=4.0.0
3
 
4
  # Core LLM and agent dependencies
5
  langchain>=0.1.0
 
1
  # Gradio for UI
2
+ gradio>=6.0.2
3
 
4
  # Core LLM and agent dependencies
5
  langchain>=0.1.0
src/utils/weight_downloader.py CHANGED
@@ -4,6 +4,7 @@ Utility to download and verify TruFor model weights automatically.
4
 
5
  import hashlib
6
  import os
 
7
  import zipfile
8
  from pathlib import Path
9
  from typing import Optional, Tuple
@@ -64,6 +65,13 @@ def ensure_trufor_weights(workspace_root: Optional[Path] = None, auto_download:
64
  """
65
  Ensure TruFor weights are available, downloading if necessary.
66
 
 
 
 
 
 
 
 
67
  Args:
68
  workspace_root: Root directory of the workspace. If None, tries to detect it.
69
  auto_download: If True, automatically download weights if missing.
@@ -82,19 +90,14 @@ def ensure_trufor_weights(workspace_root: Optional[Path] = None, auto_download:
82
 
83
  # Check if weights already exist
84
  if weights_path.exists():
85
- # Verify MD5 if possible
86
- try:
87
- md5_hash = _calculate_md5(weights_path)
88
- if md5_hash.lower() == TRUFOR_WEIGHTS_MD5.lower():
89
- return True, f"βœ… TruFor weights found and verified at {weights_path}"
90
- else:
91
- print(f"⚠️ MD5 mismatch. Expected: {TRUFOR_WEIGHTS_MD5}, Got: {md5_hash}")
92
- print(" Weights file exists but MD5 doesn't match. Consider re-downloading.")
93
- # Still return True - file exists, just not verified
94
- return True, f"⚠️ TruFor weights found at {weights_path} but MD5 verification failed"
95
- except Exception as e:
96
- # If MD5 check fails, still return True if file exists
97
- return True, f"βœ… TruFor weights found at {weights_path} (MD5 check skipped: {e})"
98
 
99
  # Weights don't exist
100
  if not auto_download:
@@ -118,73 +121,81 @@ def ensure_trufor_weights(workspace_root: Optional[Path] = None, auto_download:
118
  if not _download_file(TRUFOR_WEIGHTS_URL, zip_path):
119
  return False, f"❌ Failed to download weights from {TRUFOR_WEIGHTS_URL}"
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Extract zip file
 
122
  print(f"πŸ“¦ Extracting weights...")
123
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
124
  # Find the weights file in the zip
125
  members = zip_ref.namelist()
126
  weights_member = None
127
 
128
- # Look for trufor.pth.tar in the zip
129
  for member in members:
130
- if member.endswith(TRUFOR_WEIGHTS_FILENAME) or member.endswith(f"/{TRUFOR_WEIGHTS_FILENAME}"):
131
  weights_member = member
132
  break
133
 
134
- if weights_member:
135
- # Extract just the weights file
136
- zip_ref.extract(weights_member, weights_dir)
137
-
138
- # Move to final location if needed
139
- extracted_path = weights_dir / weights_member
140
- if extracted_path != weights_path:
141
- if weights_path.exists():
142
- weights_path.unlink()
143
- extracted_path.rename(weights_path)
144
-
145
- # Clean up zip file
146
  zip_path.unlink()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Verify MD5
149
  try:
150
- md5_hash = _calculate_md5(weights_path)
151
- if md5_hash.lower() == TRUFOR_WEIGHTS_MD5.lower():
152
- return True, f"βœ… TruFor weights downloaded and verified at {weights_path}"
153
- else:
154
- return False, (
155
- f"❌ Downloaded weights MD5 mismatch!\n"
156
- f" Expected: {TRUFOR_WEIGHTS_MD5}\n"
157
- f" Got: {md5_hash}\n"
158
- f" File at: {weights_path}"
159
- )
160
- except Exception as e:
161
- return True, f"βœ… TruFor weights downloaded at {weights_path} (MD5 check failed: {e})"
 
 
 
 
 
162
  else:
163
- # Extract all files and look for weights
164
- zip_ref.extractall(weights_dir)
165
- zip_path.unlink()
166
-
167
- # Look for the weights file in extracted files
168
- for root, dirs, files in os.walk(weights_dir):
169
- for file in files:
170
- if file == TRUFOR_WEIGHTS_FILENAME:
171
- found_path = Path(root) / file
172
- if found_path != weights_path:
173
- if weights_path.exists():
174
- weights_path.unlink()
175
- found_path.rename(weights_path)
176
-
177
- # Verify MD5
178
- try:
179
- md5_hash = _calculate_md5(weights_path)
180
- if md5_hash.lower() == TRUFOR_WEIGHTS_MD5.lower():
181
- return True, f"βœ… TruFor weights downloaded and verified at {weights_path}"
182
- else:
183
- return False, f"❌ Downloaded weights MD5 mismatch: {md5_hash}"
184
- except Exception as e:
185
- return True, f"βœ… TruFor weights downloaded at {weights_path} (MD5 check failed: {e})"
186
-
187
- return False, f"❌ Could not find {TRUFOR_WEIGHTS_FILENAME} in downloaded zip"
188
 
189
  except Exception as e:
190
  # Clean up on error
 
4
 
5
  import hashlib
6
  import os
7
+ import shutil
8
  import zipfile
9
  from pathlib import Path
10
  from typing import Optional, Tuple
 
65
  """
66
  Ensure TruFor weights are available, downloading if necessary.
67
 
68
+ Downloads TruFor_weights.zip from the official source, verifies MD5,
69
+ and extracts trufor.pth.tar to weights/trufor/trufor.pth.tar.
70
+
71
+ Zip structure: weights/trufor.pth.tar
72
+ Final path: projectroot/weights/trufor/trufor.pth.tar
73
+ MD5 is verified on the zip file (not the tar).
74
+
75
  Args:
76
  workspace_root: Root directory of the workspace. If None, tries to detect it.
77
  auto_download: If True, automatically download weights if missing.
 
90
 
91
  # Check if weights already exist
92
  if weights_path.exists():
93
+ # File exists - we can't verify MD5 since it's for the zip, not the tar
94
+ file_size = weights_path.stat().st_size
95
+ if file_size > 0:
96
+ return True, f"βœ… TruFor weights found at {weights_path} ({file_size / 1024 / 1024:.1f} MB)"
97
+ else:
98
+ # Empty file - delete and re-download
99
+ weights_path.unlink()
100
+ print("⚠️ Found empty weights file, re-downloading...")
 
 
 
 
 
101
 
102
  # Weights don't exist
103
  if not auto_download:
 
121
  if not _download_file(TRUFOR_WEIGHTS_URL, zip_path):
122
  return False, f"❌ Failed to download weights from {TRUFOR_WEIGHTS_URL}"
123
 
124
+ # Verify MD5 of the zip file immediately after download
125
+ print(f"πŸ” Verifying download integrity (MD5)...")
126
+ try:
127
+ zip_md5 = _calculate_md5(zip_path)
128
+ if zip_md5.lower() != TRUFOR_WEIGHTS_MD5.lower():
129
+ zip_path.unlink()
130
+ return False, (
131
+ f"❌ Downloaded zip MD5 mismatch!\n"
132
+ f" Expected: {TRUFOR_WEIGHTS_MD5}\n"
133
+ f" Got: {zip_md5}\n"
134
+ f" The download may be corrupted. Please try again."
135
+ )
136
+ print(f"βœ… MD5 verified: {zip_md5}")
137
+ except Exception as e:
138
+ print(f"⚠️ MD5 verification failed: {e}. Continuing with extraction...")
139
+
140
  # Extract zip file
141
+ # Zip structure: weights/trufor.pth.tar
142
  print(f"πŸ“¦ Extracting weights...")
143
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
144
  # Find the weights file in the zip
145
  members = zip_ref.namelist()
146
  weights_member = None
147
 
148
+ # Look for trufor.pth.tar in the zip (expected: weights/trufor.pth.tar)
149
  for member in members:
150
+ if member.endswith(TRUFOR_WEIGHTS_FILENAME):
151
  weights_member = member
152
  break
153
 
154
+ if not weights_member:
 
 
 
 
 
 
 
 
 
 
 
155
  zip_path.unlink()
156
+ return False, (
157
+ f"❌ Could not find {TRUFOR_WEIGHTS_FILENAME} in downloaded zip.\n"
158
+ f" Zip contents: {members}"
159
+ )
160
+
161
+ # Extract just the weights file to a temp location
162
+ # zip_ref.extract will create the nested directory structure
163
+ zip_ref.extract(weights_member, weights_dir)
164
+
165
+ # Move from extracted location to final location
166
+ # e.g., weights/trufor/weights/trufor.pth.tar -> weights/trufor/trufor.pth.tar
167
+ extracted_path = weights_dir / weights_member
168
+
169
+ if extracted_path != weights_path:
170
+ # Move to final location
171
+ if weights_path.exists():
172
+ weights_path.unlink()
173
+ shutil.move(str(extracted_path), str(weights_path))
174
 
175
+ # Clean up any empty directories left from extraction
176
  try:
177
+ # Remove the 'weights' folder if it was created inside weights_dir
178
+ extracted_parent = extracted_path.parent
179
+ while extracted_parent != weights_dir and extracted_parent.exists():
180
+ if not any(extracted_parent.iterdir()):
181
+ extracted_parent.rmdir()
182
+ extracted_parent = extracted_parent.parent
183
+ except Exception:
184
+ pass # Ignore cleanup errors
185
+
186
+ # Clean up zip file
187
+ zip_path.unlink()
188
+
189
+ # Verify final file exists and has content
190
+ if weights_path.exists():
191
+ file_size = weights_path.stat().st_size
192
+ if file_size > 0:
193
+ return True, f"βœ… TruFor weights downloaded successfully to {weights_path} ({file_size / 1024 / 1024:.1f} MB)"
194
  else:
195
+ weights_path.unlink()
196
+ return False, f"❌ Extracted weights file is empty"
197
+ else:
198
+ return False, f"❌ Failed to extract weights to {weights_path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  except Exception as e:
201
  # Clean up on error