asdrty123 commited on
Commit
984fbd1
·
verified ·
1 Parent(s): 44767f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -104,7 +104,7 @@ def model_data(model_name, model_root="weights", repo_id="simpsonsaiorg/stream-m
104
  if len(pth_files) == 0:
105
  raise ValueError(f"No .pth file found in {model_root}/{model_name}")
106
  pth_path = pth_files[0]
107
-
108
  index_files = [
109
  os.path.join(model_root, model_name, f)
110
  for f in os.listdir(os.path.join(model_root, model_name))
@@ -112,30 +112,24 @@ def model_data(model_name, model_root="weights", repo_id="simpsonsaiorg/stream-m
112
  ]
113
  else:
114
  # --- HuggingFace load ---
115
- # List all files in the repo first
116
  all_files = api.list_repo_files(repo_id)
117
  model_files = [f for f in all_files if f.startswith(f"weights/{model_name}/")]
118
 
119
- # Find .pth file (pick the first one)
120
  pth_files = [f for f in model_files if f.endswith(".pth")]
121
  if not pth_files:
122
  raise ValueError(f"No .pth file found for model {model_name} in repo")
123
  pth_path = hf_hub_download(repo_id=repo_id, filename=pth_files[0])
124
 
125
- # Find optional index files (match any pattern like added_IVF*_v2.index)
126
  index_files = [
127
  hf_hub_download(repo_id=repo_id, filename=f)
128
  for f in model_files
129
  if f.endswith(".index") and "added_IVF" in f
130
  ]
131
-
132
-
133
- print(f"Loading {pth_path}")
134
 
135
- if len(pth_files) == 0:
136
- raise ValueError(f"No .pth file found for model {model_name}")
137
- pth_path = pth_files[0]
138
- print(f"Loading {pth_path}")
139
 
140
  # -----------------------
141
  # 2. Load checkpoint
 
104
  if len(pth_files) == 0:
105
  raise ValueError(f"No .pth file found in {model_root}/{model_name}")
106
  pth_path = pth_files[0]
107
+
108
  index_files = [
109
  os.path.join(model_root, model_name, f)
110
  for f in os.listdir(os.path.join(model_root, model_name))
 
112
  ]
113
  else:
114
  # --- HuggingFace load ---
 
115
  all_files = api.list_repo_files(repo_id)
116
  model_files = [f for f in all_files if f.startswith(f"weights/{model_name}/")]
117
 
118
+ # Find .pth file
119
  pth_files = [f for f in model_files if f.endswith(".pth")]
120
  if not pth_files:
121
  raise ValueError(f"No .pth file found for model {model_name} in repo")
122
  pth_path = hf_hub_download(repo_id=repo_id, filename=pth_files[0])
123
 
124
+ # Find index files
125
  index_files = [
126
  hf_hub_download(repo_id=repo_id, filename=f)
127
  for f in model_files
128
  if f.endswith(".index") and "added_IVF" in f
129
  ]
 
 
 
130
 
131
+ print(f"Loading {pth_path}") # <-- safe to do for both cases
132
+
 
 
133
 
134
  # -----------------------
135
  # 2. Load checkpoint