vanfun_be / model_main.py
moonbaek's picture
Upload model_main.py
2d2d14f verified
Raw
History Blame Contribute Delete
24.1 kB
"""
HỆ THỐNG TÍCH HỢP CHẤM BÀI VĂN TỰ ĐỘNG (RAG + LLM GRADER)
Sự kết hợp giữa:
- Người số 1: Trích xuất và tìm kiếm tài liệu chuẩn từ Google Drive & Qdrant Vector DB
- Người số 2: Chấm điểm bằng Qwen-72B qua API, sửa lỗi JSON tự động
Vui lòng cài đặt các thư viện trước khi chạy:
pip install sentence-transformers qdrant-client google-api-python-client google-auth-httplib2 google-auth-oauthlib numpy openai
"""
import os
import re
import sys
import json
import time
import logging
import unicodedata
from typing import List, Dict, Optional
from dataclasses import dataclass
from pathlib import Path
# --- KIỂM TRA THƯ VIỆN TRƯỚC KHI CHẠY (PRE-FLIGHT CHECK) ---
REQUIRED_LIBRARIES = {
"sentence_transformers": "sentence-transformers",
"qdrant_client": "qdrant-client",
"googleapiclient": "google-api-python-client",
"google.auth": "google-auth-oauthlib",
"numpy": "numpy",
"openai": "openai"
}
missing_libraries = []
for module_name, pip_name in REQUIRED_LIBRARIES.items():
try:
__import__(module_name)
except ImportError:
missing_libraries.append(pip_name)
if missing_libraries:
print("\n" + "="*80)
print("❌ LỖI KHÔNG TÌM THẤY THƯ VIỆN TRÊN MÁY CỦA BẠN!")
print("="*80)
print("Vui lòng mở Terminal / Command Prompt và chạy lệnh dưới đây để cài đặt:")
print(f"\npip install {' '.join(missing_libraries)}")
print("="*80 + "\n")
sys.exit(1)
# --- IMPORT CÁC THƯ VIỆN KHI ĐÃ ĐẢM BẢO ĐỦ ĐIỀU KIỆN ---
import numpy as np
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload
from google.auth.transport.requests import Request
from google_auth_oauthlib.flow import InstalledAppFlow
from openai import OpenAI
# ============================================================
# LOGGING SETUP
# ============================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
try:
from dotenv import load_dotenv
dotenv_path = Path(__file__).resolve().parent / '.env'
if dotenv_path.exists():
load_dotenv(dotenv_path=dotenv_path)
logger.info(f"✅ Đã nạp biến môi trường từ {dotenv_path}")
else:
logger.info(f"⚠️ .env không tồn tại tại {dotenv_path}. Nếu bạn dùng .env, hãy tạo file ở thư mục server.")
except ImportError:
logger.warning("⚠️ python-dotenv chưa cài, .env sẽ không được load tự động.")
# ============================================================
# CONFIGURATION CLASS
# ============================================================
@dataclass
class SystemConfig:
# --- Cấu hình Người số 1 (Data & DB) ---
GOOGLE_CREDENTIALS_FILE: str = "CREDENTIALS_JSON_CONTENT"
GOOGLE_TOKEN_FILE: str = "TOKEN_JSON_CONTENT"
SCOPES: List[str] = None
GOOGLE_FOLDER_ID: str = "1RLjyoxo88y0wpQNgTG4uhFPdWQyRBmIn"
QDRANT_HOST: str = "localhost"
QDRANT_PORT: int = 6333
QDRANT_COLLECTION_NAME: str = "van_mau"
EMBEDDING_MODEL_NAME: str = "dangvantuan/vietnamese-embedding"
CHUNK_SIZE: int = 500
CHUNK_OVERLAP: int = 100
RAW_DIR: str = "data/raw"
# --- Cấu hình Người số 2 (Qwen LLM) ---
NVIDIA_API_KEY: str = os.getenv("NVIDIA_API_KEY", "")
NVIDIA_BASE_URL: str = "https://integrate.api.nvidia.com/v1"
NVIDIA_MODEL: str = "google/gemma-4-31b-it"
TOGETHER_API_KEY: str = os.getenv("TOGETHER_API_KEY", "")
TOGETHER_BASE_URL: str = "https://api.together.xyz/v1"
TOGETHER_MODEL: str = "Qwen/Qwen2.5-72B-Instruct-Turbo"
OPENROUTER_API_KEY: str = os.getenv("OPENROUTER_API_KEY", "")
OPENROUTER_BASE_URL: str = "https://openrouter.ai/api/v1"
OPENROUTER_MODEL: str = "google/gemma-4-31b-it:free"
# nvidia | together | openrouter
ACTIVE_PROVIDER: str = os.getenv("QWEN_PROVIDER", "openrouter")
TEMPERATURE: float = 0.1
MAX_TOKENS: int = 2048
TOP_P: float = 0.9
def __post_init__(self):
if self.SCOPES is None:
self.SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
Path(self.RAW_DIR).mkdir(parents=True, exist_ok=True)
@property
def active_api_key(self) -> str:
mapping = {
"nvidia": self.NVIDIA_API_KEY,
"together": self.TOGETHER_API_KEY,
"openrouter": self.OPENROUTER_API_KEY
}
return mapping.get(self.ACTIVE_PROVIDER, "")
@property
def active_base_url(self) -> str:
mapping = {
"nvidia": self.NVIDIA_BASE_URL,
"together": self.TOGETHER_BASE_URL,
"openrouter": self.OPENROUTER_BASE_URL
}
return mapping.get(self.ACTIVE_PROVIDER, "")
@property
def active_model(self) -> str:
mapping = {
"nvidia": self.NVIDIA_MODEL,
"together": self.TOGETHER_MODEL,
"openrouter": self.OPENROUTER_MODEL
}
return mapping.get(self.ACTIVE_PROVIDER, "")
# Khởi tạo Singleton Config
config = SystemConfig()
logger.info(f"✅ Active provider: {config.ACTIVE_PROVIDER}")
if config.active_api_key:
logger.info("✅ Đã tìm thấy API key cho provider hiện tại.")
else:
logger.warning(
"⚠️ Chưa tìm thấy API key cho provider hiện tại. "
"Thiết lập NVIDIA_API_KEY/TOGETHER_API_KEY/OPENROUTER_API_KEY trước khi chạy."
)
# ============================================================
# PROMPTS & ESSAY RUBRICS (Người số 2)
# ============================================================
SYSTEM_PROMPT = """Bạn là một giám khảo chấm thi môn Ngữ văn cấp THPT với hơn 20 năm kinh nghiệm.
Nhiệm vụ của bạn là chấm điểm bài văn học sinh dựa trên tài liệu đáp án chuẩn được cung cấp.
# NGUYÊN TẮC CHẤM BÀI
1. Bám sát đáp án chuẩn — chỉ cho điểm những ý học sinh trình bày đúng hoặc tương đương đáp án.
2. Khách quan — không thiên vị, không thêm điểm vì văn phong hay nếu nội dung sai lệch biểu điểm.
3. Nhất quán — áp dụng cùng một tiêu chí chấm điểm chặt chẽ.
4. Chi tiết — chỉ ra cụ thể ưu điểm và nhược điểm, không nhận xét mơ hồ.
# THANG ĐIỂM
- Thang điểm tối đa: 10
- Làm tròn đến 0.25 điểm gần nhất (ví dụ: 6.0, 6.25, 6.5, 6.75, 7.0...)
- Điểm liệt: bài lạc đề hoàn toàn hoặc bỏ giấy trắng.
# QUY TẮC OUTPUT BẮT BUỘC
- Bạn PHẢI trả về DUY NHẤT một JSON object hợp lệ.
- KHÔNG viết bất kỳ chữ nào trước hoặc sau JSON.
- KHÔNG dùng markdown code block (không có ```json).
- KHÔNG để comment (như // hoặc /* */) bên trong JSON.
- Đảm bảo tất cả dấu nháy và ngoặc đóng/mở đều khớp chính xác.
# CẤU TRÚC JSON OUTPUT MẪU:
{
"diem": 7.5,
"xep_loai": "Khá",
"nhan_xet_chung": "Bài viết có bố cục rõ ràng, nắm được đặc điểm nội dung yêu cầu...",
"uu_diem": [
{
"tieu_chi": "Hiểu nhân vật",
"mo_ta": "Học sinh hiểu sâu sắc về số phận của Thúy Kiều..."
}
],
"nhuoc_diem": [
{
"tieu_chi": "Phân tích nội tâm",
"mo_ta": "Bài viết chưa phân tích kỹ mâu thuẫn giằng xé bên trong nhân vật.",
"goi_y_sua": "Nên đưa thêm dẫn chứng về các câu thơ độc thoại nội tâm ở lầu Ngưng Bích."
}
],
"chi_tiet_diem": {
"noi_dung": 4.5,
"hinh_thuc": 1.5,
"sang_tao": 1.5
},
"ket_luan": "Một bài viết khá, cần phát huy khả năng sáng tạo độc đáo hơn."
}"""
USER_PROMPT_TEMPLATE = """# TÀI LIỆU ĐÁP ÁN CHUẨN
Dưới đây là đáp án/biểu điểm chính thức dùng để đối chiếu:
---
{tai_lieu_chuan}
---
# BÀI VĂN CỦA HỌC SINH CẦN CHẤM
---
{bai_van}
---
# YÊU CẦU:
Chỉ trả về duy nhất chuỗi JSON hợp lệ không có tiền tố hay hậu tố gì khác ngoài chuỗi JSON này."""
def build_messages(bai_van: str, tai_lieu_chuan: str) -> List[Dict[str, str]]:
user_content = USER_PROMPT_TEMPLATE.format(
tai_lieu_chuan=tai_lieu_chuan.strip(),
bai_van=bai_van.strip(),
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
# ============================================================
# NGƯỜI SỐ 1: VĂN BẢN & VECTOR DATABASE PIPELINE (QDRANT)
# ============================================================
class TextProcessor:
@staticmethod
def normalize_unicode(text: str) -> str:
return unicodedata.normalize('NFC', text)
@staticmethod
def remove_extra_spaces(text: str) -> str:
text = re.sub(r'\s+', ' ', text)
return text.strip()
@staticmethod
def clean_text(text: str) -> str:
text = TextProcessor.normalize_unicode(text)
return TextProcessor.remove_extra_spaces(text)
class ChunkingStrategy:
def __init__(self, chunk_size: int = 300, overlap: int = 50):
self.chunk_size = chunk_size
self.overlap = overlap
def chunk_smart(self, text: str) -> List[str]:
# Cắt theo đoạn văn trước
paragraphs = re.split(r'\n\s*\n', text)
paragraphs = [p.strip() for p in paragraphs if p.strip()]
chunks = []
for para in paragraphs:
if len(para) <= self.chunk_size:
chunks.append(para)
else:
chunks.extend(self._chunk_by_size(para))
return chunks
def _chunk_by_size(self, text: str) -> List[str]:
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = min(start + self.chunk_size, text_len)
if end < text_len:
for sep in ['. ', '; ', ', ', ' ']:
pos = text.rfind(sep, start, end)
if pos != -1:
end = pos + len(sep)
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = max(start + 1, end - self.overlap)
return chunks
class GoogleDriveManager:
def __init__(self, credentials_file: str, token_file: str):
self.credentials_file = credentials_file
self.token_file = token_file
self.service = None
self._authenticate()
def _authenticate(self):
creds = None
if os.path.exists(self.token_file):
creds = Credentials.from_authorized_user_file(self.token_file)
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
if not os.path.exists(self.credentials_file):
raise FileNotFoundError(
f"Không tìm thấy file xác thực '{self.credentials_file}' cho Google Drive. "
"Vui lòng tải nó từ Google Cloud Console về."
)
flow = InstalledAppFlow.from_client_secrets_file(
self.credentials_file,
['[https://www.googleapis.com/auth/drive.readonly](https://www.googleapis.com/auth/drive.readonly)']
)
creds = flow.run_local_server(port=0)
with open(self.token_file, 'w') as token:
token.write(creds.to_json())
self.service = build('drive', 'v3', credentials=creds)
logger.info("✅ Xác thực Google Drive API thành công.")
def download_all_files(self, folder_id: str, destination: str) -> List[str]:
query = f"'{folder_id}' in parents and trashed=false"
results = self.service.files().list(
q=query,
fields="files(id, name, mimeType)"
).execute()
files = results.get('files', [])
logger.info(f"📁 Tìm thấy {len(files)} file trên thư mục Google Drive.")
downloaded_files = []
for file_info in files:
file_name = file_info['name']
file_path = os.path.join(destination, file_name)
request = self.service.files().get_media(fileId=file_info['id'])
with open(file_path, 'wb') as f:
downloader = MediaIoBaseDownload(f, request)
done = False
while not done:
status, done = downloader.next_chunk()
downloaded_files.append(file_path)
logger.info(f"✅ Đã tải: {file_name}")
return downloaded_files
class DocumentReader:
@staticmethod
def read_file(file_path: str) -> str:
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"❌ Không thể đọc file {file_path}: {e}")
return ""
class EmbeddingModel:
def __init__(self, model_name: str):
logger.info(f"🧠 Đang nạp mô hình Embedding: {model_name}...")
self.model = SentenceTransformer(model_name)
self.vector_size = self.model.get_embedding_dimension()
logger.info(f"✅ Mô hình sẵn sàng. Vector size: {self.vector_size}")
def encode(self, texts: List[str]) -> np.ndarray:
if isinstance(texts, str):
texts = [texts]
return self.model.encode(texts, convert_to_numpy=True)
class QdrantManager:
def __init__(self, host: str, port: int, collection_name: str, vector_size: int):
self.collection_name = collection_name
self.vector_size = vector_size
try:
self.client = QdrantClient(url=f"http://{host}:{port}")
self.client.get_collections()
logger.info(f"✅ Đã kết nối đến Qdrant server tại {host}:{port}")
except Exception as e:
logger.warning(f"⚠️ Không kết nối được Qdrant server tại {host}:{port}. Thử local path ./qdrant_data. Lỗi: {e}")
try:
self.client = QdrantClient(path="./qdrant_data")
logger.info("✅ Đã kết nối tới Qdrant local path ./qdrant_data")
except Exception as e_local:
logger.error(f"❌ Không thể dùng Qdrant local: {e_local}")
logger.warning("⚠️ Qdrant không khả dụng — server sẽ chạy ở chế độ không có RAG (chỉ chấm thuần LLM).")
self.client = None
# Nếu client không khả dụng, create_collection sẽ bỏ qua
try:
self.create_collection(force=False)
except Exception as e_create:
logger.warning(f"⚠️ Tạo collection thất bại (bỏ qua): {e_create}")
def create_collection(self, force: bool = False):
if self.client is None:
logger.info("ℹ️ Bỏ qua tạo collection vì Qdrant không khả dụng.")
return
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists:
if force:
self.client.delete_collection(self.collection_name)
else:
logger.info(f"Bộ sưu tập '{self.collection_name}' đã tồn tại.")
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.vector_size,
distance=Distance.COSINE
)
)
logger.info(f"✅ Đã khởi tạo mới bộ sưu tập '{self.collection_name}' trên Qdrant.")
def upsert_chunks(self, chunks: List[Dict], embedding_model: EmbeddingModel):
if not chunks:
logger.warning("Không có dữ liệu văn bản để đẩy lên DB.")
return
contents = [chunk['content'] for chunk in chunks]
embeddings = embedding_model.encode(contents)
points = []
for i, chunk in enumerate(chunks):
points.append(PointStruct(
id=i,
vector=embeddings[i].tolist(),
payload={
"content": chunk['content'],
"document_id": chunk['document_id'],
"file_name": chunk.get('file_name', ''),
"chunk_index": chunk.get('chunk_index', i)
}
))
if self.client is None:
logger.warning("⚠️ Bỏ qua upsert chunks vì Qdrant không khả dụng.")
return
self.client.upsert(
collection_name=self.collection_name,
points=points
)
logger.info(f"✅ Đã tải {len(points)} khối thông tin (chunks) lên Qdrant thành công.")
def search(self, query_vector: List[float], limit: int = 3) -> List[Dict]:
if self.client is None:
logger.info("ℹ️ Qdrant không khả dụng — trả về danh sách ngữ cảnh rỗng.")
return []
query_response = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=limit,
with_payload=True,
)
return [
{
"content": hit.payload["content"],
"score": hit.score,
"file_name": hit.payload.get("file_name", "")
}
for hit in query_response.points
]
class ContextRetriever:
def __init__(self, qdrant: QdrantManager, embedding: EmbeddingModel):
self.qdrant = qdrant
self.embedding = embedding
self.text_processor = TextProcessor()
def search_context(self, bai_van_hoc_sinh: str, limit: int = 3) -> str:
cleaned_query = self.text_processor.clean_text(bai_van_hoc_sinh)
query_vector = self.embedding.encode([cleaned_query])[0]
results = self.qdrant.search(query_vector=query_vector.tolist(), limit=limit)
contexts = []
for r in results:
contexts.append(f"--- NGUỒN ĐÁP ÁN ({r['file_name']}) ---\n{r['content']}")
return "\n\n".join(contexts)
embedding_model = EmbeddingModel(config.EMBEDDING_MODEL_NAME)
# 2. Khởi tạo Qdrant (Lưu ý: đoạn này vẫn sẽ tự nhận path="./qdrant_data" vì ông đã sửa trong class ở trên rồi)
qdrant_manager = QdrantManager(
host=config.QDRANT_HOST,
port=config.QDRANT_PORT,
collection_name=config.QDRANT_COLLECTION_NAME,
vector_size=embedding_model.vector_size
)
# ============================================================
# NGƯỜI SỐ 2: LLM ENGINE & AUTO-JSON RECONSTRUCTION
# ============================================================
def _build_client() -> OpenAI:
return OpenAI(
api_key=config.active_api_key,
base_url=config.active_base_url,
)
def _extract_json(raw_text: str) -> str:
text = raw_text.strip()
# Bóc markdown code fence nếu model trả về ```json ... ```
fence_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
if fence_match:
text = fence_match.group(1).strip()
# Nếu output chỉ có một JSON object thì trả ngay
if text.startswith("{") and text.endswith("}"):
return text
# Tìm dấu { đầu tiên và } cuối cùng đóng cặp
start = text.find("{")
if start == -1:
raise ValueError("Không tìm thấy JSON object trong output của LLM")
depth = 0
in_string = False
escape_next = False
end = -1
for i, ch in enumerate(text[start:], start=start):
if escape_next:
escape_next = False
continue
if ch == "\\":
escape_next = True
continue
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
end = i
break
if end == -1:
raise ValueError("Không tìm thấy dấu đóng '}' tương ứng trong JSON output của LLM")
return text[start : end + 1]
def get_ai_grade(
bai_van: str,
tai_lieu_chuan: str,
max_retries: int = 3,
retry_delay: float = 2.0,
):
if not config.active_api_key:
raise RuntimeError(
f"Chưa cấu hình API key cho provider '{config.ACTIVE_PROVIDER}'"
)
client = _build_client()
messages = build_messages(
bai_van=bai_van,
tai_lieu_chuan=tai_lieu_chuan
)
for attempt in range(max_retries):
try:
# Use plain-text completion (no structured response_format) to maximize provider compatibility
response = client.chat.completions.create(
model=config.active_model,
messages=messages,
temperature=config.TEMPERATURE,
max_tokens=config.MAX_TOKENS,
top_p=config.TOP_P,
)
raw = response.choices[0].message.content
json_str = _extract_json(raw)
data = json.loads(json_str)
return data
except Exception as e:
logger.error(f"Attempt {attempt+1}: {e}")
if attempt + 1 == max_retries:
# Nếu LLM thực sự lỗi sau nhiều lần retry, trả về kết quả tạm thời
logger.error("LLM không phản hồi hợp lệ sau nhiều lần thử. Trả về kết quả tạm thời.")
fallback = {
"diem": 0.0,
"xep_loai": "Không chấm được",
"nhan_xet_chung": f"Lỗi LLM: {str(e)}",
"uu_diem": [],
"nhuoc_diem": [],
"chi_tiet_diem": {"noi_dung": 0.0, "hinh_thuc": 0.0, "sang_tao": 0.0},
"ket_luan": "Kết quả tạm thời do lỗi hệ thống LLM."
}
return fallback
time.sleep(retry_delay)
def grade_batch(
bai_van_list,
tai_lieu_chuan_list
):
results = []
for bai_van, context in zip(
bai_van_list,
tai_lieu_chuan_list
):
try:
results.append(
get_ai_grade(
bai_van,
context
)
)
except Exception as e:
results.append(
{"error": str(e)}
)
return results