File size: 5,745 Bytes
420b1ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import html
import json
import os
from pathlib import Path
from typing import Any

from fastapi import FastAPI, Form
from fastapi.responses import HTMLResponse
from huggingface_hub import snapshot_download
from setfit import SetFitModel

LABELS = [
    "course concept confusion",
    "teaching or explanation problem",
    "language or communication barrier",
    "workload or time-management difficulty",
    "assessment or grade anxiety",
    "isolation or lack of support",
    "explicit help-seeking",
    "urgent crisis-related language",
]

MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "").strip()
MODEL_DIR = Path(os.environ.get("MODEL_DIR", "./student_support_setfit_model"))
ARTIFACT_DIR = Path(os.environ.get("ARTIFACT_DIR", "/data/setfit_artifacts"))

if not ARTIFACT_DIR.parent.exists() or not os.access(ARTIFACT_DIR.parent, os.W_OK):
    ARTIFACT_DIR = Path("./setfit_artifacts")

ID_TO_LABEL = {i: label for i, label in enumerate(LABELS)}
REQUIRED_MODEL_FILES = ["model_head.pkl", "config.json"]

app = FastAPI(title="Student Academic Support Classifier")


def ensure_support_files() -> None:
    ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
    label_to_id = {label: i for i, label in ID_TO_LABEL.items()}
    (ARTIFACT_DIR / "labels.json").write_text(
        json.dumps({"id_to_label": ID_TO_LABEL, "label_to_id": label_to_id, "labels": LABELS}, indent=2)
    )


def ensure_model_files() -> None:
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    if MODEL_REPO_ID:
        snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(MODEL_DIR), local_dir_use_symlinks=False)

    missing = [name for name in REQUIRED_MODEL_FILES if not (MODEL_DIR / name).exists()]
    if missing:
        raise FileNotFoundError(
            "Missing SetFit model files: "
            + ", ".join(missing)
            + f". Upload your trained SetFit files into {MODEL_DIR}, or set MODEL_REPO_ID in the Space settings."
        )


def load_model() -> tuple[Any, str]:
    try:
        ensure_support_files()
        ensure_model_files()
        model = SetFitModel.from_pretrained(str(MODEL_DIR))
        return model, f"Model loaded from {MODEL_DIR}."
    except Exception as exc:
        return None, f"Model failed to load: {exc}"


MODEL, MODEL_STATUS = load_model()
print(MODEL_STATUS)


def normalize_prediction(prediction: Any) -> str:
    if isinstance(prediction, dict):
        for key in ("label", "prediction", "class"):
            if key in prediction:
                prediction = prediction[key]
                break
        else:
            prediction = json.dumps(prediction)
    if isinstance(prediction, (list, tuple)):
        prediction = prediction[0] if prediction else ""
    try:
        return prediction if isinstance(prediction, str) else ID_TO_LABEL[int(prediction)]
    except Exception:
        return str(prediction)


def classify(text: str) -> tuple[str, str, str]:
    text = (text or "").strip()
    if not text:
        return "Input required", "N/A", "Please enter text."
    if MODEL is None:
        return "Error", "Model not initialized", "Check the model status below."

    raw_prediction = MODEL.predict([text])[0]
    label = normalize_prediction(raw_prediction)

    if label == "urgent crisis-related language":
        return "Flagged: Urgent", label, "Immediate concern detected. Escalate to appropriate support."
    if label == "explicit help-seeking":
        return "Flagged: Help-Seeking", label, "Student is directly asking for assistance."
    return "Academic Support", label, "Standard academic struggle detected."


def page(result: tuple[str, str, str] | None = None, input_text: str = "") -> str:
    decision, label, interpretation = result or ("", "", "")
    result_html = ""
    if result:
        result_html = f"""
        <section class='result'>
          <p><strong>Decision:</strong> {html.escape(decision)}</p>
          <p><strong>Predicted category:</strong> {html.escape(label)}</p>
          <p><strong>Interpretation:</strong> {html.escape(interpretation)}</p>
        </section>
        """
    return f"""
    <!doctype html>
    <html>
      <head>
        <title>Student Academic Support Classifier</title>
        <style>
          body {{ font-family: system-ui, sans-serif; max-width: 840px; margin: 48px auto; padding: 0 24px; }}
          textarea {{ width: 100%; min-height: 130px; padding: 12px; font-size: 16px; }}
          button {{ margin-top: 12px; padding: 10px 16px; font-size: 16px; cursor: pointer; }}
          .result {{ margin-top: 24px; padding: 16px; border: 1px solid #ddd; border-radius: 12px; background: #fafafa; }}
          .status {{ margin-top: 24px; color: #555; font-size: 14px; }}
        </style>
      </head>
      <body>
        <h1>Student Academic Support Classifier</h1>
        <form method="post" action="/classify">
          <label for="text">What are you struggling most with in class?</label><br><br>
          <textarea id="text" name="text">{html.escape(input_text)}</textarea><br>
          <button type="submit">Classify</button>
        </form>
        {result_html}
        <p class="status"><strong>Model status:</strong> {html.escape(MODEL_STATUS)}</p>
        <p class="status">Educational demo only; not a diagnostic, counseling, disciplinary, or surveillance tool.</p>
      </body>
    </html>
    """


@app.get("/", response_class=HTMLResponse)
def home() -> HTMLResponse:
    return HTMLResponse(page())


@app.post("/classify", response_class=HTMLResponse)
def classify_route(text: str = Form("")) -> HTMLResponse:
    return HTMLResponse(page(classify(text), input_text=text))


@app.get("/health")
def health() -> dict[str, str]:
    return {"status": "ok", "model_status": MODEL_STATUS}