mastefan commited on
Commit
c405c2d
·
verified ·
1 Parent(s): 3ca6a57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -29
app.py CHANGED
@@ -80,12 +80,32 @@ def load_yolo_from_hub():
80
  return YOLO(w)
81
 
82
  def load_autogluon_tabular_from_hub():
 
83
  z = hf_hub_download(repo_id=AG_REPO_ID, filename=AG_ZIP_NAME, cache_dir=CACHE_DIR)
84
  extract_dir = CACHE_DIR / "ag_predictor_native"
85
- if extract_dir.exists(): shutil.rmtree(extract_dir)
86
- with zipfile.ZipFile(z, "r") as zip_ref: zip_ref.extractall(extract_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  print(f"[INFO] Loaded AutoGluon predictor from {extract_dir}")
88
- return TabularPredictor.load(str(extract_dir), require_py_version_match=False)
 
89
 
90
 
91
  _YOLO = None
@@ -191,39 +211,16 @@ def predict_scores(df):
191
  feats = ["red_ratio", "green_ratio", "red_diff", "green_diff", "z_red", "z_green"]
192
  X = df[feats].copy()
193
  ag = ag_predictor()
194
-
195
- # List all models and exclude any FastAI learners (which can cause tabular DL issues)
196
- try:
197
- available_models = ag.model_names()
198
- except Exception:
199
- available_models = ag.model_names
200
- safe_models = [m for m in available_models if "fastai" not in m.lower()]
201
- print(f"[INFO] Using safe models: {safe_models}")
202
-
203
- # Access AutoGluon's internal learner, which supports the `models` argument
204
- learner = ag._learner
205
-
206
  try:
207
- # Get probabilities from the safe models only
208
- proba = learner.predict_proba(X, models=safe_models)
209
  if isinstance(proba, pd.DataFrame) and (1 in proba.columns):
210
  return proba[1]
211
  except Exception as e:
212
- print("[WARN] FastAI submodel failed retrying with safe models only:", e)
213
- proba = learner.predict_proba(X, models=safe_models)
214
- if isinstance(proba, pd.DataFrame) and (1 in proba.columns):
215
- return proba[1]
216
-
217
- # Fallback if probabilities aren’t available
218
- try:
219
- preds = learner.predict(X, models=safe_models)
220
  s = pd.Series(preds).astype(float)
221
  rng = (s.quantile(0.95) - s.quantile(0.05)) or 1.0
222
  return ((s - s.quantile(0.05)) / rng).clip(0, 1)
223
- except Exception as e:
224
- print("[ERROR] Predictor fallback failed:", e)
225
- return pd.Series(np.zeros(len(df)))
226
-
227
 
228
  def pick_events(df,score,fps):
229
  z=rolling_z(score,45); strong=(z>4.0); keep=strong.rolling(3,min_periods=1).sum()>=2
 
80
  return YOLO(w)
81
 
82
  def load_autogluon_tabular_from_hub():
83
+ """Download and load AutoGluon predictor, removing any FastAI submodels."""
84
  z = hf_hub_download(repo_id=AG_REPO_ID, filename=AG_ZIP_NAME, cache_dir=CACHE_DIR)
85
  extract_dir = CACHE_DIR / "ag_predictor_native"
86
+ if extract_dir.exists():
87
+ shutil.rmtree(extract_dir)
88
+ with zipfile.ZipFile(z, "r") as zip_ref:
89
+ zip_ref.extractall(extract_dir)
90
+
91
+ # --- delete fastai models before loading to avoid deserialization errors ---
92
+ fastai_dirs = list(extract_dir.rglob("*fastai*"))
93
+ for p in fastai_dirs:
94
+ try:
95
+ if p.is_dir():
96
+ shutil.rmtree(p)
97
+ else:
98
+ p.unlink()
99
+ except Exception as e:
100
+ print(f"[WARN] Could not remove {p}: {e}")
101
+ print(f"[CLEANUP] Removed {len(fastai_dirs)} FastAI model files.")
102
+
103
+ # Now load normally (no version check)
104
+ from autogluon.tabular import TabularPredictor
105
+ predictor = TabularPredictor.load(str(extract_dir), require_py_version_match=False)
106
  print(f"[INFO] Loaded AutoGluon predictor from {extract_dir}")
107
+ return predictor
108
+
109
 
110
 
111
  _YOLO = None
 
211
  feats = ["red_ratio", "green_ratio", "red_diff", "green_diff", "z_red", "z_green"]
212
  X = df[feats].copy()
213
  ag = ag_predictor()
 
 
 
 
 
 
 
 
 
 
 
 
214
  try:
215
+ proba = ag.predict_proba(X)
 
216
  if isinstance(proba, pd.DataFrame) and (1 in proba.columns):
217
  return proba[1]
218
  except Exception as e:
219
+ print("[WARN] Predictor .predict_proba() failed, falling back:", e)
220
+ preds = ag.predict(X)
 
 
 
 
 
 
221
  s = pd.Series(preds).astype(float)
222
  rng = (s.quantile(0.95) - s.quantile(0.05)) or 1.0
223
  return ((s - s.quantile(0.05)) / rng).clip(0, 1)
 
 
 
 
224
 
225
  def pick_events(df,score,fps):
226
  z=rolling_z(score,45); strong=(z>4.0); keep=strong.rolling(3,min_periods=1).sum()>=2