""" Người số 2: LLM & Prompt Engineer Hàm chính: get_ai_grade(bai_van, tai_lieu_chuan) -> dict Xử lý: - Gọi Qwen-72B API - Fix lỗi JSON (thừa text, cắt cụt, markdown fence) - Retry tự động nếu JSON lỗi - Validate schema output """ import json import re import time import logging from typing import Optional from openai import OpenAI from server.config import config from server.prompts import build_messages, get_xep_loai # ============================================================ # LOGGING # ============================================================ logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) # ============================================================ # KHỞI TẠO CLIENT # ============================================================ def _build_client() -> OpenAI: """Tạo OpenAI-compatible client trỏ đến provider đang dùng.""" return OpenAI( api_key=config.active_api_key, base_url=config.active_base_url, ) # ============================================================ # JSON FIX — xử lý mọi dạng output lỗi từ LLM # ============================================================ def _extract_json(raw_text: str) -> str: """ Trích xuất JSON thuần túy từ output của LLM, dù LLM có: - Bọc trong ```json ... ``` - Thêm text trước/sau - Có comment // - Có trailing comma """ text = raw_text.strip() # 1. Bóc markdown code fence nếu có fence_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text) if fence_match: text = fence_match.group(1).strip() # 2. Tìm JSON object đầu tiên hợp lệ (từ { đến } cuối cùng khớp) start = text.find("{") if start == -1: raise ValueError("Không tìm thấy '{' trong output của LLM") # Tìm } đóng khớp bằng cách đếm depth depth = 0 end = -1 in_string = False escape_next = False for i, ch in enumerate(text[start:], start=start): if escape_next: escape_next = False continue if ch == "\\" and in_string: escape_next = True continue if ch == '"' and not escape_next: in_string = not in_string continue if in_string: continue if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: end = i break if end == -1: # JSON bị cắt cụt — thử tự đóng các ngoặc còn thiếu logger.warning("JSON bị cắt cụt, thử tự sửa...") text = text[start:] text = _fix_truncated_json(text) return text return text[start : end + 1] def _fix_truncated_json(text: str) -> str: """Tự động đóng các cặp ngoặc/nháy còn thiếu.""" # Xóa trailing comma cuối cùng text = re.sub(r",\s*$", "", text.rstrip()) # Đếm ngoặc thiếu depth_curly = 0 depth_square = 0 in_string = False escape_next = False for ch in text: if escape_next: escape_next = False continue if ch == "\\": escape_next = True continue if ch == '"' and not escape_next: in_string = not in_string continue if in_string: continue if ch == "{": depth_curly += 1 elif ch == "}": depth_curly -= 1 elif ch == "[": depth_square += 1 elif ch == "]": depth_square -= 1 # Đóng những gì còn mở if in_string: text += '"' text += "]" * max(0, depth_square) text += "}" * max(0, depth_curly) return text def _clean_json_string(raw: str) -> str: """ Làm sạch JSON string: - Xóa comment // và /* */ - Xóa trailing comma trước } hoặc ] """ # Xóa comment // (cẩn thận không xóa URL trong string) raw = re.sub(r'(? dict: """ Kiểm tra và tự điền các field còn thiếu với giá trị mặc định. """ # Đảm bảo điểm hợp lệ diem = data.get("diem", 0) try: diem = float(diem) # Làm tròn đến 0.25 diem = round(diem * 4) / 4 diem = max(0.0, min(10.0, diem)) except (TypeError, ValueError): diem = 0.0 data["diem"] = diem # Tự điền xep_loai nếu thiếu hoặc sai data["xep_loai"] = get_xep_loai(diem) # Đảm bảo các list field là list for field in ["uu_diem", "nhuoc_diem"]: if not isinstance(data.get(field), list): data[field] = [] # Đảm bảo chi_tiet_diem là dict if not isinstance(data.get("chi_tiet_diem"), dict): data["chi_tiet_diem"] = {"noi_dung": 0.0, "hinh_thuc": 0.0, "sang_tao": 0.0} # Điền field còn thiếu data.setdefault("nhan_xet_chung", "Không có nhận xét.") data.setdefault("ket_luan", "") return data # ============================================================ # HÀM CALL API VỚI RETRY # ============================================================ def _call_qwen( messages: list[dict], client: OpenAI, attempt: int = 1, ) -> str: """ Gọi Qwen API và trả về raw text. Có timeout và retry. """ logger.info(f"[Attempt {attempt}] Gọi {config.active_model} qua {config.ACTIVE_PROVIDER}...") # Send plain-text chat completion (no structured response_format) response = client.chat.completions.create( model=config.active_model, messages=messages, temperature=config.TEMPERATURE, max_tokens=config.MAX_TOKENS, top_p=config.TOP_P, ) raw = response.choices[0].message.content finish_reason = response.choices[0].finish_reason logger.info(f"Finish reason: {finish_reason} | Output length: {len(raw)} chars") if finish_reason == "length": logger.warning("Output bị cắt do max_tokens! Thử tăng MAX_TOKENS.") return raw # ============================================================ # HÀM CHÍNH: get_ai_grade() # ============================================================ def get_ai_grade( bai_van: str, tai_lieu_chuan: str, max_retries: int = 3, retry_delay: float = 2.0, ) -> dict: """ Chấm điểm bài văn học sinh bằng Qwen-72B. Args: bai_van: Nội dung bài văn học sinh cần chấm tai_lieu_chuan: Đáp án chuẩn từ Qdrant (output của search_context()) max_retries: Số lần retry nếu JSON lỗi retry_delay: Thời gian chờ giữa các lần retry (giây) Returns: dict: Kết quả chấm điểm đã validate, gồm: diem, xep_loai, nhan_xet_chung, uu_diem, nhuoc_diem, chi_tiet_diem, ket_luan Raises: ValueError: Nếu sau max_retries vẫn không parse được JSON RuntimeError: Nếu API key chưa cấu hình """ # Kiểm tra API key if not config.active_api_key: raise RuntimeError( f"Chưa cấu hình API key cho provider '{config.ACTIVE_PROVIDER}'. " f"Set biến môi trường: NVIDIA_API_KEY, TOGETHER_API_KEY, hoặc OPENROUTER_API_KEY" ) # Kiểm tra input if not bai_van or not bai_van.strip(): raise ValueError("bai_van không được để trống") if not tai_lieu_chuan or not tai_lieu_chuan.strip(): raise ValueError("tai_lieu_chuan không được để trống") client = _build_client() messages = build_messages(bai_van=bai_van, tai_lieu_chuan=tai_lieu_chuan) last_error: Optional[Exception] = None for attempt in range(1, max_retries + 1): try: raw_output = _call_qwen(messages, client, attempt=attempt) # Bước 1: Trích xuất JSON thuần túy json_str = _extract_json(raw_output) # Bước 2: Làm sạch comment / trailing comma json_str = _clean_json_string(json_str) # Bước 3: Parse data = json.loads(json_str) # Bước 4: Validate và fix schema data = _validate_and_fix_schema(data) logger.info(f"✓ Chấm thành công | Điểm: {data['diem']} | Xếp loại: {data['xep_loai']}") return data except json.JSONDecodeError as e: last_error = e logger.warning(f"[Attempt {attempt}] JSON parse lỗi: {e}") if attempt < max_retries: # Retry với hint thêm trong message logger.info(f"Retry sau {retry_delay}s...") time.sleep(retry_delay) # Thêm message nhắc nhở JSON messages_with_hint = messages + [ { "role": "assistant", "content": raw_output, }, { "role": "user", "content": ( "Output của bạn không phải JSON hợp lệ. " "Hãy trả về lại CHỈ JSON object hợp lệ, " "không có text nào khác, không có markdown." ), }, ] messages = messages_with_hint except Exception as e: last_error = e logger.error(f"[Attempt {attempt}] Lỗi không mong đợi: {e}") if attempt < max_retries: time.sleep(retry_delay) raise ValueError( f"Không thể parse JSON sau {max_retries} lần thử. " f"Lỗi cuối: {last_error}" ) # ============================================================ # BATCH GRADING — chấm nhiều bài cùng lúc # ============================================================ def grade_batch( bai_van_list: list[str], tai_lieu_chuan_list: list[str], delay_between: float = 1.0, ) -> list[dict]: """ Chấm nhiều bài văn tuần tự. Args: bai_van_list: Danh sách bài văn học sinh tai_lieu_chuan_list: Danh sách context tương ứng từ Qdrant delay_between: Thời gian chờ giữa các request (tránh rate limit) Returns: list[dict]: Danh sách kết quả, mỗi phần tử là dict từ get_ai_grade() hoặc {"error": "..."} nếu bài đó bị lỗi """ if len(bai_van_list) != len(tai_lieu_chuan_list): raise ValueError("bai_van_list và tai_lieu_chuan_list phải cùng độ dài") results = [] total = len(bai_van_list) for i, (bai_van, context) in enumerate(zip(bai_van_list, tai_lieu_chuan_list), 1): logger.info(f"--- Chấm bài {i}/{total} ---") try: result = get_ai_grade(bai_van, context) results.append(result) except Exception as e: logger.error(f"Bài {i} bị lỗi: {e}") results.append({"error": str(e), "bai_so": i}) if i < total: time.sleep(delay_between) return results