FastAPIVersion / app.py
katesiplon's picture
Upload 4 files
420b1ec verified
import html
import json
import os
from pathlib import Path
from typing import Any
from fastapi import FastAPI, Form
from fastapi.responses import HTMLResponse
from huggingface_hub import snapshot_download
from setfit import SetFitModel
LABELS = [
"course concept confusion",
"teaching or explanation problem",
"language or communication barrier",
"workload or time-management difficulty",
"assessment or grade anxiety",
"isolation or lack of support",
"explicit help-seeking",
"urgent crisis-related language",
]
MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "").strip()
MODEL_DIR = Path(os.environ.get("MODEL_DIR", "./student_support_setfit_model"))
ARTIFACT_DIR = Path(os.environ.get("ARTIFACT_DIR", "/data/setfit_artifacts"))
if not ARTIFACT_DIR.parent.exists() or not os.access(ARTIFACT_DIR.parent, os.W_OK):
ARTIFACT_DIR = Path("./setfit_artifacts")
ID_TO_LABEL = {i: label for i, label in enumerate(LABELS)}
REQUIRED_MODEL_FILES = ["model_head.pkl", "config.json"]
app = FastAPI(title="Student Academic Support Classifier")
def ensure_support_files() -> None:
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
label_to_id = {label: i for i, label in ID_TO_LABEL.items()}
(ARTIFACT_DIR / "labels.json").write_text(
json.dumps({"id_to_label": ID_TO_LABEL, "label_to_id": label_to_id, "labels": LABELS}, indent=2)
)
def ensure_model_files() -> None:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
if MODEL_REPO_ID:
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(MODEL_DIR), local_dir_use_symlinks=False)
missing = [name for name in REQUIRED_MODEL_FILES if not (MODEL_DIR / name).exists()]
if missing:
raise FileNotFoundError(
"Missing SetFit model files: "
+ ", ".join(missing)
+ f". Upload your trained SetFit files into {MODEL_DIR}, or set MODEL_REPO_ID in the Space settings."
)
def load_model() -> tuple[Any, str]:
try:
ensure_support_files()
ensure_model_files()
model = SetFitModel.from_pretrained(str(MODEL_DIR))
return model, f"Model loaded from {MODEL_DIR}."
except Exception as exc:
return None, f"Model failed to load: {exc}"
MODEL, MODEL_STATUS = load_model()
print(MODEL_STATUS)
def normalize_prediction(prediction: Any) -> str:
if isinstance(prediction, dict):
for key in ("label", "prediction", "class"):
if key in prediction:
prediction = prediction[key]
break
else:
prediction = json.dumps(prediction)
if isinstance(prediction, (list, tuple)):
prediction = prediction[0] if prediction else ""
try:
return prediction if isinstance(prediction, str) else ID_TO_LABEL[int(prediction)]
except Exception:
return str(prediction)
def classify(text: str) -> tuple[str, str, str]:
text = (text or "").strip()
if not text:
return "Input required", "N/A", "Please enter text."
if MODEL is None:
return "Error", "Model not initialized", "Check the model status below."
raw_prediction = MODEL.predict([text])[0]
label = normalize_prediction(raw_prediction)
if label == "urgent crisis-related language":
return "Flagged: Urgent", label, "Immediate concern detected. Escalate to appropriate support."
if label == "explicit help-seeking":
return "Flagged: Help-Seeking", label, "Student is directly asking for assistance."
return "Academic Support", label, "Standard academic struggle detected."
def page(result: tuple[str, str, str] | None = None, input_text: str = "") -> str:
decision, label, interpretation = result or ("", "", "")
result_html = ""
if result:
result_html = f"""
<section class='result'>
<p><strong>Decision:</strong> {html.escape(decision)}</p>
<p><strong>Predicted category:</strong> {html.escape(label)}</p>
<p><strong>Interpretation:</strong> {html.escape(interpretation)}</p>
</section>
"""
return f"""
<!doctype html>
<html>
<head>
<title>Student Academic Support Classifier</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 840px; margin: 48px auto; padding: 0 24px; }}
textarea {{ width: 100%; min-height: 130px; padding: 12px; font-size: 16px; }}
button {{ margin-top: 12px; padding: 10px 16px; font-size: 16px; cursor: pointer; }}
.result {{ margin-top: 24px; padding: 16px; border: 1px solid #ddd; border-radius: 12px; background: #fafafa; }}
.status {{ margin-top: 24px; color: #555; font-size: 14px; }}
</style>
</head>
<body>
<h1>Student Academic Support Classifier</h1>
<form method="post" action="/classify">
<label for="text">What are you struggling most with in class?</label><br><br>
<textarea id="text" name="text">{html.escape(input_text)}</textarea><br>
<button type="submit">Classify</button>
</form>
{result_html}
<p class="status"><strong>Model status:</strong> {html.escape(MODEL_STATUS)}</p>
<p class="status">Educational demo only; not a diagnostic, counseling, disciplinary, or surveillance tool.</p>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
def home() -> HTMLResponse:
return HTMLResponse(page())
@app.post("/classify", response_class=HTMLResponse)
def classify_route(text: str = Form("")) -> HTMLResponse:
return HTMLResponse(page(classify(text), input_text=text))
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok", "model_status": MODEL_STATUS}