Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import urllib.error | |
| import urllib.request | |
| from pathlib import Path | |
| import torch | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| load_dotenv() | |
| ROOT_DIR = Path(__file__).resolve().parent | |
| DIST_DIR = ROOT_DIR / "dist" | |
| # Determine the model path or Hugging Face Hub repo ID | |
| BERT_MODEL_DIR_ENV = os.getenv("BERT_MODEL_DIR") | |
| if BERT_MODEL_DIR_ENV: | |
| MODEL_PATH_STR = BERT_MODEL_DIR_ENV | |
| else: | |
| local_model_path = ROOT_DIR / "trained_bert_model" | |
| if local_model_path.exists(): | |
| MODEL_PATH_STR = str(local_model_path) | |
| else: | |
| MODEL_PATH_STR = "brianhuster/dass_bert" | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| PRIMARY_MODEL = os.getenv("GEMINI_MODEL", "gemini-3.1-flash-lite") | |
| FALLBACK_MODELS = [ | |
| value.strip() | |
| for value in os.getenv("GEMINI_FALLBACK_MODELS", "").split(",") | |
| if value.strip() | |
| ] | |
| LABELS = {0: "0", 1: "1", 2: "2", 3: "3", 4: "w"} | |
| app = FastAPI(title="DASS-21 Screening App") | |
| class AnalyzeRequest(BaseModel): | |
| question: str = Field(min_length=1) | |
| answer: str = Field(min_length=1) | |
| class AnalyzeResponse(BaseModel): | |
| label: int | |
| score: int | None | |
| confidence: float | |
| needsClarification: bool | |
| reply: str | |
| model: str | None = None | |
| class AdviceRequest(BaseModel): | |
| assessment: dict = Field(default_factory=dict) | |
| messages: list[dict] = Field(default_factory=list) | |
| class AdviceResponse(BaseModel): | |
| reply: str | |
| riskLevel: str | |
| model: str | None = None | |
| # Check if model exists locally. If not, assume it's a Hugging Face Hub repo ID | |
| if os.path.exists(MODEL_PATH_STR): | |
| print(f"[MODEL] Loading model from local directory: {MODEL_PATH_STR}") | |
| MODEL_TARGET = Path(MODEL_PATH_STR) | |
| else: | |
| # If path has separators or starts with dot/slash, it was meant to be local but is missing | |
| if MODEL_PATH_STR.startswith("/") or MODEL_PATH_STR.startswith("./") or MODEL_PATH_STR.startswith("../"): | |
| raise FileNotFoundError(f"Local model directory not found: {MODEL_PATH_STR}") | |
| print(f"[MODEL] Local directory not found. Loading model from Hugging Face Hub: {MODEL_PATH_STR}") | |
| MODEL_TARGET = MODEL_PATH_STR # type: ignore | |
| MODEL_DIR = MODEL_TARGET | |
| print(f"\n[MODEL] Bắt đầu nạp mô hình từ nguồn: {MODEL_TARGET}") | |
| print("[MODEL] Lưu ý: Nếu chạy lần đầu, quá trình tải tự động mô hình (~1.1GB) từ Hugging Face Hub sẽ chạy ngầm. Vui lòng giữ kết nối Internet và chờ đợi từ 1-5 phút...\n") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_TARGET) | |
| print("[MODEL] Nạp thành công Tokenizer!") | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_TARGET) | |
| print("[MODEL] Nạp thành công Model weights! Đang khởi động web server...\n") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| def gemini_json_reply(prompt: str, fallback_reply: str, fallback_risk: str = "low") -> dict[str, str]: | |
| if not GEMINI_API_KEY: | |
| print("[GEMINI] Không tìm thấy GEMINI_API_KEY. Sử dụng câu trả lời dự phòng (fallback).") | |
| return {"reply": fallback_reply, "riskLevel": fallback_risk, "model": "fallback"} | |
| for candidate_model in [PRIMARY_MODEL, *FALLBACK_MODELS]: | |
| url = ( | |
| f"https://generativelanguage.googleapis.com/v1beta/models/" | |
| f"{candidate_model}:generateContent?key={GEMINI_API_KEY}" | |
| ) | |
| body = { | |
| "contents": [{"parts": [{"text": prompt}]}], | |
| "generationConfig": { | |
| "temperature": 0.6, | |
| "responseMimeType": "application/json", | |
| }, | |
| } | |
| request = urllib.request.Request( | |
| url, | |
| data=json.dumps(body).encode("utf-8"), | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| try: | |
| print(f"[GEMINI] Đang gọi mô hình '{candidate_model}'...") | |
| with urllib.request.urlopen(request, timeout=45) as response: | |
| raw = json.loads(response.read().decode("utf-8")) | |
| text = ( | |
| raw.get("candidates", [{}])[0] | |
| .get("content", {}) | |
| .get("parts", [{}])[0] | |
| .get("text", "") | |
| ) | |
| parsed = json.loads(text) | |
| if isinstance(parsed, dict) and isinstance(parsed.get("reply"), str): | |
| print(f"[GEMINI] Gọi mô hình '{candidate_model}' thành công!") | |
| print(f"[GEMINI] Phản hồi nhận được: {parsed['reply']}") | |
| return { | |
| "reply": parsed["reply"], | |
| "riskLevel": parsed.get("riskLevel", "low"), | |
| "model": candidate_model, | |
| } | |
| except Exception as e: | |
| print(f"[GEMINI] Lỗi khi gọi mô hình '{candidate_model}': {type(e).__name__} - {e}") | |
| continue | |
| print("[GEMINI] Tất cả các mô hình Gemini đều thất bại hoặc trả về dữ liệu sai định dạng. Sử dụng câu trả lời dự phòng (fallback).") | |
| return {"reply": fallback_reply, "riskLevel": fallback_risk, "model": "fallback"} | |
| def build_turn_prompt(question: str, answer: str, label: str, confidence: float) -> str: | |
| tone_guidelines = "" | |
| if label == "0": | |
| tone_guidelines = "- Tình trạng sinh viên đang tốt (nhãn 0): Khen ngợi nhẹ nhàng, ngắn gọn trong 1 câu ấm áp." | |
| elif label == "1": | |
| tone_guidelines = "- Tình trạng sinh viên bình thường (nhãn 1): An ủi và khích lệ nhẹ nhàng trong 1-2 câu." | |
| elif label == "2": | |
| tone_guidelines = "- Tình trạng sinh viên không tốt lắm (nhãn 2): Thể hiện sự đồng cảm sâu sắc, trấn an, giải thích nguyên nhân theo hướng tích cực, đề xuất giải pháp ngắn gọn và lời động viên trong 3-4 câu." | |
| elif label == "3": | |
| tone_guidelines = "- Tình trạng sinh viên khá tệ (nhãn 3): Thể hiện sự đồng cảm cao nhất, xoa dịu tinh thần, định hướng suy nghĩ tích cực hơn, khuyên nghỉ ngơi/chia sẻ với người thân và đưa ra lời khuyên thực tế trong 3-4 câu." | |
| else: | |
| tone_guidelines = "- Câu trả lời có thể mơ hồ hoặc chưa rõ ràng. Hãy nhẹ nhàng khuyến khích sinh viên chia sẻ chi tiết hơn một cách tinh tế." | |
| return f""" | |
| Hãy đóng vai như thể bạn là một trợ lý ảo tư vấn tâm lý học của Đại học Bách khoa Hà Nội (HUST). Bạn có phong cách nói chuyện vô cùng tích cực, ấm áp và đồng cảm cao với sinh viên. | |
| Nhiệm vụ & Nguyên tắc hội thoại: | |
| - Xưng hô thân thiện, gần gũi (dùng "mình" và "bạn"). | |
| - Bạn TUYỆT ĐỐI không được hỏi thêm bất kỳ câu hỏi nào. Chỉ đưa ra lời khuyên, sự an ủi, thấu hiểu hoặc khen ngợi. | |
| - Không chẩn đoán bệnh lý hay nói giọng bác sĩ lâm sàng. | |
| - Hướng dẫn phong cách phản hồi theo nhãn điểm được phân loại: | |
| {tone_guidelines} | |
| Ngữ cảnh hiện tại: | |
| - Câu hỏi khảo sát DASS-21: {question} | |
| - Câu trả lời của sinh viên: '{answer}' | |
| Hãy trả về JSON hợp lệ theo cấu trúc chính xác: | |
| {{"reply": "nội dung phản hồi ấm áp của bạn"}} | |
| """.strip() | |
| def build_final_advice_prompt(assessment: dict, conversation: str) -> str: | |
| top_concern = assessment.get("topConcernLabel", "n/a") | |
| return f""" | |
| Hãy đóng vai như thể bạn là một trợ lý ảo tư vấn tâm lý học của Đại học Bách khoa Hà Nội (HUST). Bạn đang đưa ra những lời khuyên tổng kết và lời tạm biệt sau khi sinh viên đã hoàn thành cuộc khảo sát DASS-21. | |
| Nhiệm vụ & Nguyên tắc hội thoại: | |
| - Chúc mừng sinh viên đã kiên nhẫn cùng bạn hoàn thành tất cả các câu hỏi đánh giá. | |
| - Thể hiện sự đồng cảm, cảm ơn sinh viên đã dành thời gian chia sẻ chân thành. | |
| - Nhận xét kết quả sàng lọc tinh thần bằng những lời lẽ tích cực, mang tính khích lệ: | |
| + Stress: {assessment.get("stress", "n/a")} | |
| + Lo âu: {assessment.get("anxiety", "n/a")} | |
| + Trầm cảm: {assessment.get("depression", "n/a")} | |
| + Yếu tố nổi bật nhất: {top_concern} | |
| - Đưa ra lời khuyên thiết thực (chế độ nghỉ ngơi, gặp gỡ bạn bè, quản lý thời gian học tập tại Bách khoa). | |
| - Nhấn mạnh rằng nhà trường và các trợ lý ảo/thầy cô sẽ luôn đồng hành, sẵn sàng lắng nghe và hỗ trợ bạn bất cứ lúc nào. | |
| - Tạm biệt sinh viên bằng lời chúc ấm áp (dùng xưng hô "mình" và "bạn"). | |
| - Độ dài khoảng 8-10 câu. | |
| Hãy trả về JSON hợp lệ theo cấu trúc chính xác: | |
| {{"reply": "nội dung lời khuyên và lời chào tạm biệt đầy đủ của bạn", "riskLevel": "low" | "moderate" | "high"}} | |
| """.strip() | |
| def health() -> dict[str, object]: | |
| return { | |
| "ok": True, | |
| "bert": { | |
| "status": "ok", | |
| "device": str(device), | |
| "model_dir": str(MODEL_DIR), | |
| }, | |
| "gemini": bool(GEMINI_API_KEY), | |
| } | |
| def analyze(payload: AnalyzeRequest) -> AnalyzeResponse: | |
| print(f"\n--- [API /api/dass/analyze] Nhận yêu cầu ---") | |
| print(f"Câu hỏi: {payload.question}") | |
| print(f"Câu trả lời của sinh viên: '{payload.answer}'") | |
| inputs = tokenizer( | |
| payload.question, | |
| payload.answer, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=192, | |
| ) | |
| inputs = {key: value.to(device) for key, value in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=-1)[0] | |
| label = int(torch.argmax(probabilities).item()) | |
| confidence = float(probabilities[label].item()) | |
| print(f"[BERT] Kết quả phân loại: nhãn={label} (Điểm quy đổi={label if label in {0, 1, 2, 3} else 'N/A'})") | |
| print(f"[BERT] Độ tin cậy (Confidence): {confidence:.4f}") | |
| if label == 4 or confidence < 0.45: | |
| reason = "nhãn = 4 (lạc đề/off-topic)" if label == 4 else f"độ tin cậy {confidence:.4f} thấp hơn ngưỡng 0.45" | |
| print(f"[BERT] Yêu cầu làm rõ (Clarification Needed): {reason}") | |
| return AnalyzeResponse( | |
| label=label, | |
| score=None, | |
| confidence=confidence, | |
| needsClarification=True, | |
| reply="Mình chưa chắc mình hiểu đúng ý bạn. Bạn có thể trả lời lại ngắn gọn và trực tiếp hơn theo đúng câu hỏi này không?", | |
| model="bert", | |
| ) | |
| fallback_reply = ( | |
| "Cảm ơn bạn đã chia sẻ. Nghe như điều này đang ảnh hưởng bạn khá nhiều; mình sẽ ghi nhận để xem tổng thể sau." | |
| if label == 3 | |
| else "Cảm ơn bạn đã chia sẻ, mình đã hiểu hơn rồi." | |
| ) | |
| print("[BERT] Đủ độ tin cậy. Đang chuyển tiếp sang Gemini để tạo phản hồi đồng cảm...") | |
| reply = gemini_json_reply( | |
| build_turn_prompt(payload.question, payload.answer, LABELS.get(label, str(label)), confidence), | |
| fallback_reply, | |
| ) | |
| print(f"--- Kết quả phản hồi gửi về FE ---") | |
| print(f"Nhãn điểm: {label if label in {0, 1, 2, 3} else None}") | |
| print(f"Phản hồi: {reply['reply']}") | |
| print(f"Mô hình thực tế sử dụng: {reply.get('model')}\n") | |
| return AnalyzeResponse( | |
| label=label, | |
| score=label if label in {0, 1, 2, 3} else None, | |
| confidence=confidence, | |
| needsClarification=False, | |
| reply=reply["reply"], | |
| model=reply.get("model"), | |
| ) | |
| def chat(payload: AdviceRequest) -> AdviceResponse: | |
| print(f"\n--- [API /api/chat] Nhận yêu cầu tư vấn ---") | |
| print(f"Kết quả phân tích: {payload.assessment}") | |
| conversation = "\n".join( | |
| f"{message.get('role', 'user')}: {message.get('content', '')}" | |
| for message in payload.messages[-10:] | |
| if isinstance(message, dict) and isinstance(message.get("content"), str) | |
| ) | |
| print(f"Lịch sử hội thoại gửi lên:\n{conversation}") | |
| reply = gemini_json_reply( | |
| build_final_advice_prompt(payload.assessment, conversation), | |
| "Mình nghe thấy bạn đang cần được hỗ trợ thêm. Nếu cảm thấy quá tải, hãy nói với người thân hoặc một chuyên gia nhé.", | |
| "moderate", | |
| ) | |
| print(f"--- Phản hồi tư vấn gửi về FE ---") | |
| print(f"Tư vấn: {reply['reply']}") | |
| print(f"Mức độ rủi ro: {reply.get('riskLevel')}") | |
| print(f"Mô hình thực tế sử dụng: {reply.get('model')}\n") | |
| return AdviceResponse( | |
| reply=reply["reply"], | |
| riskLevel=reply.get("riskLevel", "low"), | |
| model=reply.get("model"), | |
| ) | |
| if DIST_DIR.exists(): | |
| app.mount("/", StaticFiles(directory=DIST_DIR, html=True), name="dist") | |