Update app.py
Browse files
app.py
CHANGED
|
@@ -8,14 +8,47 @@ import gradio as gr
|
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# --- CONFIG & MODEL DOWNLOAD ---
|
| 12 |
MODEL_PATH = "LookThem_V8_MNIST.pth"
|
|
|
|
| 13 |
HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"
|
| 14 |
|
| 15 |
if not os.path.exists(MODEL_PATH):
|
| 16 |
print(f"Downloading model weights from Hugging Face...")
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# --- DEFINE YOUR MODEL ARCHITECTURE ---
|
| 21 |
class LookThemLayer(nn.Module):
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
| 10 |
|
| 11 |
+
import os
|
| 12 |
+
import urllib.request
|
| 13 |
+
import zipfile
|
| 14 |
+
|
| 15 |
# --- CONFIG & MODEL DOWNLOAD ---
|
| 16 |
MODEL_PATH = "LookThem_V8_MNIST.pth"
|
| 17 |
+
ZIP_PATH = "LookThem_V8_MNIST.zip"
|
| 18 |
HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"
|
| 19 |
|
| 20 |
if not os.path.exists(MODEL_PATH):
|
| 21 |
print(f"Downloading model weights from Hugging Face...")
|
| 22 |
+
# Download the file as a zip first
|
| 23 |
+
urllib.request.urlretrieve(HF_URL, ZIP_PATH)
|
| 24 |
+
print("Download complete! Checking for zip compression...")
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
# Unzip the file
|
| 28 |
+
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
|
| 29 |
+
# Look for a .pth file inside the zip
|
| 30 |
+
file_list = zip_ref.namelist()
|
| 31 |
+
pth_files = [f for f in file_list if f.endswith('.pth')]
|
| 32 |
+
|
| 33 |
+
if pth_files:
|
| 34 |
+
# Extract the .pth file and rename it to our expected MODEL_PATH
|
| 35 |
+
zip_ref.extract(pth_files[0], path=".")
|
| 36 |
+
if pth_files[0] != MODEL_PATH:
|
| 37 |
+
os.rename(pth_files[0], MODEL_PATH)
|
| 38 |
+
print(f"Successfully extracted: {pth_files[0]} -> {MODEL_PATH}")
|
| 39 |
+
else:
|
| 40 |
+
# If no .pth inside, maybe the zip itself *is* the model (PyTorch 2.0+ format)
|
| 41 |
+
print("No .pth file found inside zip. Renaming zip directly to .pth...")
|
| 42 |
+
os.rename(ZIP_PATH, MODEL_PATH)
|
| 43 |
+
|
| 44 |
+
except zipfile.BadZipFile:
|
| 45 |
+
# If it wasn't actually a zip file, just rename it
|
| 46 |
+
print("File is not a zip archive. Proceeding with standard weight loading.")
|
| 47 |
+
os.rename(ZIP_PATH, MODEL_PATH)
|
| 48 |
+
|
| 49 |
+
# Clean up the temporary zip file if it still exists
|
| 50 |
+
if os.path.exists(ZIP_PATH):
|
| 51 |
+
os.remove(ZIP_PATH)
|
| 52 |
|
| 53 |
# --- DEFINE YOUR MODEL ARCHITECTURE ---
|
| 54 |
class LookThemLayer(nn.Module):
|