Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -104,7 +104,7 @@ def model_data(model_name, model_root="weights", repo_id="simpsonsaiorg/stream-m
|
|
| 104 |
if len(pth_files) == 0:
|
| 105 |
raise ValueError(f"No .pth file found in {model_root}/{model_name}")
|
| 106 |
pth_path = pth_files[0]
|
| 107 |
-
|
| 108 |
index_files = [
|
| 109 |
os.path.join(model_root, model_name, f)
|
| 110 |
for f in os.listdir(os.path.join(model_root, model_name))
|
|
@@ -112,30 +112,24 @@ def model_data(model_name, model_root="weights", repo_id="simpsonsaiorg/stream-m
|
|
| 112 |
]
|
| 113 |
else:
|
| 114 |
# --- HuggingFace load ---
|
| 115 |
-
# List all files in the repo first
|
| 116 |
all_files = api.list_repo_files(repo_id)
|
| 117 |
model_files = [f for f in all_files if f.startswith(f"weights/{model_name}/")]
|
| 118 |
|
| 119 |
-
# Find .pth file
|
| 120 |
pth_files = [f for f in model_files if f.endswith(".pth")]
|
| 121 |
if not pth_files:
|
| 122 |
raise ValueError(f"No .pth file found for model {model_name} in repo")
|
| 123 |
pth_path = hf_hub_download(repo_id=repo_id, filename=pth_files[0])
|
| 124 |
|
| 125 |
-
# Find
|
| 126 |
index_files = [
|
| 127 |
hf_hub_download(repo_id=repo_id, filename=f)
|
| 128 |
for f in model_files
|
| 129 |
if f.endswith(".index") and "added_IVF" in f
|
| 130 |
]
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
print(f"Loading {pth_path}")
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
pth_path = pth_files[0]
|
| 138 |
-
print(f"Loading {pth_path}")
|
| 139 |
|
| 140 |
# -----------------------
|
| 141 |
# 2. Load checkpoint
|
|
|
|
| 104 |
if len(pth_files) == 0:
|
| 105 |
raise ValueError(f"No .pth file found in {model_root}/{model_name}")
|
| 106 |
pth_path = pth_files[0]
|
| 107 |
+
|
| 108 |
index_files = [
|
| 109 |
os.path.join(model_root, model_name, f)
|
| 110 |
for f in os.listdir(os.path.join(model_root, model_name))
|
|
|
|
| 112 |
]
|
| 113 |
else:
|
| 114 |
# --- HuggingFace load ---
|
|
|
|
| 115 |
all_files = api.list_repo_files(repo_id)
|
| 116 |
model_files = [f for f in all_files if f.startswith(f"weights/{model_name}/")]
|
| 117 |
|
| 118 |
+
# Find .pth file
|
| 119 |
pth_files = [f for f in model_files if f.endswith(".pth")]
|
| 120 |
if not pth_files:
|
| 121 |
raise ValueError(f"No .pth file found for model {model_name} in repo")
|
| 122 |
pth_path = hf_hub_download(repo_id=repo_id, filename=pth_files[0])
|
| 123 |
|
| 124 |
+
# Find index files
|
| 125 |
index_files = [
|
| 126 |
hf_hub_download(repo_id=repo_id, filename=f)
|
| 127 |
for f in model_files
|
| 128 |
if f.endswith(".index") and "added_IVF" in f
|
| 129 |
]
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
print(f"Loading {pth_path}") # <-- safe to do for both cases
|
| 132 |
+
|
|
|
|
|
|
|
| 133 |
|
| 134 |
# -----------------------
|
| 135 |
# 2. Load checkpoint
|