ASomeoneWhoInterestedWithAI commited on
Commit
3fec527
·
verified ·
1 Parent(s): 077e064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -2
app.py CHANGED
@@ -8,14 +8,47 @@ import gradio as gr
8
  import numpy as np
9
  from PIL import Image
10
 
 
 
 
 
11
  # --- CONFIG & MODEL DOWNLOAD ---
12
  MODEL_PATH = "LookThem_V8_MNIST.pth"
 
13
  HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"
14
 
15
  if not os.path.exists(MODEL_PATH):
16
  print(f"Downloading model weights from Hugging Face...")
17
- urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
- print("Download complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # --- DEFINE YOUR MODEL ARCHITECTURE ---
21
  class LookThemLayer(nn.Module):
 
8
  import numpy as np
9
  from PIL import Image
10
 
11
+ import os
12
+ import urllib.request
13
+ import zipfile
14
+
15
  # --- CONFIG & MODEL DOWNLOAD ---
16
  MODEL_PATH = "LookThem_V8_MNIST.pth"
17
+ ZIP_PATH = "LookThem_V8_MNIST.zip"
18
  HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"
19
 
20
  if not os.path.exists(MODEL_PATH):
21
  print(f"Downloading model weights from Hugging Face...")
22
+ # Download the file as a zip first
23
+ urllib.request.urlretrieve(HF_URL, ZIP_PATH)
24
+ print("Download complete! Checking for zip compression...")
25
+
26
+ try:
27
+ # Unzip the file
28
+ with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
29
+ # Look for a .pth file inside the zip
30
+ file_list = zip_ref.namelist()
31
+ pth_files = [f for f in file_list if f.endswith('.pth')]
32
+
33
+ if pth_files:
34
+ # Extract the .pth file and rename it to our expected MODEL_PATH
35
+ zip_ref.extract(pth_files[0], path=".")
36
+ if pth_files[0] != MODEL_PATH:
37
+ os.rename(pth_files[0], MODEL_PATH)
38
+ print(f"Successfully extracted: {pth_files[0]} -> {MODEL_PATH}")
39
+ else:
40
+ # If no .pth inside, maybe the zip itself *is* the model (PyTorch 2.0+ format)
41
+ print("No .pth file found inside zip. Renaming zip directly to .pth...")
42
+ os.rename(ZIP_PATH, MODEL_PATH)
43
+
44
+ except zipfile.BadZipFile:
45
+ # If it wasn't actually a zip file, just rename it
46
+ print("File is not a zip archive. Proceeding with standard weight loading.")
47
+ os.rename(ZIP_PATH, MODEL_PATH)
48
+
49
+ # Clean up the temporary zip file if it still exists
50
+ if os.path.exists(ZIP_PATH):
51
+ os.remove(ZIP_PATH)
52
 
53
  # --- DEFINE YOUR MODEL ARCHITECTURE ---
54
  class LookThemLayer(nn.Module):