Mustafa Akcanca commited on
Commit Β·
90f2bb7
1
Parent(s): 20d6258
Fix weights downloader
Browse files- README.md +1 -1
- app.py +16 -16
- app_requirements.txt +1 -1
- src/utils/weight_downloader.py +76 -65
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: π
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.0.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
app.py
CHANGED
|
@@ -204,21 +204,7 @@ def create_interface():
|
|
| 204 |
"""Create and configure the Gradio interface."""
|
| 205 |
|
| 206 |
with gr.Blocks(
|
| 207 |
-
title="Forensic Image Analysis Agent"
|
| 208 |
-
theme=gr.themes.Soft(),
|
| 209 |
-
css="""
|
| 210 |
-
.gradio-container {
|
| 211 |
-
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 212 |
-
}
|
| 213 |
-
.main-header {
|
| 214 |
-
text-align: center;
|
| 215 |
-
padding: 20px;
|
| 216 |
-
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 217 |
-
color: white;
|
| 218 |
-
border-radius: 10px;
|
| 219 |
-
margin-bottom: 20px;
|
| 220 |
-
}
|
| 221 |
-
"""
|
| 222 |
) as demo:
|
| 223 |
gr.HTML("""
|
| 224 |
<div class="main-header">
|
|
@@ -337,6 +323,20 @@ if __name__ == "__main__":
|
|
| 337 |
demo.launch(
|
| 338 |
server_name="0.0.0.0",
|
| 339 |
server_port=7860,
|
| 340 |
-
share=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
)
|
| 342 |
|
|
|
|
| 204 |
"""Create and configure the Gradio interface."""
|
| 205 |
|
| 206 |
with gr.Blocks(
|
| 207 |
+
title="Forensic Image Analysis Agent"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
) as demo:
|
| 209 |
gr.HTML("""
|
| 210 |
<div class="main-header">
|
|
|
|
| 323 |
demo.launch(
|
| 324 |
server_name="0.0.0.0",
|
| 325 |
server_port=7860,
|
| 326 |
+
share=False,
|
| 327 |
+
theme=gr.themes.Soft(),
|
| 328 |
+
css="""
|
| 329 |
+
.gradio-container {
|
| 330 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 331 |
+
}
|
| 332 |
+
.main-header {
|
| 333 |
+
text-align: center;
|
| 334 |
+
padding: 20px;
|
| 335 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 336 |
+
color: white;
|
| 337 |
+
border-radius: 10px;
|
| 338 |
+
margin-bottom: 20px;
|
| 339 |
+
}
|
| 340 |
+
"""
|
| 341 |
)
|
| 342 |
|
app_requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# Gradio for UI
|
| 2 |
-
gradio>=
|
| 3 |
|
| 4 |
# Core LLM and agent dependencies
|
| 5 |
langchain>=0.1.0
|
|
|
|
| 1 |
# Gradio for UI
|
| 2 |
+
gradio>=6.0.2
|
| 3 |
|
| 4 |
# Core LLM and agent dependencies
|
| 5 |
langchain>=0.1.0
|
src/utils/weight_downloader.py
CHANGED
|
@@ -4,6 +4,7 @@ Utility to download and verify TruFor model weights automatically.
|
|
| 4 |
|
| 5 |
import hashlib
|
| 6 |
import os
|
|
|
|
| 7 |
import zipfile
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Optional, Tuple
|
|
@@ -64,6 +65,13 @@ def ensure_trufor_weights(workspace_root: Optional[Path] = None, auto_download:
|
|
| 64 |
"""
|
| 65 |
Ensure TruFor weights are available, downloading if necessary.
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
Args:
|
| 68 |
workspace_root: Root directory of the workspace. If None, tries to detect it.
|
| 69 |
auto_download: If True, automatically download weights if missing.
|
|
@@ -82,19 +90,14 @@ def ensure_trufor_weights(workspace_root: Optional[Path] = None, auto_download:
|
|
| 82 |
|
| 83 |
# Check if weights already exist
|
| 84 |
if weights_path.exists():
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
# Still return True - file exists, just not verified
|
| 94 |
-
return True, f"β οΈ TruFor weights found at {weights_path} but MD5 verification failed"
|
| 95 |
-
except Exception as e:
|
| 96 |
-
# If MD5 check fails, still return True if file exists
|
| 97 |
-
return True, f"β
TruFor weights found at {weights_path} (MD5 check skipped: {e})"
|
| 98 |
|
| 99 |
# Weights don't exist
|
| 100 |
if not auto_download:
|
|
@@ -118,73 +121,81 @@ def ensure_trufor_weights(workspace_root: Optional[Path] = None, auto_download:
|
|
| 118 |
if not _download_file(TRUFOR_WEIGHTS_URL, zip_path):
|
| 119 |
return False, f"β Failed to download weights from {TRUFOR_WEIGHTS_URL}"
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# Extract zip file
|
|
|
|
| 122 |
print(f"π¦ Extracting weights...")
|
| 123 |
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 124 |
# Find the weights file in the zip
|
| 125 |
members = zip_ref.namelist()
|
| 126 |
weights_member = None
|
| 127 |
|
| 128 |
-
# Look for trufor.pth.tar in the zip
|
| 129 |
for member in members:
|
| 130 |
-
if member.endswith(TRUFOR_WEIGHTS_FILENAME)
|
| 131 |
weights_member = member
|
| 132 |
break
|
| 133 |
|
| 134 |
-
if weights_member:
|
| 135 |
-
# Extract just the weights file
|
| 136 |
-
zip_ref.extract(weights_member, weights_dir)
|
| 137 |
-
|
| 138 |
-
# Move to final location if needed
|
| 139 |
-
extracted_path = weights_dir / weights_member
|
| 140 |
-
if extracted_path != weights_path:
|
| 141 |
-
if weights_path.exists():
|
| 142 |
-
weights_path.unlink()
|
| 143 |
-
extracted_path.rename(weights_path)
|
| 144 |
-
|
| 145 |
-
# Clean up zip file
|
| 146 |
zip_path.unlink()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
#
|
| 149 |
try:
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
else:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
# Look for the weights file in extracted files
|
| 168 |
-
for root, dirs, files in os.walk(weights_dir):
|
| 169 |
-
for file in files:
|
| 170 |
-
if file == TRUFOR_WEIGHTS_FILENAME:
|
| 171 |
-
found_path = Path(root) / file
|
| 172 |
-
if found_path != weights_path:
|
| 173 |
-
if weights_path.exists():
|
| 174 |
-
weights_path.unlink()
|
| 175 |
-
found_path.rename(weights_path)
|
| 176 |
-
|
| 177 |
-
# Verify MD5
|
| 178 |
-
try:
|
| 179 |
-
md5_hash = _calculate_md5(weights_path)
|
| 180 |
-
if md5_hash.lower() == TRUFOR_WEIGHTS_MD5.lower():
|
| 181 |
-
return True, f"β
TruFor weights downloaded and verified at {weights_path}"
|
| 182 |
-
else:
|
| 183 |
-
return False, f"β Downloaded weights MD5 mismatch: {md5_hash}"
|
| 184 |
-
except Exception as e:
|
| 185 |
-
return True, f"β
TruFor weights downloaded at {weights_path} (MD5 check failed: {e})"
|
| 186 |
-
|
| 187 |
-
return False, f"β Could not find {TRUFOR_WEIGHTS_FILENAME} in downloaded zip"
|
| 188 |
|
| 189 |
except Exception as e:
|
| 190 |
# Clean up on error
|
|
|
|
| 4 |
|
| 5 |
import hashlib
|
| 6 |
import os
|
| 7 |
+
import shutil
|
| 8 |
import zipfile
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional, Tuple
|
|
|
|
| 65 |
"""
|
| 66 |
Ensure TruFor weights are available, downloading if necessary.
|
| 67 |
|
| 68 |
+
Downloads TruFor_weights.zip from the official source, verifies MD5,
|
| 69 |
+
and extracts trufor.pth.tar to weights/trufor/trufor.pth.tar.
|
| 70 |
+
|
| 71 |
+
Zip structure: weights/trufor.pth.tar
|
| 72 |
+
Final path: projectroot/weights/trufor/trufor.pth.tar
|
| 73 |
+
MD5 is verified on the zip file (not the tar).
|
| 74 |
+
|
| 75 |
Args:
|
| 76 |
workspace_root: Root directory of the workspace. If None, tries to detect it.
|
| 77 |
auto_download: If True, automatically download weights if missing.
|
|
|
|
| 90 |
|
| 91 |
# Check if weights already exist
|
| 92 |
if weights_path.exists():
|
| 93 |
+
# File exists - we can't verify MD5 since it's for the zip, not the tar
|
| 94 |
+
file_size = weights_path.stat().st_size
|
| 95 |
+
if file_size > 0:
|
| 96 |
+
return True, f"β
TruFor weights found at {weights_path} ({file_size / 1024 / 1024:.1f} MB)"
|
| 97 |
+
else:
|
| 98 |
+
# Empty file - delete and re-download
|
| 99 |
+
weights_path.unlink()
|
| 100 |
+
print("β οΈ Found empty weights file, re-downloading...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
# Weights don't exist
|
| 103 |
if not auto_download:
|
|
|
|
| 121 |
if not _download_file(TRUFOR_WEIGHTS_URL, zip_path):
|
| 122 |
return False, f"β Failed to download weights from {TRUFOR_WEIGHTS_URL}"
|
| 123 |
|
| 124 |
+
# Verify MD5 of the zip file immediately after download
|
| 125 |
+
print(f"π Verifying download integrity (MD5)...")
|
| 126 |
+
try:
|
| 127 |
+
zip_md5 = _calculate_md5(zip_path)
|
| 128 |
+
if zip_md5.lower() != TRUFOR_WEIGHTS_MD5.lower():
|
| 129 |
+
zip_path.unlink()
|
| 130 |
+
return False, (
|
| 131 |
+
f"β Downloaded zip MD5 mismatch!\n"
|
| 132 |
+
f" Expected: {TRUFOR_WEIGHTS_MD5}\n"
|
| 133 |
+
f" Got: {zip_md5}\n"
|
| 134 |
+
f" The download may be corrupted. Please try again."
|
| 135 |
+
)
|
| 136 |
+
print(f"β
MD5 verified: {zip_md5}")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"β οΈ MD5 verification failed: {e}. Continuing with extraction...")
|
| 139 |
+
|
| 140 |
# Extract zip file
|
| 141 |
+
# Zip structure: weights/trufor.pth.tar
|
| 142 |
print(f"π¦ Extracting weights...")
|
| 143 |
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 144 |
# Find the weights file in the zip
|
| 145 |
members = zip_ref.namelist()
|
| 146 |
weights_member = None
|
| 147 |
|
| 148 |
+
# Look for trufor.pth.tar in the zip (expected: weights/trufor.pth.tar)
|
| 149 |
for member in members:
|
| 150 |
+
if member.endswith(TRUFOR_WEIGHTS_FILENAME):
|
| 151 |
weights_member = member
|
| 152 |
break
|
| 153 |
|
| 154 |
+
if not weights_member:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
zip_path.unlink()
|
| 156 |
+
return False, (
|
| 157 |
+
f"β Could not find {TRUFOR_WEIGHTS_FILENAME} in downloaded zip.\n"
|
| 158 |
+
f" Zip contents: {members}"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Extract just the weights file to a temp location
|
| 162 |
+
# zip_ref.extract will create the nested directory structure
|
| 163 |
+
zip_ref.extract(weights_member, weights_dir)
|
| 164 |
+
|
| 165 |
+
# Move from extracted location to final location
|
| 166 |
+
# e.g., weights/trufor/weights/trufor.pth.tar -> weights/trufor/trufor.pth.tar
|
| 167 |
+
extracted_path = weights_dir / weights_member
|
| 168 |
+
|
| 169 |
+
if extracted_path != weights_path:
|
| 170 |
+
# Move to final location
|
| 171 |
+
if weights_path.exists():
|
| 172 |
+
weights_path.unlink()
|
| 173 |
+
shutil.move(str(extracted_path), str(weights_path))
|
| 174 |
|
| 175 |
+
# Clean up any empty directories left from extraction
|
| 176 |
try:
|
| 177 |
+
# Remove the 'weights' folder if it was created inside weights_dir
|
| 178 |
+
extracted_parent = extracted_path.parent
|
| 179 |
+
while extracted_parent != weights_dir and extracted_parent.exists():
|
| 180 |
+
if not any(extracted_parent.iterdir()):
|
| 181 |
+
extracted_parent.rmdir()
|
| 182 |
+
extracted_parent = extracted_parent.parent
|
| 183 |
+
except Exception:
|
| 184 |
+
pass # Ignore cleanup errors
|
| 185 |
+
|
| 186 |
+
# Clean up zip file
|
| 187 |
+
zip_path.unlink()
|
| 188 |
+
|
| 189 |
+
# Verify final file exists and has content
|
| 190 |
+
if weights_path.exists():
|
| 191 |
+
file_size = weights_path.stat().st_size
|
| 192 |
+
if file_size > 0:
|
| 193 |
+
return True, f"β
TruFor weights downloaded successfully to {weights_path} ({file_size / 1024 / 1024:.1f} MB)"
|
| 194 |
else:
|
| 195 |
+
weights_path.unlink()
|
| 196 |
+
return False, f"β Extracted weights file is empty"
|
| 197 |
+
else:
|
| 198 |
+
return False, f"β Failed to extract weights to {weights_path}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
except Exception as e:
|
| 201 |
# Clean up on error
|