import pathlib, shutil, zipfile import pandas import gradio import huggingface_hub import autogluon.tabular MODEL_REPO_ID = "jennifee/classical_automl_model" ZIP_FILENAME = "autogluon_predictor_dir.zip" CACHE_DIR = pathlib.Path("hf_assets") EXTRACT_DIR = CACHE_DIR / "predictor_native" FEATURE_COLS_MODEL = ['phone_hours', 'computer_hours', 'device_count', 'sleep_quality', 'sleep_time', 'sleep_hours'] TARGET_COL = "use_before_bed" SLEEP_QUALITY_LABELS = ['good', 'medium', 'bad'] USE_BEFORE_BED_LABELS = {0: 'No', 1: 'Yes'} def _prepare_predictor_dir() -> str: CACHE_DIR.mkdir(parents=True, exist_ok=True) local_zip = huggingface_hub.hf_hub_download( repo_id=MODEL_REPO_ID, filename=ZIP_FILENAME, repo_type="model", local_dir=str(CACHE_DIR), local_dir_use_symlinks=False, ) if EXTRACT_DIR.exists(): shutil.rmtree(EXTRACT_DIR) EXTRACT_DIR.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(local_zip, "r") as zf: zf.extractall(str(EXTRACT_DIR)) contents = list(EXTRACT_DIR.iterdir()) predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR return str(predictor_root) PREDICTOR_DIR = _prepare_predictor_dir() PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False) def _human_label(c): try: ci = int(c) if ci in USE_BEFORE_BED_LABELS: return USE_BEFORE_BED_LABELS[ci] except Exception: pass return str(c) def do_predict(phone_hours, computer_hours, device_count, sleep_quality_label, sleep_time, sleep_hours): row = { "phone_hours": float(phone_hours), "computer_hours": float(computer_hours), "device_count": int(device_count), "sleep_quality": sleep_quality_label, "sleep_time": int(sleep_time), "sleep_hours": float(sleep_hours), } X = pandas.DataFrame([row], columns=FEATURE_COLS_MODEL) pred_series = PREDICTOR.predict(X) raw_pred = pred_series.iloc[0] pred_label = _human_label(raw_pred) proba = None try: proba = PREDICTOR.predict_proba(X) if isinstance(proba, pandas.Series): proba = proba.to_frame().T except Exception as e: print(f"Error calculating probabilities: {e}") proba = None proba_dict = None if proba is not None: row0 = proba.iloc[0] tmp = {} for cls in [0, 1]: val = None if cls in row0.index: val = row0[cls] elif str(cls) in row0.index: val = row0[str(cls)] if val is not None: key = _human_label(cls) tmp[key] = float(tmp.get(key, 0.0)) + float(val) if tmp: proba_dict = dict(sorted(tmp.items(), key=lambda kv: kv[1], reverse=True)) md = f"**Prediction:** {pred_label}" if proba_dict: md += f" \n**Confidence:** {round(proba_dict.get(pred_label, 0.0) * 100, 2)}%" return proba_dict, md EXAMPLES = [[3.5, 5.0, 3, 'good', 23, 7.0], [4.2, 6.5, 3, 'medium', 0, 6.5], [5.0, 4.0, 4, 'bad', 1, 6.0], [2.0, 7.5, 3, 'good', 22, 7.5], [3.8, 6.0, 3, 'medium', 0, 6.0]] with gradio.Blocks() as demo: gradio.Markdown("# Predict Phone Use Before Bed") gradio.Markdown( "This app predicts whether a student uses their phone before bed based on their sleeping habits." "\nEnter the student's sleeping habits below to get a prediction." ) with gradio.Row(): phone_hours = gradio.Number(value=3.5, precision=1, label=FEATURE_COLS_MODEL[0]) computer_hours = gradio.Number(value=5.0, precision=1, label=FEATURE_COLS_MODEL[1]) device_count = gradio.Number(value=3, precision=0, label=FEATURE_COLS_MODEL[2]) with gradio.Row(): sleep_quality_label = gradio.Radio(choices=SLEEP_QUALITY_LABELS, value="good", label=FEATURE_COLS_MODEL[3]) sleep_time = gradio.Number(value=23, precision=0, label=FEATURE_COLS_MODEL[4]) sleep_hours = gradio.Number(value=7.0, precision=1, label=FEATURE_COLS_MODEL[5]) proba_pretty = gradio.Label(num_top_classes=2, label="Probability of Using Phone Before Bed") prediction_output = gradio.Markdown() inputs = [phone_hours, computer_hours, device_count, sleep_quality_label, sleep_time, sleep_hours] outputs = [proba_pretty, prediction_output] for comp in inputs: comp.change(fn=do_predict, inputs=inputs, outputs=outputs) gradio.Examples( examples=EXAMPLES, inputs=inputs, label="Representative examples", examples_per_page=5, cache_examples=False, ) if __name__ == "__main__": demo.launch(debug=False)