bcueva's picture
Upload app.py with huggingface_hub
e38ead2 verified
import pathlib, shutil, zipfile
import pandas
import gradio
import huggingface_hub
import autogluon.tabular
MODEL_REPO_ID = "jennifee/classical_automl_model"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
FEATURE_COLS_MODEL = ['phone_hours', 'computer_hours', 'device_count', 'sleep_quality', 'sleep_time', 'sleep_hours']
TARGET_COL = "use_before_bed"
SLEEP_QUALITY_LABELS = ['good', 'medium', 'bad']
USE_BEFORE_BED_LABELS = {0: 'No', 1: 'Yes'}
def _prepare_predictor_dir() -> str:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False)
def _human_label(c):
try:
ci = int(c)
if ci in USE_BEFORE_BED_LABELS:
return USE_BEFORE_BED_LABELS[ci]
except Exception:
pass
return str(c)
def do_predict(phone_hours, computer_hours, device_count, sleep_quality_label, sleep_time, sleep_hours):
row = {
"phone_hours": float(phone_hours),
"computer_hours": float(computer_hours),
"device_count": int(device_count),
"sleep_quality": sleep_quality_label,
"sleep_time": int(sleep_time),
"sleep_hours": float(sleep_hours),
}
X = pandas.DataFrame([row], columns=FEATURE_COLS_MODEL)
pred_series = PREDICTOR.predict(X)
raw_pred = pred_series.iloc[0]
pred_label = _human_label(raw_pred)
proba = None
try:
proba = PREDICTOR.predict_proba(X)
if isinstance(proba, pandas.Series):
proba = proba.to_frame().T
except Exception as e:
print(f"Error calculating probabilities: {e}")
proba = None
proba_dict = None
if proba is not None:
row0 = proba.iloc[0]
tmp = {}
for cls in [0, 1]:
val = None
if cls in row0.index:
val = row0[cls]
elif str(cls) in row0.index:
val = row0[str(cls)]
if val is not None:
key = _human_label(cls)
tmp[key] = float(tmp.get(key, 0.0)) + float(val)
if tmp:
proba_dict = dict(sorted(tmp.items(), key=lambda kv: kv[1], reverse=True))
md = f"**Prediction:** {pred_label}"
if proba_dict:
md += f" \n**Confidence:** {round(proba_dict.get(pred_label, 0.0) * 100, 2)}%"
return proba_dict, md
EXAMPLES = [[3.5, 5.0, 3, 'good', 23, 7.0], [4.2, 6.5, 3, 'medium', 0, 6.5], [5.0, 4.0, 4, 'bad', 1, 6.0], [2.0, 7.5, 3, 'good', 22, 7.5], [3.8, 6.0, 3, 'medium', 0, 6.0]]
with gradio.Blocks() as demo:
gradio.Markdown("# Predict Phone Use Before Bed")
gradio.Markdown(
"This app predicts whether a student uses their phone before bed based on their sleeping habits."
"\nEnter the student's sleeping habits below to get a prediction."
)
with gradio.Row():
phone_hours = gradio.Number(value=3.5, precision=1, label=FEATURE_COLS_MODEL[0])
computer_hours = gradio.Number(value=5.0, precision=1, label=FEATURE_COLS_MODEL[1])
device_count = gradio.Number(value=3, precision=0, label=FEATURE_COLS_MODEL[2])
with gradio.Row():
sleep_quality_label = gradio.Radio(choices=SLEEP_QUALITY_LABELS, value="good", label=FEATURE_COLS_MODEL[3])
sleep_time = gradio.Number(value=23, precision=0, label=FEATURE_COLS_MODEL[4])
sleep_hours = gradio.Number(value=7.0, precision=1, label=FEATURE_COLS_MODEL[5])
proba_pretty = gradio.Label(num_top_classes=2, label="Probability of Using Phone Before Bed")
prediction_output = gradio.Markdown()
inputs = [phone_hours, computer_hours, device_count, sleep_quality_label, sleep_time, sleep_hours]
outputs = [proba_pretty, prediction_output]
for comp in inputs:
comp.change(fn=do_predict, inputs=inputs, outputs=outputs)
gradio.Examples(
examples=EXAMPLES,
inputs=inputs,
label="Representative examples",
examples_per_page=5,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(debug=False)