Spaces:
Sleeping
Sleeping
File size: 5,745 Bytes
420b1ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import html
import json
import os
from pathlib import Path
from typing import Any
from fastapi import FastAPI, Form
from fastapi.responses import HTMLResponse
from huggingface_hub import snapshot_download
from setfit import SetFitModel
LABELS = [
"course concept confusion",
"teaching or explanation problem",
"language or communication barrier",
"workload or time-management difficulty",
"assessment or grade anxiety",
"isolation or lack of support",
"explicit help-seeking",
"urgent crisis-related language",
]
MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "").strip()
MODEL_DIR = Path(os.environ.get("MODEL_DIR", "./student_support_setfit_model"))
ARTIFACT_DIR = Path(os.environ.get("ARTIFACT_DIR", "/data/setfit_artifacts"))
if not ARTIFACT_DIR.parent.exists() or not os.access(ARTIFACT_DIR.parent, os.W_OK):
ARTIFACT_DIR = Path("./setfit_artifacts")
ID_TO_LABEL = {i: label for i, label in enumerate(LABELS)}
REQUIRED_MODEL_FILES = ["model_head.pkl", "config.json"]
app = FastAPI(title="Student Academic Support Classifier")
def ensure_support_files() -> None:
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
label_to_id = {label: i for i, label in ID_TO_LABEL.items()}
(ARTIFACT_DIR / "labels.json").write_text(
json.dumps({"id_to_label": ID_TO_LABEL, "label_to_id": label_to_id, "labels": LABELS}, indent=2)
)
def ensure_model_files() -> None:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
if MODEL_REPO_ID:
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(MODEL_DIR), local_dir_use_symlinks=False)
missing = [name for name in REQUIRED_MODEL_FILES if not (MODEL_DIR / name).exists()]
if missing:
raise FileNotFoundError(
"Missing SetFit model files: "
+ ", ".join(missing)
+ f". Upload your trained SetFit files into {MODEL_DIR}, or set MODEL_REPO_ID in the Space settings."
)
def load_model() -> tuple[Any, str]:
try:
ensure_support_files()
ensure_model_files()
model = SetFitModel.from_pretrained(str(MODEL_DIR))
return model, f"Model loaded from {MODEL_DIR}."
except Exception as exc:
return None, f"Model failed to load: {exc}"
MODEL, MODEL_STATUS = load_model()
print(MODEL_STATUS)
def normalize_prediction(prediction: Any) -> str:
if isinstance(prediction, dict):
for key in ("label", "prediction", "class"):
if key in prediction:
prediction = prediction[key]
break
else:
prediction = json.dumps(prediction)
if isinstance(prediction, (list, tuple)):
prediction = prediction[0] if prediction else ""
try:
return prediction if isinstance(prediction, str) else ID_TO_LABEL[int(prediction)]
except Exception:
return str(prediction)
def classify(text: str) -> tuple[str, str, str]:
text = (text or "").strip()
if not text:
return "Input required", "N/A", "Please enter text."
if MODEL is None:
return "Error", "Model not initialized", "Check the model status below."
raw_prediction = MODEL.predict([text])[0]
label = normalize_prediction(raw_prediction)
if label == "urgent crisis-related language":
return "Flagged: Urgent", label, "Immediate concern detected. Escalate to appropriate support."
if label == "explicit help-seeking":
return "Flagged: Help-Seeking", label, "Student is directly asking for assistance."
return "Academic Support", label, "Standard academic struggle detected."
def page(result: tuple[str, str, str] | None = None, input_text: str = "") -> str:
decision, label, interpretation = result or ("", "", "")
result_html = ""
if result:
result_html = f"""
<section class='result'>
<p><strong>Decision:</strong> {html.escape(decision)}</p>
<p><strong>Predicted category:</strong> {html.escape(label)}</p>
<p><strong>Interpretation:</strong> {html.escape(interpretation)}</p>
</section>
"""
return f"""
<!doctype html>
<html>
<head>
<title>Student Academic Support Classifier</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 840px; margin: 48px auto; padding: 0 24px; }}
textarea {{ width: 100%; min-height: 130px; padding: 12px; font-size: 16px; }}
button {{ margin-top: 12px; padding: 10px 16px; font-size: 16px; cursor: pointer; }}
.result {{ margin-top: 24px; padding: 16px; border: 1px solid #ddd; border-radius: 12px; background: #fafafa; }}
.status {{ margin-top: 24px; color: #555; font-size: 14px; }}
</style>
</head>
<body>
<h1>Student Academic Support Classifier</h1>
<form method="post" action="/classify">
<label for="text">What are you struggling most with in class?</label><br><br>
<textarea id="text" name="text">{html.escape(input_text)}</textarea><br>
<button type="submit">Classify</button>
</form>
{result_html}
<p class="status"><strong>Model status:</strong> {html.escape(MODEL_STATUS)}</p>
<p class="status">Educational demo only; not a diagnostic, counseling, disciplinary, or surveillance tool.</p>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
def home() -> HTMLResponse:
return HTMLResponse(page())
@app.post("/classify", response_class=HTMLResponse)
def classify_route(text: str = Form("")) -> HTMLResponse:
return HTMLResponse(page(classify(text), input_text=text))
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok", "model_status": MODEL_STATUS}
|