| """
|
| Người số 2: LLM & Prompt Engineer
|
| Hàm chính: get_ai_grade(bai_van, tai_lieu_chuan) -> dict
|
|
|
| Xử lý:
|
| - Gọi Qwen-72B API
|
| - Fix lỗi JSON (thừa text, cắt cụt, markdown fence)
|
| - Retry tự động nếu JSON lỗi
|
| - Validate schema output
|
| """
|
|
|
| import json
|
| import re
|
| import time
|
| import logging
|
| from typing import Optional
|
|
|
| from openai import OpenAI
|
|
|
| from server.config import config
|
| from server.prompts import build_messages, get_xep_loai
|
|
|
|
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format="%(asctime)s [%(levelname)s] %(message)s",
|
| datefmt="%H:%M:%S",
|
| )
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
| def _build_client() -> OpenAI:
|
| """Tạo OpenAI-compatible client trỏ đến provider đang dùng."""
|
| return OpenAI(
|
| api_key=config.active_api_key,
|
| base_url=config.active_base_url,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def _extract_json(raw_text: str) -> str:
|
| """
|
| Trích xuất JSON thuần túy từ output của LLM, dù LLM có:
|
| - Bọc trong ```json ... ```
|
| - Thêm text trước/sau
|
| - Có comment //
|
| - Có trailing comma
|
| """
|
| text = raw_text.strip()
|
|
|
|
|
| fence_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
|
| if fence_match:
|
| text = fence_match.group(1).strip()
|
|
|
|
|
| start = text.find("{")
|
| if start == -1:
|
| raise ValueError("Không tìm thấy '{' trong output của LLM")
|
|
|
|
|
| depth = 0
|
| end = -1
|
| in_string = False
|
| escape_next = False
|
|
|
| for i, ch in enumerate(text[start:], start=start):
|
| if escape_next:
|
| escape_next = False
|
| continue
|
| if ch == "\\" and in_string:
|
| escape_next = True
|
| continue
|
| if ch == '"' and not escape_next:
|
| in_string = not in_string
|
| continue
|
| if in_string:
|
| continue
|
| if ch == "{":
|
| depth += 1
|
| elif ch == "}":
|
| depth -= 1
|
| if depth == 0:
|
| end = i
|
| break
|
|
|
| if end == -1:
|
|
|
| logger.warning("JSON bị cắt cụt, thử tự sửa...")
|
| text = text[start:]
|
| text = _fix_truncated_json(text)
|
| return text
|
|
|
| return text[start : end + 1]
|
|
|
|
|
| def _fix_truncated_json(text: str) -> str:
|
| """Tự động đóng các cặp ngoặc/nháy còn thiếu."""
|
|
|
| text = re.sub(r",\s*$", "", text.rstrip())
|
|
|
|
|
| depth_curly = 0
|
| depth_square = 0
|
| in_string = False
|
| escape_next = False
|
|
|
| for ch in text:
|
| if escape_next:
|
| escape_next = False
|
| continue
|
| if ch == "\\":
|
| escape_next = True
|
| continue
|
| if ch == '"' and not escape_next:
|
| in_string = not in_string
|
| continue
|
| if in_string:
|
| continue
|
| if ch == "{":
|
| depth_curly += 1
|
| elif ch == "}":
|
| depth_curly -= 1
|
| elif ch == "[":
|
| depth_square += 1
|
| elif ch == "]":
|
| depth_square -= 1
|
|
|
|
|
| if in_string:
|
| text += '"'
|
| text += "]" * max(0, depth_square)
|
| text += "}" * max(0, depth_curly)
|
|
|
| return text
|
|
|
|
|
| def _clean_json_string(raw: str) -> str:
|
| """
|
| Làm sạch JSON string:
|
| - Xóa comment // và /* */
|
| - Xóa trailing comma trước } hoặc ]
|
| """
|
|
|
| raw = re.sub(r'(?<!:)//.*', '', raw)
|
|
|
| raw = re.sub(r'/\*.*?\*/', '', raw, flags=re.DOTALL)
|
|
|
| raw = re.sub(r',(\s*[}\]])', r'\1', raw)
|
| return raw
|
|
|
|
|
|
|
|
|
|
|
| REQUIRED_FIELDS = {"diem", "xep_loai", "nhan_xet_chung", "uu_diem", "nhuoc_diem", "chi_tiet_diem", "ket_luan"}
|
|
|
| def _validate_and_fix_schema(data: dict) -> dict:
|
| """
|
| Kiểm tra và tự điền các field còn thiếu với giá trị mặc định.
|
| """
|
|
|
| diem = data.get("diem", 0)
|
| try:
|
| diem = float(diem)
|
|
|
| diem = round(diem * 4) / 4
|
| diem = max(0.0, min(10.0, diem))
|
| except (TypeError, ValueError):
|
| diem = 0.0
|
| data["diem"] = diem
|
|
|
|
|
| data["xep_loai"] = get_xep_loai(diem)
|
|
|
|
|
| for field in ["uu_diem", "nhuoc_diem"]:
|
| if not isinstance(data.get(field), list):
|
| data[field] = []
|
|
|
|
|
| if not isinstance(data.get("chi_tiet_diem"), dict):
|
| data["chi_tiet_diem"] = {"noi_dung": 0.0, "hinh_thuc": 0.0, "sang_tao": 0.0}
|
|
|
|
|
| data.setdefault("nhan_xet_chung", "Không có nhận xét.")
|
| data.setdefault("ket_luan", "")
|
|
|
| return data
|
|
|
|
|
|
|
|
|
|
|
| def _call_qwen(
|
| messages: list[dict],
|
| client: OpenAI,
|
| attempt: int = 1,
|
| ) -> str:
|
| """
|
| Gọi Qwen API và trả về raw text. Có timeout và retry.
|
| """
|
| logger.info(f"[Attempt {attempt}] Gọi {config.active_model} qua {config.ACTIVE_PROVIDER}...")
|
|
|
|
|
| response = client.chat.completions.create(
|
| model=config.active_model,
|
| messages=messages,
|
| temperature=config.TEMPERATURE,
|
| max_tokens=config.MAX_TOKENS,
|
| top_p=config.TOP_P,
|
| )
|
|
|
| raw = response.choices[0].message.content
|
| finish_reason = response.choices[0].finish_reason
|
|
|
| logger.info(f"Finish reason: {finish_reason} | Output length: {len(raw)} chars")
|
|
|
| if finish_reason == "length":
|
| logger.warning("Output bị cắt do max_tokens! Thử tăng MAX_TOKENS.")
|
|
|
| return raw
|
|
|
|
|
|
|
|
|
|
|
| def get_ai_grade(
|
| bai_van: str,
|
| tai_lieu_chuan: str,
|
| max_retries: int = 3,
|
| retry_delay: float = 2.0,
|
| ) -> dict:
|
| """
|
| Chấm điểm bài văn học sinh bằng Qwen-72B.
|
|
|
| Args:
|
| bai_van: Nội dung bài văn học sinh cần chấm
|
| tai_lieu_chuan: Đáp án chuẩn từ Qdrant (output của search_context())
|
| max_retries: Số lần retry nếu JSON lỗi
|
| retry_delay: Thời gian chờ giữa các lần retry (giây)
|
|
|
| Returns:
|
| dict: Kết quả chấm điểm đã validate, gồm:
|
| diem, xep_loai, nhan_xet_chung, uu_diem, nhuoc_diem,
|
| chi_tiet_diem, ket_luan
|
|
|
| Raises:
|
| ValueError: Nếu sau max_retries vẫn không parse được JSON
|
| RuntimeError: Nếu API key chưa cấu hình
|
| """
|
|
|
| if not config.active_api_key:
|
| raise RuntimeError(
|
| f"Chưa cấu hình API key cho provider '{config.ACTIVE_PROVIDER}'. "
|
| f"Set biến môi trường: NVIDIA_API_KEY, TOGETHER_API_KEY, hoặc OPENROUTER_API_KEY"
|
| )
|
|
|
|
|
| if not bai_van or not bai_van.strip():
|
| raise ValueError("bai_van không được để trống")
|
| if not tai_lieu_chuan or not tai_lieu_chuan.strip():
|
| raise ValueError("tai_lieu_chuan không được để trống")
|
|
|
| client = _build_client()
|
| messages = build_messages(bai_van=bai_van, tai_lieu_chuan=tai_lieu_chuan)
|
|
|
| last_error: Optional[Exception] = None
|
|
|
| for attempt in range(1, max_retries + 1):
|
| try:
|
| raw_output = _call_qwen(messages, client, attempt=attempt)
|
|
|
|
|
| json_str = _extract_json(raw_output)
|
|
|
|
|
| json_str = _clean_json_string(json_str)
|
|
|
|
|
| data = json.loads(json_str)
|
|
|
|
|
| data = _validate_and_fix_schema(data)
|
|
|
| logger.info(f"✓ Chấm thành công | Điểm: {data['diem']} | Xếp loại: {data['xep_loai']}")
|
| return data
|
|
|
| except json.JSONDecodeError as e:
|
| last_error = e
|
| logger.warning(f"[Attempt {attempt}] JSON parse lỗi: {e}")
|
|
|
| if attempt < max_retries:
|
|
|
| logger.info(f"Retry sau {retry_delay}s...")
|
| time.sleep(retry_delay)
|
|
|
|
|
| messages_with_hint = messages + [
|
| {
|
| "role": "assistant",
|
| "content": raw_output,
|
| },
|
| {
|
| "role": "user",
|
| "content": (
|
| "Output của bạn không phải JSON hợp lệ. "
|
| "Hãy trả về lại CHỈ JSON object hợp lệ, "
|
| "không có text nào khác, không có markdown."
|
| ),
|
| },
|
| ]
|
| messages = messages_with_hint
|
|
|
| except Exception as e:
|
| last_error = e
|
| logger.error(f"[Attempt {attempt}] Lỗi không mong đợi: {e}")
|
| if attempt < max_retries:
|
| time.sleep(retry_delay)
|
|
|
| raise ValueError(
|
| f"Không thể parse JSON sau {max_retries} lần thử. "
|
| f"Lỗi cuối: {last_error}"
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def grade_batch(
|
| bai_van_list: list[str],
|
| tai_lieu_chuan_list: list[str],
|
| delay_between: float = 1.0,
|
| ) -> list[dict]:
|
| """
|
| Chấm nhiều bài văn tuần tự.
|
|
|
| Args:
|
| bai_van_list: Danh sách bài văn học sinh
|
| tai_lieu_chuan_list: Danh sách context tương ứng từ Qdrant
|
| delay_between: Thời gian chờ giữa các request (tránh rate limit)
|
|
|
| Returns:
|
| list[dict]: Danh sách kết quả, mỗi phần tử là dict từ get_ai_grade()
|
| hoặc {"error": "..."} nếu bài đó bị lỗi
|
| """
|
| if len(bai_van_list) != len(tai_lieu_chuan_list):
|
| raise ValueError("bai_van_list và tai_lieu_chuan_list phải cùng độ dài")
|
|
|
| results = []
|
| total = len(bai_van_list)
|
|
|
| for i, (bai_van, context) in enumerate(zip(bai_van_list, tai_lieu_chuan_list), 1):
|
| logger.info(f"--- Chấm bài {i}/{total} ---")
|
| try:
|
| result = get_ai_grade(bai_van, context)
|
| results.append(result)
|
| except Exception as e:
|
| logger.error(f"Bài {i} bị lỗi: {e}")
|
| results.append({"error": str(e), "bai_so": i})
|
|
|
| if i < total:
|
| time.sleep(delay_between)
|
|
|
| return results
|
|
|