ccm commited on
Commit
c8d0eae
·
verified ·
1 Parent(s): 52ad6eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # For filesystem operations
2
+ import shutil # For directory cleanup
3
+ import zipfile # For extracting model archives
4
+ import pathlib # For path manipulations
5
+ import pandas # For tabular data handling
6
+ import gradio # For interactive UI
7
+ import huggingface_hub # For downloading model assets
8
+ import autogluon.tabular # For loading and running AutoGluon predictors
9
+
10
+ # Settings
11
+ MODEL_REPO_ID = "ccm/2024-24679-tabular-autolguon-predictor"
12
+ ZIP_FILENAME = "autogluon_predictor_dir.zip"
13
+ CACHE_DIR = pathlib.Path("hf_assets")
14
+ EXTRACT_DIR = CACHE_DIR / "predictor_native"
15
+
16
+ # Feature schema (must match training)
17
+ FEATURE_COLS = [
18
+ "About how many hours per week do you spend listening to music?",
19
+ "Approximately how many songs are in your music library?",
20
+ "Approximately how many playlists have you created yourself?",
21
+ "How often do you share music with others?",
22
+ "Which decade of music do you listen to most?",
23
+ "How often do you attend live music events?",
24
+ "Do you prefer songs with lyrics or instrumental music?",
25
+ ]
26
+ TARGET_COL = "Do you usually listen to music alone or with others?"
27
+
28
+ # Encodings (aligned to survey UI)
29
+ LIKERT5_LABELS = ["Never", "Rarely", "Sometimes", "Often", "Very Often"]
30
+ LIKERT5_MAP = {label: idx for idx, label in enumerate(LIKERT5_LABELS)}
31
+
32
+ DECADE_LABELS = ["1970s and before", "1980s", "1990s", "2000s", "2010s", "2020s"]
33
+ DECADE_MAP = {label: idx for idx, label in enumerate(DECADE_LABELS)}
34
+
35
+ LYRICS_LABELS = ["Lyrics", "Instrumental", "Both equally"]
36
+ LYRICS_MAP = {label: idx for idx, label in enumerate(LYRICS_LABELS)}
37
+
38
+ # Outcome label mapping
39
+ OUTCOME_LABELS = {
40
+ 0: "Mostly Alone",
41
+ 1: "Mostly With Others",
42
+ }
43
+
44
+ def _prepare_predictor_dir() -> str:
45
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
46
+ local_zip = huggingface_hub.hf_hub_download(
47
+ repo_id=MODEL_REPO_ID,
48
+ filename=ZIP_FILENAME,
49
+ repo_type="model",
50
+ local_dir=str(CACHE_DIR),
51
+ local_dir_use_symlinks=False,
52
+ )
53
+ if EXTRACT_DIR.exists():
54
+ shutil.rmtree(EXTRACT_DIR)
55
+ EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
56
+ with zipfile.ZipFile(local_zip, "r") as zf:
57
+ zf.extractall(str(EXTRACT_DIR))
58
+ contents = list(EXTRACT_DIR.iterdir())
59
+ predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
60
+ return str(predictor_root)
61
+
62
+ PREDICTOR_DIR = _prepare_predictor_dir()
63
+ PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR)
64
+
65
+ # Class-to-label mapper
66
+ def _human_label(c):
67
+ try:
68
+ ci = int(c)
69
+ if ci in OUTCOME_LABELS:
70
+ return OUTCOME_LABELS[ci]
71
+ except Exception:
72
+ pass
73
+ if c in OUTCOME_LABELS:
74
+ return OUTCOME_LABELS[c]
75
+ return str(c)
76
+
77
+ # Inference
78
+ def do_predict(hours_per_week, num_songs, num_playlists, share_label, decade_label, live_events_label, lyrics_label):
79
+ share_code = LIKERT5_MAP[share_label]
80
+ decade_code = DECADE_MAP[decade_label]
81
+ live_events_code = LIKERT5_MAP[live_events_label]
82
+ lyrics_code = LYRICS_MAP[lyrics_label]
83
+
84
+ row = {
85
+ FEATURE_COLS[0]: float(hours_per_week),
86
+ FEATURE_COLS[1]: int(num_songs),
87
+ FEATURE_COLS[2]: int(num_playlists),
88
+ FEATURE_COLS[3]: int(share_code),
89
+ FEATURE_COLS[4]: int(decade_code),
90
+ FEATURE_COLS[5]: int(live_events_code),
91
+ FEATURE_COLS[6]: int(lyrics_code),
92
+ }
93
+ X = pandas.DataFrame([row], columns=FEATURE_COLS)
94
+
95
+ pred_series = PREDICTOR.predict(X)
96
+ raw_pred = pred_series.iloc[0]
97
+
98
+ try:
99
+ proba = PREDICTOR.predict_proba(X)
100
+ if isinstance(proba, pandas.Series):
101
+ proba = proba.to_frame().T
102
+ except Exception:
103
+ proba = None
104
+
105
+ pred_label = _human_label(raw_pred)
106
+
107
+ proba_dict = None
108
+ if proba is not None:
109
+ row0 = proba.iloc[0]
110
+ tmp = {}
111
+ for cls, val in row0.items():
112
+ key = _human_label(cls)
113
+ tmp[key] = float(val) + float(tmp.get(key, 0.0))
114
+ proba_dict = dict(sorted(tmp.items(), key=lambda kv: kv[1], reverse=True))
115
+
116
+ df_out = pandas.DataFrame([{
117
+ "Predicted outcome": pred_label,
118
+ "Confidence (%)": round((proba_dict.get(pred_label, 1.0) if proba_dict else 1.0) * 100, 2),
119
+ }])
120
+
121
+ md = f"**Prediction:** {pred_label}"
122
+ if proba_dict:
123
+ md += f" \n**Confidence:** {round(proba_dict.get(pred_label, 0.0) * 100, 2)}%"
124
+
125
+ return md, proba_dict, df_out
126
+
127
+ # Representative examples
128
+ EXAMPLES = [
129
+ [5.0, 300, 3, "Rarely", "2010s", "Rarely", "Lyrics"],
130
+ [18.0, 1500, 25, "Often", "2000s", "Often", "Both equally"],
131
+ [12.0, 8000, 40, "Sometimes", "1990s", "Sometimes", "Instrumental"],
132
+ [4.0, 120, 1, "Never", "1970s and before", "Rarely", "Lyrics"],
133
+ [22.0, 500, 10, "Very Often", "2020s", "Very Often", "Lyrics"],
134
+ ]
135
+
136
+ # Gradio UI
137
+ with gradio.Blocks() as demo:
138
+ with gradio.Row():
139
+ hours_per_week = gradio.Slider(0, 80, step=0.5, value=5.0, label=FEATURE_COLS[0])
140
+ num_songs = gradio.Number(value=200, precision=0, label=FEATURE_COLS[1])
141
+ num_playlists = gradio.Number(value=5, precision=0, label=FEATURE_COLS[2])
142
+
143
+ with gradio.Row():
144
+ share_label = gradio.Radio(choices=LIKERT5_LABELS, value="Sometimes", label="How often do you share music with others?")
145
+ live_events_label = gradio.Radio(choices=LIKERT5_LABELS, value="Rarely", label="How often do you attend live music events?")
146
+
147
+ with gradio.Row():
148
+ decade_label = gradio.Radio(choices=DECADE_LABELS, value="2010s", label="Which decade of music do you listen to most?")
149
+ lyrics_label = gradio.Radio(choices=LYRICS_LABELS, value="Lyrics", label="Do you prefer songs with lyrics or instrumental music?")
150
+
151
+ proba_pretty = gradio.Label(num_top_classes=5, label="Class probabilities")
152
+ pred_table = gradio.Dataframe(headers=["Predicted outcome", "Confidence (%)"], label="Prediction (compact)", interactive=False)
153
+
154
+ inputs = [hours_per_week, num_songs, num_playlists, share_label, decade_label, live_events_label, lyrics_label]
155
+ for comp in inputs:
156
+ comp.change(fn=do_predict, inputs=inputs, outputs=[proba_pretty, pred_table])
157
+
158
+ gradio.Examples(
159
+ examples=EXAMPLES,
160
+ inputs=inputs,
161
+ label="Representative examples",
162
+ examples_per_page=5,
163
+ cache_examples=False,
164
+ )
165
+
166
+ if __name__ == "__main__":
167
+ demo.launch()