vanfun_be / grader.py
moonbaek's picture
Upload 118 files
67819f1 verified
Raw
History Blame Contribute Delete
12.2 kB
"""
Người số 2: LLM & Prompt Engineer
Hàm chính: get_ai_grade(bai_van, tai_lieu_chuan) -> dict
Xử lý:
- Gọi Qwen-72B API
- Fix lỗi JSON (thừa text, cắt cụt, markdown fence)
- Retry tự động nếu JSON lỗi
- Validate schema output
"""
import json
import re
import time
import logging
from typing import Optional
from openai import OpenAI
from server.config import config
from server.prompts import build_messages, get_xep_loai
# ============================================================
# LOGGING
# ============================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# ============================================================
# KHỞI TẠO CLIENT
# ============================================================
def _build_client() -> OpenAI:
"""Tạo OpenAI-compatible client trỏ đến provider đang dùng."""
return OpenAI(
api_key=config.active_api_key,
base_url=config.active_base_url,
)
# ============================================================
# JSON FIX — xử lý mọi dạng output lỗi từ LLM
# ============================================================
def _extract_json(raw_text: str) -> str:
"""
Trích xuất JSON thuần túy từ output của LLM, dù LLM có:
- Bọc trong ```json ... ```
- Thêm text trước/sau
- Có comment //
- Có trailing comma
"""
text = raw_text.strip()
# 1. Bóc markdown code fence nếu có
fence_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
if fence_match:
text = fence_match.group(1).strip()
# 2. Tìm JSON object đầu tiên hợp lệ (từ { đến } cuối cùng khớp)
start = text.find("{")
if start == -1:
raise ValueError("Không tìm thấy '{' trong output của LLM")
# Tìm } đóng khớp bằng cách đếm depth
depth = 0
end = -1
in_string = False
escape_next = False
for i, ch in enumerate(text[start:], start=start):
if escape_next:
escape_next = False
continue
if ch == "\\" and in_string:
escape_next = True
continue
if ch == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
end = i
break
if end == -1:
# JSON bị cắt cụt — thử tự đóng các ngoặc còn thiếu
logger.warning("JSON bị cắt cụt, thử tự sửa...")
text = text[start:]
text = _fix_truncated_json(text)
return text
return text[start : end + 1]
def _fix_truncated_json(text: str) -> str:
"""Tự động đóng các cặp ngoặc/nháy còn thiếu."""
# Xóa trailing comma cuối cùng
text = re.sub(r",\s*$", "", text.rstrip())
# Đếm ngoặc thiếu
depth_curly = 0
depth_square = 0
in_string = False
escape_next = False
for ch in text:
if escape_next:
escape_next = False
continue
if ch == "\\":
escape_next = True
continue
if ch == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth_curly += 1
elif ch == "}":
depth_curly -= 1
elif ch == "[":
depth_square += 1
elif ch == "]":
depth_square -= 1
# Đóng những gì còn mở
if in_string:
text += '"'
text += "]" * max(0, depth_square)
text += "}" * max(0, depth_curly)
return text
def _clean_json_string(raw: str) -> str:
"""
Làm sạch JSON string:
- Xóa comment // và /* */
- Xóa trailing comma trước } hoặc ]
"""
# Xóa comment // (cẩn thận không xóa URL trong string)
raw = re.sub(r'(?<!:)//.*', '', raw)
# Xóa /* */ comment
raw = re.sub(r'/\*.*?\*/', '', raw, flags=re.DOTALL)
# Xóa trailing comma
raw = re.sub(r',(\s*[}\]])', r'\1', raw)
return raw
# ============================================================
# VALIDATE SCHEMA OUTPUT
# ============================================================
REQUIRED_FIELDS = {"diem", "xep_loai", "nhan_xet_chung", "uu_diem", "nhuoc_diem", "chi_tiet_diem", "ket_luan"}
def _validate_and_fix_schema(data: dict) -> dict:
"""
Kiểm tra và tự điền các field còn thiếu với giá trị mặc định.
"""
# Đảm bảo điểm hợp lệ
diem = data.get("diem", 0)
try:
diem = float(diem)
# Làm tròn đến 0.25
diem = round(diem * 4) / 4
diem = max(0.0, min(10.0, diem))
except (TypeError, ValueError):
diem = 0.0
data["diem"] = diem
# Tự điền xep_loai nếu thiếu hoặc sai
data["xep_loai"] = get_xep_loai(diem)
# Đảm bảo các list field là list
for field in ["uu_diem", "nhuoc_diem"]:
if not isinstance(data.get(field), list):
data[field] = []
# Đảm bảo chi_tiet_diem là dict
if not isinstance(data.get("chi_tiet_diem"), dict):
data["chi_tiet_diem"] = {"noi_dung": 0.0, "hinh_thuc": 0.0, "sang_tao": 0.0}
# Điền field còn thiếu
data.setdefault("nhan_xet_chung", "Không có nhận xét.")
data.setdefault("ket_luan", "")
return data
# ============================================================
# HÀM CALL API VỚI RETRY
# ============================================================
def _call_qwen(
messages: list[dict],
client: OpenAI,
attempt: int = 1,
) -> str:
"""
Gọi Qwen API và trả về raw text. Có timeout và retry.
"""
logger.info(f"[Attempt {attempt}] Gọi {config.active_model} qua {config.ACTIVE_PROVIDER}...")
# Send plain-text chat completion (no structured response_format)
response = client.chat.completions.create(
model=config.active_model,
messages=messages,
temperature=config.TEMPERATURE,
max_tokens=config.MAX_TOKENS,
top_p=config.TOP_P,
)
raw = response.choices[0].message.content
finish_reason = response.choices[0].finish_reason
logger.info(f"Finish reason: {finish_reason} | Output length: {len(raw)} chars")
if finish_reason == "length":
logger.warning("Output bị cắt do max_tokens! Thử tăng MAX_TOKENS.")
return raw
# ============================================================
# HÀM CHÍNH: get_ai_grade()
# ============================================================
def get_ai_grade(
bai_van: str,
tai_lieu_chuan: str,
max_retries: int = 3,
retry_delay: float = 2.0,
) -> dict:
"""
Chấm điểm bài văn học sinh bằng Qwen-72B.
Args:
bai_van: Nội dung bài văn học sinh cần chấm
tai_lieu_chuan: Đáp án chuẩn từ Qdrant (output của search_context())
max_retries: Số lần retry nếu JSON lỗi
retry_delay: Thời gian chờ giữa các lần retry (giây)
Returns:
dict: Kết quả chấm điểm đã validate, gồm:
diem, xep_loai, nhan_xet_chung, uu_diem, nhuoc_diem,
chi_tiet_diem, ket_luan
Raises:
ValueError: Nếu sau max_retries vẫn không parse được JSON
RuntimeError: Nếu API key chưa cấu hình
"""
# Kiểm tra API key
if not config.active_api_key:
raise RuntimeError(
f"Chưa cấu hình API key cho provider '{config.ACTIVE_PROVIDER}'. "
f"Set biến môi trường: NVIDIA_API_KEY, TOGETHER_API_KEY, hoặc OPENROUTER_API_KEY"
)
# Kiểm tra input
if not bai_van or not bai_van.strip():
raise ValueError("bai_van không được để trống")
if not tai_lieu_chuan or not tai_lieu_chuan.strip():
raise ValueError("tai_lieu_chuan không được để trống")
client = _build_client()
messages = build_messages(bai_van=bai_van, tai_lieu_chuan=tai_lieu_chuan)
last_error: Optional[Exception] = None
for attempt in range(1, max_retries + 1):
try:
raw_output = _call_qwen(messages, client, attempt=attempt)
# Bước 1: Trích xuất JSON thuần túy
json_str = _extract_json(raw_output)
# Bước 2: Làm sạch comment / trailing comma
json_str = _clean_json_string(json_str)
# Bước 3: Parse
data = json.loads(json_str)
# Bước 4: Validate và fix schema
data = _validate_and_fix_schema(data)
logger.info(f"✓ Chấm thành công | Điểm: {data['diem']} | Xếp loại: {data['xep_loai']}")
return data
except json.JSONDecodeError as e:
last_error = e
logger.warning(f"[Attempt {attempt}] JSON parse lỗi: {e}")
if attempt < max_retries:
# Retry với hint thêm trong message
logger.info(f"Retry sau {retry_delay}s...")
time.sleep(retry_delay)
# Thêm message nhắc nhở JSON
messages_with_hint = messages + [
{
"role": "assistant",
"content": raw_output,
},
{
"role": "user",
"content": (
"Output của bạn không phải JSON hợp lệ. "
"Hãy trả về lại CHỈ JSON object hợp lệ, "
"không có text nào khác, không có markdown."
),
},
]
messages = messages_with_hint
except Exception as e:
last_error = e
logger.error(f"[Attempt {attempt}] Lỗi không mong đợi: {e}")
if attempt < max_retries:
time.sleep(retry_delay)
raise ValueError(
f"Không thể parse JSON sau {max_retries} lần thử. "
f"Lỗi cuối: {last_error}"
)
# ============================================================
# BATCH GRADING — chấm nhiều bài cùng lúc
# ============================================================
def grade_batch(
bai_van_list: list[str],
tai_lieu_chuan_list: list[str],
delay_between: float = 1.0,
) -> list[dict]:
"""
Chấm nhiều bài văn tuần tự.
Args:
bai_van_list: Danh sách bài văn học sinh
tai_lieu_chuan_list: Danh sách context tương ứng từ Qdrant
delay_between: Thời gian chờ giữa các request (tránh rate limit)
Returns:
list[dict]: Danh sách kết quả, mỗi phần tử là dict từ get_ai_grade()
hoặc {"error": "..."} nếu bài đó bị lỗi
"""
if len(bai_van_list) != len(tai_lieu_chuan_list):
raise ValueError("bai_van_list và tai_lieu_chuan_list phải cùng độ dài")
results = []
total = len(bai_van_list)
for i, (bai_van, context) in enumerate(zip(bai_van_list, tai_lieu_chuan_list), 1):
logger.info(f"--- Chấm bài {i}/{total} ---")
try:
result = get_ai_grade(bai_van, context)
results.append(result)
except Exception as e:
logger.error(f"Bài {i} bị lỗi: {e}")
results.append({"error": str(e), "bai_so": i})
if i < total:
time.sleep(delay_between)
return results