Spaces:
Sleeping
Sleeping
| import html | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| from fastapi import FastAPI, Form | |
| from fastapi.responses import HTMLResponse | |
| from huggingface_hub import snapshot_download | |
| from setfit import SetFitModel | |
| LABELS = [ | |
| "course concept confusion", | |
| "teaching or explanation problem", | |
| "language or communication barrier", | |
| "workload or time-management difficulty", | |
| "assessment or grade anxiety", | |
| "isolation or lack of support", | |
| "explicit help-seeking", | |
| "urgent crisis-related language", | |
| ] | |
| MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "").strip() | |
| MODEL_DIR = Path(os.environ.get("MODEL_DIR", "./student_support_setfit_model")) | |
| ARTIFACT_DIR = Path(os.environ.get("ARTIFACT_DIR", "/data/setfit_artifacts")) | |
| if not ARTIFACT_DIR.parent.exists() or not os.access(ARTIFACT_DIR.parent, os.W_OK): | |
| ARTIFACT_DIR = Path("./setfit_artifacts") | |
| ID_TO_LABEL = {i: label for i, label in enumerate(LABELS)} | |
| REQUIRED_MODEL_FILES = ["model_head.pkl", "config.json"] | |
| app = FastAPI(title="Student Academic Support Classifier") | |
| def ensure_support_files() -> None: | |
| ARTIFACT_DIR.mkdir(parents=True, exist_ok=True) | |
| label_to_id = {label: i for i, label in ID_TO_LABEL.items()} | |
| (ARTIFACT_DIR / "labels.json").write_text( | |
| json.dumps({"id_to_label": ID_TO_LABEL, "label_to_id": label_to_id, "labels": LABELS}, indent=2) | |
| ) | |
| def ensure_model_files() -> None: | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| if MODEL_REPO_ID: | |
| snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(MODEL_DIR), local_dir_use_symlinks=False) | |
| missing = [name for name in REQUIRED_MODEL_FILES if not (MODEL_DIR / name).exists()] | |
| if missing: | |
| raise FileNotFoundError( | |
| "Missing SetFit model files: " | |
| + ", ".join(missing) | |
| + f". Upload your trained SetFit files into {MODEL_DIR}, or set MODEL_REPO_ID in the Space settings." | |
| ) | |
| def load_model() -> tuple[Any, str]: | |
| try: | |
| ensure_support_files() | |
| ensure_model_files() | |
| model = SetFitModel.from_pretrained(str(MODEL_DIR)) | |
| return model, f"Model loaded from {MODEL_DIR}." | |
| except Exception as exc: | |
| return None, f"Model failed to load: {exc}" | |
| MODEL, MODEL_STATUS = load_model() | |
| print(MODEL_STATUS) | |
| def normalize_prediction(prediction: Any) -> str: | |
| if isinstance(prediction, dict): | |
| for key in ("label", "prediction", "class"): | |
| if key in prediction: | |
| prediction = prediction[key] | |
| break | |
| else: | |
| prediction = json.dumps(prediction) | |
| if isinstance(prediction, (list, tuple)): | |
| prediction = prediction[0] if prediction else "" | |
| try: | |
| return prediction if isinstance(prediction, str) else ID_TO_LABEL[int(prediction)] | |
| except Exception: | |
| return str(prediction) | |
| def classify(text: str) -> tuple[str, str, str]: | |
| text = (text or "").strip() | |
| if not text: | |
| return "Input required", "N/A", "Please enter text." | |
| if MODEL is None: | |
| return "Error", "Model not initialized", "Check the model status below." | |
| raw_prediction = MODEL.predict([text])[0] | |
| label = normalize_prediction(raw_prediction) | |
| if label == "urgent crisis-related language": | |
| return "Flagged: Urgent", label, "Immediate concern detected. Escalate to appropriate support." | |
| if label == "explicit help-seeking": | |
| return "Flagged: Help-Seeking", label, "Student is directly asking for assistance." | |
| return "Academic Support", label, "Standard academic struggle detected." | |
| def page(result: tuple[str, str, str] | None = None, input_text: str = "") -> str: | |
| decision, label, interpretation = result or ("", "", "") | |
| result_html = "" | |
| if result: | |
| result_html = f""" | |
| <section class='result'> | |
| <p><strong>Decision:</strong> {html.escape(decision)}</p> | |
| <p><strong>Predicted category:</strong> {html.escape(label)}</p> | |
| <p><strong>Interpretation:</strong> {html.escape(interpretation)}</p> | |
| </section> | |
| """ | |
| return f""" | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <title>Student Academic Support Classifier</title> | |
| <style> | |
| body {{ font-family: system-ui, sans-serif; max-width: 840px; margin: 48px auto; padding: 0 24px; }} | |
| textarea {{ width: 100%; min-height: 130px; padding: 12px; font-size: 16px; }} | |
| button {{ margin-top: 12px; padding: 10px 16px; font-size: 16px; cursor: pointer; }} | |
| .result {{ margin-top: 24px; padding: 16px; border: 1px solid #ddd; border-radius: 12px; background: #fafafa; }} | |
| .status {{ margin-top: 24px; color: #555; font-size: 14px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Student Academic Support Classifier</h1> | |
| <form method="post" action="/classify"> | |
| <label for="text">What are you struggling most with in class?</label><br><br> | |
| <textarea id="text" name="text">{html.escape(input_text)}</textarea><br> | |
| <button type="submit">Classify</button> | |
| </form> | |
| {result_html} | |
| <p class="status"><strong>Model status:</strong> {html.escape(MODEL_STATUS)}</p> | |
| <p class="status">Educational demo only; not a diagnostic, counseling, disciplinary, or surveillance tool.</p> | |
| </body> | |
| </html> | |
| """ | |
| def home() -> HTMLResponse: | |
| return HTMLResponse(page()) | |
| def classify_route(text: str = Form("")) -> HTMLResponse: | |
| return HTMLResponse(page(classify(text), input_text=text)) | |
| def health() -> dict[str, str]: | |
| return {"status": "ok", "model_status": MODEL_STATUS} | |