import html
import json
import os
from pathlib import Path
from typing import Any
from fastapi import FastAPI, Form
from fastapi.responses import HTMLResponse
from huggingface_hub import snapshot_download
from setfit import SetFitModel
LABELS = [
"course concept confusion",
"teaching or explanation problem",
"language or communication barrier",
"workload or time-management difficulty",
"assessment or grade anxiety",
"isolation or lack of support",
"explicit help-seeking",
"urgent crisis-related language",
]
MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "").strip()
MODEL_DIR = Path(os.environ.get("MODEL_DIR", "./student_support_setfit_model"))
ARTIFACT_DIR = Path(os.environ.get("ARTIFACT_DIR", "/data/setfit_artifacts"))
if not ARTIFACT_DIR.parent.exists() or not os.access(ARTIFACT_DIR.parent, os.W_OK):
ARTIFACT_DIR = Path("./setfit_artifacts")
ID_TO_LABEL = {i: label for i, label in enumerate(LABELS)}
REQUIRED_MODEL_FILES = ["model_head.pkl", "config.json"]
app = FastAPI(title="Student Academic Support Classifier")
def ensure_support_files() -> None:
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
label_to_id = {label: i for i, label in ID_TO_LABEL.items()}
(ARTIFACT_DIR / "labels.json").write_text(
json.dumps({"id_to_label": ID_TO_LABEL, "label_to_id": label_to_id, "labels": LABELS}, indent=2)
)
def ensure_model_files() -> None:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
if MODEL_REPO_ID:
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(MODEL_DIR), local_dir_use_symlinks=False)
missing = [name for name in REQUIRED_MODEL_FILES if not (MODEL_DIR / name).exists()]
if missing:
raise FileNotFoundError(
"Missing SetFit model files: "
+ ", ".join(missing)
+ f". Upload your trained SetFit files into {MODEL_DIR}, or set MODEL_REPO_ID in the Space settings."
)
def load_model() -> tuple[Any, str]:
try:
ensure_support_files()
ensure_model_files()
model = SetFitModel.from_pretrained(str(MODEL_DIR))
return model, f"Model loaded from {MODEL_DIR}."
except Exception as exc:
return None, f"Model failed to load: {exc}"
MODEL, MODEL_STATUS = load_model()
print(MODEL_STATUS)
def normalize_prediction(prediction: Any) -> str:
if isinstance(prediction, dict):
for key in ("label", "prediction", "class"):
if key in prediction:
prediction = prediction[key]
break
else:
prediction = json.dumps(prediction)
if isinstance(prediction, (list, tuple)):
prediction = prediction[0] if prediction else ""
try:
return prediction if isinstance(prediction, str) else ID_TO_LABEL[int(prediction)]
except Exception:
return str(prediction)
def classify(text: str) -> tuple[str, str, str]:
text = (text or "").strip()
if not text:
return "Input required", "N/A", "Please enter text."
if MODEL is None:
return "Error", "Model not initialized", "Check the model status below."
raw_prediction = MODEL.predict([text])[0]
label = normalize_prediction(raw_prediction)
if label == "urgent crisis-related language":
return "Flagged: Urgent", label, "Immediate concern detected. Escalate to appropriate support."
if label == "explicit help-seeking":
return "Flagged: Help-Seeking", label, "Student is directly asking for assistance."
return "Academic Support", label, "Standard academic struggle detected."
def page(result: tuple[str, str, str] | None = None, input_text: str = "") -> str:
decision, label, interpretation = result or ("", "", "")
result_html = ""
if result:
result_html = f"""
Decision: {html.escape(decision)} Predicted category: {html.escape(label)} Interpretation: {html.escape(interpretation)}
Model status: {html.escape(MODEL_STATUS)}
Educational demo only; not a diagnostic, counseling, disciplinary, or surveillance tool.
""" @app.get("/", response_class=HTMLResponse) def home() -> HTMLResponse: return HTMLResponse(page()) @app.post("/classify", response_class=HTMLResponse) def classify_route(text: str = Form("")) -> HTMLResponse: return HTMLResponse(page(classify(text), input_text=text)) @app.get("/health") def health() -> dict[str, str]: return {"status": "ok", "model_status": MODEL_STATUS}