allenborochin commited on
Commit
ac4d749
·
verified ·
1 Parent(s): b42197e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -50
app.py CHANGED
@@ -3,14 +3,13 @@ FPL Player Predictor - Streamlit app for HuggingFace Space.
3
  Loads trained models + engineered dataset from companion model repo.
4
  """
5
 
6
- import io
7
  import urllib.parse
8
- from pathlib import Path
9
 
10
  import numpy as np
11
  import pandas as pd
12
  import pickle
13
  import streamlit as st
 
14
  from huggingface_hub import hf_hub_download
15
 
16
  # ============================================================
@@ -22,7 +21,6 @@ REGRESSION_FILE = "fpl_regression_model.pkl"
22
  CLASSIFICATION_FILE = "fpl_classification_model.pkl"
23
  DATASET_FILE = "df_fe.parquet"
24
 
25
- # PL brand palette
26
  PL_PURPLE = "#37003C"
27
  PL_CYAN = "#00FF87"
28
  PL_LIME = "#04F5FF"
@@ -34,7 +32,6 @@ CLASS_LABELS = {0: "Blank", 1: "Decent", 2: "Good", 3: "Haul"}
34
  CLASS_RANGES = {0: "0-1 pts", 1: "2-4 pts", 2: "5-9 pts", 3: "10+ pts"}
35
  CLASS_COLORS = {0: "#888888", 1: "#04F5FF", 2: "#00FF87", 3: "#E90052"}
36
 
37
- # Features that are 0/1 booleans - display as Yes/No
38
  BOOLEAN_FEATURES = {"was_home", "had_haul_last_3", "has_std_history"}
39
  CLUSTER_FEATURES = {"cluster_0", "cluster_1", "cluster_2", "cluster_3", "cluster_4"}
40
 
@@ -66,15 +63,6 @@ st.markdown(
66
  padding: 20px;
67
  margin-bottom: 16px;
68
  }}
69
- .pl-pill {{
70
- display: inline-block;
71
- padding: 4px 12px;
72
- border-radius: 999px;
73
- font-size: 11px;
74
- font-weight: 700;
75
- letter-spacing: 0.05em;
76
- text-transform: uppercase;
77
- }}
78
  .stSelectbox label, .stRadio label {{
79
  color: {PL_CYAN} !important;
80
  font-size: 12px !important;
@@ -125,7 +113,6 @@ st.markdown(
125
 
126
  @st.cache_resource(show_spinner="Loading models from HuggingFace...")
127
  def load_artifacts():
128
- """Pull all 3 artifacts from the model repo on first run, then cache."""
129
  reg_path = hf_hub_download(repo_id=MODEL_REPO, filename=REGRESSION_FILE)
130
  cls_path = hf_hub_download(repo_id=MODEL_REPO, filename=CLASSIFICATION_FILE)
131
  data_path = hf_hub_download(repo_id=MODEL_REPO, filename=DATASET_FILE, repo_type="model")
@@ -144,7 +131,6 @@ def load_artifacts():
144
  # ============================================================
145
 
146
  def get_player_row(df, name, position, season, gameweek):
147
- """Find the exact row matching the user's selection."""
148
  mask = (
149
  (df["name"] == name)
150
  & (df["position"] == position)
@@ -158,37 +144,27 @@ def get_player_row(df, name, position, season, gameweek):
158
 
159
 
160
  def predict_regression(row, bundle):
161
- """Run regression prediction. Returns predicted points (float)."""
162
  model = bundle["model"]
163
  scaler = bundle["scaler"]
164
  features = bundle["feature_names"]
165
-
166
  X = pd.DataFrame([row[features].values], columns=features)
167
  X_scaled = scaler.transform(X)
168
  return float(model.predict(X_scaled)[0])
169
 
170
 
171
  def predict_classification(row, bundle):
172
- """Run classification prediction. Returns (predicted_class_int, probabilities_dict)."""
173
  model = bundle["model"]
174
  scaler = bundle["scaler"]
175
  features = bundle["feature_names"]
176
-
177
  X = pd.DataFrame([row[features].values], columns=features)
178
  X_scaled = scaler.transform(X)
179
-
180
  pred_class = int(model.predict(X_scaled)[0])
181
  probs = model.predict_proba(X_scaled)[0]
182
  probs_dict = {i: float(p) for i, p in enumerate(probs)}
183
  return pred_class, probs_dict
184
 
185
 
186
- def feature_contributions(row, bundle, top_n_up=8, top_n_down=6):
187
- """
188
- For LogReg classifier: compute per-feature contribution to the predicted Haul class.
189
- Contribution = scaled_feature_value * coefficient_for_class.
190
- Returns DataFrame sorted by signed contribution.
191
- """
192
  model = bundle["model"]
193
  scaler = bundle["scaler"]
194
  features = bundle["feature_names"]
@@ -196,10 +172,8 @@ def feature_contributions(row, bundle, top_n_up=8, top_n_down=6):
196
  X = pd.DataFrame([row[features].values], columns=features)
197
  X_scaled = scaler.transform(X)[0]
198
 
199
- # Use Haul class (3) coefficients - what's pushing toward "is this a haul"
200
  haul_class_idx = list(model.classes_).index(3)
201
  coefs = model.coef_[haul_class_idx]
202
-
203
  contributions = X_scaled * coefs
204
 
205
  df_contrib = pd.DataFrame({
@@ -214,7 +188,6 @@ def feature_contributions(row, bundle, top_n_up=8, top_n_down=6):
214
 
215
 
216
  def feature_friendly_name(feature):
217
- """Convert internal feature names to human-readable labels."""
218
  mapping = {
219
  "minutes_lag_1": "Minutes played last gameweek",
220
  "points_lag_1": "Points scored last gameweek",
@@ -251,35 +224,137 @@ def feature_friendly_name(feature):
251
 
252
 
253
  def format_value_display(feature, value):
254
- """Format the raw value for display in the feature panel."""
255
- # Booleans, cluster membership, position dummies, season dummies -> Yes/No
 
256
  if feature in BOOLEAN_FEATURES or feature in CLUSTER_FEATURES:
 
 
257
  return "Yes" if value >= 0.5 else "No"
258
  if feature.startswith("position_") or feature.startswith("season_"):
259
  return "Yes" if value >= 0.5 else "No"
260
- # Numeric -> 1 decimal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if isinstance(value, (int, float, np.floating, np.integer)) and not isinstance(value, bool):
262
  return f"{value:.1f}"
263
  return str(value)
264
 
265
 
266
  def should_skip_feature(feature, value):
267
- """Hide cluster/position/season rows where the player doesn't belong - they aren't informative."""
 
268
  if feature in CLUSTER_FEATURES and value < 0.5:
269
  return True
270
  if feature.startswith("position_") and value < 0.5:
271
  return True
272
- if feature.startswith("season_") and value < 0.5:
 
 
 
 
273
  return True
274
  return False
275
 
276
 
277
  # ============================================================
278
- # PLAIN-ENGLISH SUMMARY GENERATOR
279
  # ============================================================
280
 
281
  def generate_summary(row, reg_pred, cls_pred, probs):
282
- """Build a rule-based plain-English match preview."""
283
  parts = []
284
 
285
  venue = "at home" if row.get("was_home", 0) == 1 else "away"
@@ -328,11 +403,58 @@ def generate_summary(row, reg_pred, cls_pred, probs):
328
 
329
 
330
  # ============================================================
331
- # YOUTUBE SEARCH LINK
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  # ============================================================
333
 
334
  def youtube_highlights_url(row):
335
- """Build a smart YouTube search link for this fixture's highlights."""
336
  home_team = row["team"] if row.get("was_home", 0) == 1 else row["opponent"]
337
  away_team = row["opponent"] if row.get("was_home", 0) == 1 else row["team"]
338
  season = row["season"]
@@ -532,17 +654,10 @@ elif st.session_state.step == "results":
532
 
533
  st.markdown("<br>", unsafe_allow_html=True)
534
 
535
- # CLASS PROBABILITY BAR
536
  st.markdown("#### Class probabilities")
537
- prob_df = pd.DataFrame({
538
- "Tier": [CLASS_LABELS[i] for i in range(4)],
539
- "Probability": [probs.get(i, 0) * 100 for i in range(4)],
540
- })
541
- st.bar_chart(
542
- prob_df.set_index("Tier"),
543
- height=160,
544
- color=PL_CYAN,
545
- )
546
 
547
  # MAIN CONTENT - 2 COLS
548
  col_left, col_right = st.columns([1, 1])
@@ -556,12 +671,15 @@ elif st.session_state.step == "results":
556
  with col_right:
557
  st.markdown('<div class="pl-card">', unsafe_allow_html=True)
558
  st.markdown("#### Why the model is predicting this")
559
- top_up, top_down = feature_contributions(row, cls_bundle, top_n_up=8, top_n_down=6)
560
 
561
  st.markdown(f"<div style='color: {PL_CYAN}; font-weight: 700; font-size: 12px; letter-spacing: 0.05em; margin-bottom: 8px;'>PUSHING TOWARD HAUL</div>", unsafe_allow_html=True)
 
562
  for _, r in top_up.iterrows():
563
  if should_skip_feature(r["feature"], r["raw_value"]):
564
  continue
 
 
565
  label = feature_friendly_name(r["feature"])
566
  val_str = format_value_display(r["feature"], r["raw_value"])
567
  st.markdown(
@@ -571,11 +689,15 @@ elif st.session_state.step == "results":
571
  f"</div>",
572
  unsafe_allow_html=True,
573
  )
 
574
 
575
  st.markdown(f"<div style='color: {PL_PINK}; font-weight: 700; font-size: 12px; letter-spacing: 0.05em; margin-top: 16px; margin-bottom: 8px;'>PUSHING AWAY FROM HAUL</div>", unsafe_allow_html=True)
 
576
  for _, r in top_down.iterrows():
577
  if should_skip_feature(r["feature"], r["raw_value"]):
578
  continue
 
 
579
  label = feature_friendly_name(r["feature"])
580
  val_str = format_value_display(r["feature"], r["raw_value"])
581
  st.markdown(
@@ -585,10 +707,11 @@ elif st.session_state.step == "results":
585
  f"</div>",
586
  unsafe_allow_html=True,
587
  )
 
588
 
589
  st.markdown("</div>", unsafe_allow_html=True)
590
 
591
- # YOUTUBE LINK
592
  yt_url = youtube_highlights_url(row)
593
  st.markdown(
594
  f"""
 
3
  Loads trained models + engineered dataset from companion model repo.
4
  """
5
 
 
6
  import urllib.parse
 
7
 
8
  import numpy as np
9
  import pandas as pd
10
  import pickle
11
  import streamlit as st
12
+ import plotly.graph_objects as go
13
  from huggingface_hub import hf_hub_download
14
 
15
  # ============================================================
 
21
  CLASSIFICATION_FILE = "fpl_classification_model.pkl"
22
  DATASET_FILE = "df_fe.parquet"
23
 
 
24
  PL_PURPLE = "#37003C"
25
  PL_CYAN = "#00FF87"
26
  PL_LIME = "#04F5FF"
 
32
  CLASS_RANGES = {0: "0-1 pts", 1: "2-4 pts", 2: "5-9 pts", 3: "10+ pts"}
33
  CLASS_COLORS = {0: "#888888", 1: "#04F5FF", 2: "#00FF87", 3: "#E90052"}
34
 
 
35
  BOOLEAN_FEATURES = {"was_home", "had_haul_last_3", "has_std_history"}
36
  CLUSTER_FEATURES = {"cluster_0", "cluster_1", "cluster_2", "cluster_3", "cluster_4"}
37
 
 
63
  padding: 20px;
64
  margin-bottom: 16px;
65
  }}
 
 
 
 
 
 
 
 
 
66
  .stSelectbox label, .stRadio label {{
67
  color: {PL_CYAN} !important;
68
  font-size: 12px !important;
 
113
 
114
  @st.cache_resource(show_spinner="Loading models from HuggingFace...")
115
  def load_artifacts():
 
116
  reg_path = hf_hub_download(repo_id=MODEL_REPO, filename=REGRESSION_FILE)
117
  cls_path = hf_hub_download(repo_id=MODEL_REPO, filename=CLASSIFICATION_FILE)
118
  data_path = hf_hub_download(repo_id=MODEL_REPO, filename=DATASET_FILE, repo_type="model")
 
131
  # ============================================================
132
 
133
  def get_player_row(df, name, position, season, gameweek):
 
134
  mask = (
135
  (df["name"] == name)
136
  & (df["position"] == position)
 
144
 
145
 
146
  def predict_regression(row, bundle):
 
147
  model = bundle["model"]
148
  scaler = bundle["scaler"]
149
  features = bundle["feature_names"]
 
150
  X = pd.DataFrame([row[features].values], columns=features)
151
  X_scaled = scaler.transform(X)
152
  return float(model.predict(X_scaled)[0])
153
 
154
 
155
  def predict_classification(row, bundle):
 
156
  model = bundle["model"]
157
  scaler = bundle["scaler"]
158
  features = bundle["feature_names"]
 
159
  X = pd.DataFrame([row[features].values], columns=features)
160
  X_scaled = scaler.transform(X)
 
161
  pred_class = int(model.predict(X_scaled)[0])
162
  probs = model.predict_proba(X_scaled)[0]
163
  probs_dict = {i: float(p) for i, p in enumerate(probs)}
164
  return pred_class, probs_dict
165
 
166
 
167
+ def feature_contributions(row, bundle, top_n_up=10, top_n_down=8):
 
 
 
 
 
168
  model = bundle["model"]
169
  scaler = bundle["scaler"]
170
  features = bundle["feature_names"]
 
172
  X = pd.DataFrame([row[features].values], columns=features)
173
  X_scaled = scaler.transform(X)[0]
174
 
 
175
  haul_class_idx = list(model.classes_).index(3)
176
  coefs = model.coef_[haul_class_idx]
 
177
  contributions = X_scaled * coefs
178
 
179
  df_contrib = pd.DataFrame({
 
188
 
189
 
190
  def feature_friendly_name(feature):
 
191
  mapping = {
192
  "minutes_lag_1": "Minutes played last gameweek",
193
  "points_lag_1": "Points scored last gameweek",
 
224
 
225
 
226
  def format_value_display(feature, value):
227
+ """Format the raw value with contextual hint where useful."""
228
+
229
+ # Boolean / cluster / position / season → Yes / No
230
  if feature in BOOLEAN_FEATURES or feature in CLUSTER_FEATURES:
231
+ if feature == "had_haul_last_3":
232
+ return "Yes (in form)" if value >= 0.5 else "No"
233
  return "Yes" if value >= 0.5 else "No"
234
  if feature.startswith("position_") or feature.startswith("season_"):
235
  return "Yes" if value >= 0.5 else "No"
236
+
237
+ # Minutes - rounded int + context
238
+ if feature == "minutes_lag_1":
239
+ v = int(round(value))
240
+ if v >= 75:
241
+ return f"{v} min (full match)"
242
+ elif v >= 30:
243
+ return f"{v} min (sub)"
244
+ elif v > 0:
245
+ return f"{v} min (cameo)"
246
+ else:
247
+ return f"{v} min (didn't play)"
248
+
249
+ if feature.startswith("minutes_played_rolling_"):
250
+ v = int(round(value))
251
+ if v >= 75:
252
+ return f"{v} min (regular starter)"
253
+ elif v >= 45:
254
+ return f"{v} min (rotation regular)"
255
+ elif v >= 15:
256
+ return f"{v} min (fringe role)"
257
+ else:
258
+ return f"{v} min (rarely featured)"
259
+
260
+ # Points lag
261
+ if feature == "points_lag_1":
262
+ v = int(round(value))
263
+ if v >= 10:
264
+ return f"{v} (hauled)"
265
+ elif v >= 5:
266
+ return f"{v} (good return)"
267
+ elif v >= 2:
268
+ return f"{v} (decent)"
269
+ else:
270
+ return f"{v} (blank)"
271
+
272
+ # Rolling points
273
+ if feature.startswith("points_rolling_"):
274
+ if value >= 6:
275
+ return f"{value:.1f} (excellent form)"
276
+ elif value >= 4:
277
+ return f"{value:.1f} (decent form)"
278
+ elif value >= 2:
279
+ return f"{value:.1f} (modest form)"
280
+ else:
281
+ return f"{value:.1f} (out of form)"
282
+
283
+ # Strengths
284
+ if feature == "opponent_strength":
285
+ if value <= 2.5:
286
+ return f"{value:.1f} (weak)"
287
+ elif value >= 4:
288
+ return f"{value:.1f} (strong)"
289
+ else:
290
+ return f"{value:.1f} (average)"
291
+
292
+ if feature == "team_strength":
293
+ if value <= 2.5:
294
+ return f"{value:.1f} (struggling)"
295
+ elif value >= 4:
296
+ return f"{value:.1f} (in form)"
297
+ else:
298
+ return f"{value:.1f} (average)"
299
+
300
+ # Price
301
+ if feature == "value":
302
+ if value >= 10:
303
+ return f"£{value:.1f}m (premium)"
304
+ elif value >= 7:
305
+ return f"£{value:.1f}m (mid-price)"
306
+ else:
307
+ return f"£{value:.1f}m (budget)"
308
+
309
+ # BPS
310
+ if feature.startswith("bps_rolling_"):
311
+ if value >= 25:
312
+ return f"{value:.1f} (high)"
313
+ elif value >= 15:
314
+ return f"{value:.1f} (decent)"
315
+ else:
316
+ return f"{value:.1f} (low)"
317
+
318
+ # Gameweek number - just hide context, it's not interpretable
319
+ if feature == "gameweek":
320
+ return f"GW{int(round(value))}"
321
+
322
+ # Volatility
323
+ if feature == "points_rolling_std_10":
324
+ if value >= 4:
325
+ return f"{value:.1f} (volatile)"
326
+ elif value >= 2:
327
+ return f"{value:.1f} (moderate)"
328
+ else:
329
+ return f"{value:.1f} (consistent)"
330
+
331
+ # Default numeric
332
  if isinstance(value, (int, float, np.floating, np.integer)) and not isinstance(value, bool):
333
  return f"{value:.1f}"
334
  return str(value)
335
 
336
 
337
  def should_skip_feature(feature, value):
338
+ """Hide rows that aren't informative for the user."""
339
+ # Cluster=No / position=No / season=No → not informative
340
  if feature in CLUSTER_FEATURES and value < 0.5:
341
  return True
342
  if feature.startswith("position_") and value < 0.5:
343
  return True
344
+ if feature.startswith("season_"):
345
+ # Season membership is true/false but not actionable - always hide
346
+ return True
347
+ # Gameweek number isn't a meaningful "factor" - hide
348
+ if feature == "gameweek":
349
  return True
350
  return False
351
 
352
 
353
  # ============================================================
354
+ # PLAIN-ENGLISH SUMMARY
355
  # ============================================================
356
 
357
  def generate_summary(row, reg_pred, cls_pred, probs):
 
358
  parts = []
359
 
360
  venue = "at home" if row.get("was_home", 0) == 1 else "away"
 
403
 
404
 
405
  # ============================================================
406
+ # CLASS PROBABILITY CHART (Plotly, on-brand)
407
+ # ============================================================
408
+
409
+ def plot_class_probabilities(probs, predicted_class):
410
+ labels = [f"{CLASS_LABELS[i]}<br><span style='font-size:10px;color:rgba(255,255,255,0.5)'>{CLASS_RANGES[i]}</span>" for i in range(4)]
411
+ values = [probs.get(i, 0) * 100 for i in range(4)]
412
+ colors = [CLASS_COLORS[i] if i != predicted_class else CLASS_COLORS[i] for i in range(4)]
413
+ opacities = [1.0 if i == predicted_class else 0.45 for i in range(4)]
414
+
415
+ fig = go.Figure()
416
+ fig.add_trace(go.Bar(
417
+ x=labels,
418
+ y=values,
419
+ marker=dict(
420
+ color=colors,
421
+ opacity=opacities,
422
+ line=dict(width=0),
423
+ ),
424
+ text=[f"{v:.1f}%" for v in values],
425
+ textposition="outside",
426
+ textfont=dict(color=PL_WHITE, size=14, family="Helvetica Neue"),
427
+ hovertemplate="<b>%{x}</b><br>Probability: %{y:.1f}%<extra></extra>",
428
+ ))
429
+
430
+ fig.update_layout(
431
+ plot_bgcolor="rgba(0,0,0,0)",
432
+ paper_bgcolor="rgba(0,0,0,0)",
433
+ font=dict(color=PL_WHITE, family="Helvetica Neue"),
434
+ height=280,
435
+ margin=dict(l=20, r=20, t=40, b=20),
436
+ yaxis=dict(
437
+ range=[0, max(max(values) * 1.25, 20)],
438
+ showgrid=True,
439
+ gridcolor="rgba(255,255,255,0.08)",
440
+ tickformat=".0f",
441
+ ticksuffix="%",
442
+ zeroline=False,
443
+ ),
444
+ xaxis=dict(
445
+ showgrid=False,
446
+ zeroline=False,
447
+ ),
448
+ showlegend=False,
449
+ )
450
+ return fig
451
+
452
+
453
+ # ============================================================
454
+ # YOUTUBE
455
  # ============================================================
456
 
457
  def youtube_highlights_url(row):
 
458
  home_team = row["team"] if row.get("was_home", 0) == 1 else row["opponent"]
459
  away_team = row["opponent"] if row.get("was_home", 0) == 1 else row["team"]
460
  season = row["season"]
 
654
 
655
  st.markdown("<br>", unsafe_allow_html=True)
656
 
657
+ # CLASS PROBABILITY BAR (Plotly)
658
  st.markdown("#### Class probabilities")
659
+ fig = plot_class_probabilities(probs, cls_pred)
660
+ st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
 
 
 
 
 
 
 
661
 
662
  # MAIN CONTENT - 2 COLS
663
  col_left, col_right = st.columns([1, 1])
 
671
  with col_right:
672
  st.markdown('<div class="pl-card">', unsafe_allow_html=True)
673
  st.markdown("#### Why the model is predicting this")
674
+ top_up, top_down = feature_contributions(row, cls_bundle, top_n_up=10, top_n_down=8)
675
 
676
  st.markdown(f"<div style='color: {PL_CYAN}; font-weight: 700; font-size: 12px; letter-spacing: 0.05em; margin-bottom: 8px;'>PUSHING TOWARD HAUL</div>", unsafe_allow_html=True)
677
+ shown_up = 0
678
  for _, r in top_up.iterrows():
679
  if should_skip_feature(r["feature"], r["raw_value"]):
680
  continue
681
+ if shown_up >= 6:
682
+ break
683
  label = feature_friendly_name(r["feature"])
684
  val_str = format_value_display(r["feature"], r["raw_value"])
685
  st.markdown(
 
689
  f"</div>",
690
  unsafe_allow_html=True,
691
  )
692
+ shown_up += 1
693
 
694
  st.markdown(f"<div style='color: {PL_PINK}; font-weight: 700; font-size: 12px; letter-spacing: 0.05em; margin-top: 16px; margin-bottom: 8px;'>PUSHING AWAY FROM HAUL</div>", unsafe_allow_html=True)
695
+ shown_down = 0
696
  for _, r in top_down.iterrows():
697
  if should_skip_feature(r["feature"], r["raw_value"]):
698
  continue
699
+ if shown_down >= 5:
700
+ break
701
  label = feature_friendly_name(r["feature"])
702
  val_str = format_value_display(r["feature"], r["raw_value"])
703
  st.markdown(
 
707
  f"</div>",
708
  unsafe_allow_html=True,
709
  )
710
+ shown_down += 1
711
 
712
  st.markdown("</div>", unsafe_allow_html=True)
713
 
714
+ # YOUTUBE
715
  yt_url = youtube_highlights_url(row)
716
  st.markdown(
717
  f"""