Che237 commited on
Commit
0955fe4
·
verified ·
1 Parent(s): 3e18f98

Fix MLModelLoader to search notebooks models/ directory + trained_models/

Browse files
Files changed (1) hide show
  1. app.py +44 -30
app.py CHANGED
@@ -247,8 +247,18 @@ class MLModelLoader:
247
 
248
  def initialize(self):
249
  loaded = 0
 
 
 
 
 
 
 
250
  for name in self.MODEL_NAMES:
 
 
251
  try:
 
252
  model_file = f"{name}/best_model.pkl"
253
  scaler_file = f"{name}/scaler.pkl"
254
  try:
@@ -264,43 +274,47 @@ class MLModelLoader:
264
  self.scalers[name] = joblib.load(scaler_path)
265
  loaded += 1
266
  logger.info(f"✅ Loaded model from Hub: {name}")
 
267
  except Exception:
268
- # Try flat filename pattern
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  try:
270
- model_path = hf_hub_download(
271
- repo_id=HF_MODEL_REPO, filename=f"{name}_model.pkl",
272
- token=HF_TOKEN or None, cache_dir=str(MODELS_DIR),
273
- )
274
- self.models[name] = joblib.load(model_path)
275
  loaded += 1
276
- logger.info(f"✅ Loaded model (flat): {name}")
277
  except Exception:
278
  pass
279
- # Try local
280
- local_model = MODELS_DIR / name / "best_model.pkl"
281
- if local_model.exists() and name not in self.models:
282
- self.models[name] = joblib.load(local_model)
283
- local_scaler = MODELS_DIR / name / "scaler.pkl"
284
- if local_scaler.exists():
285
- self.scalers[name] = joblib.load(local_scaler)
286
- loaded += 1
287
- logger.info(f"✅ Loaded model from local: {name}")
288
- except Exception as e:
289
- logger.warning(f"Error loading model {name}: {e}")
290
-
291
- # Also check trained_models dir for any .pkl files
292
- for pkl in MODELS_DIR.glob("*.pkl"):
293
- stem = pkl.stem.replace("_model", "").replace("_best", "")
294
- if stem not in self.models:
295
- try:
296
- self.models[stem] = joblib.load(pkl)
297
- loaded += 1
298
- logger.info(f"✅ Loaded local model: {stem}")
299
- except Exception:
300
- pass
301
 
302
  self.ready = loaded > 0
303
- logger.info(f"ML Models: {loaded} loaded ({list(ml_loader.models.keys()) if loaded else 'none'})")
304
 
305
  def predict(self, model_name: str, features: Dict) -> Dict:
306
  if model_name not in self.models:
 
247
 
248
  def initialize(self):
249
  loaded = 0
250
+ # All directories where models might exist (notebooks save to ../models)
251
+ search_dirs = [
252
+ MODELS_DIR, # trained_models/
253
+ APP_DIR / "models", # models/ (where notebooks output)
254
+ APP_DIR.parent / "models", # one level up fallback
255
+ ]
256
+
257
  for name in self.MODEL_NAMES:
258
+ if name in self.models:
259
+ continue
260
  try:
261
+ # 1. Try HuggingFace Hub first
262
  model_file = f"{name}/best_model.pkl"
263
  scaler_file = f"{name}/scaler.pkl"
264
  try:
 
274
  self.scalers[name] = joblib.load(scaler_path)
275
  loaded += 1
276
  logger.info(f"✅ Loaded model from Hub: {name}")
277
+ continue
278
  except Exception:
279
+ pass
280
+
281
+ # 2. Try all local search directories
282
+ for sdir in search_dirs:
283
+ if name in self.models:
284
+ break
285
+ for model_fname in [f"{name}/best_model.pkl", f"{name}/model.pkl", f"{name}_model.pkl"]:
286
+ candidate = sdir / model_fname
287
+ if candidate.exists():
288
+ self.models[name] = joblib.load(candidate)
289
+ # Try to find matching scaler
290
+ for scaler_fname in [f"{name}/scaler.pkl", f"{name}_scaler.pkl"]:
291
+ sc = sdir / scaler_fname
292
+ if sc.exists():
293
+ self.scalers[name] = joblib.load(sc)
294
+ break
295
+ loaded += 1
296
+ logger.info(f"✅ Loaded model from {sdir.name}/{model_fname}: {name}")
297
+ break
298
+
299
+ except Exception as e:
300
+ logger.warning(f"Error loading model {name}: {e}")
301
+
302
+ # Sweep all search dirs for any .pkl files not yet loaded
303
+ for sdir in search_dirs:
304
+ if not sdir.exists():
305
+ continue
306
+ for pkl in sdir.glob("*.pkl"):
307
+ stem = pkl.stem.replace("_model", "").replace("_best", "")
308
+ if stem not in self.models:
309
  try:
310
+ self.models[stem] = joblib.load(pkl)
 
 
 
 
311
  loaded += 1
312
+ logger.info(f"✅ Loaded model sweep: {stem} from {sdir.name}")
313
  except Exception:
314
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  self.ready = loaded > 0
317
+ logger.info(f"ML Models: {loaded} loaded {list(self.models.keys())}")
318
 
319
  def predict(self, model_name: str, features: Dict) -> Dict:
320
  if model_name not in self.models: