|
|
"""Các utility functions cho evaluation.""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import re |
|
|
import csv |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
from dotenv import find_dotenv, load_dotenv |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1] |
|
|
if str(REPO_ROOT) not in sys.path: |
|
|
sys.path.insert(0, str(REPO_ROOT)) |
|
|
load_dotenv(find_dotenv(usecwd=True)) |
|
|
|
|
|
from openai import OpenAI |
|
|
from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings |
|
|
from core.rag.vector_store import ChromaConfig, ChromaVectorDB |
|
|
from core.rag.retrival import Retriever |
|
|
from core.rag.generator import RAGGenerator |
|
|
|
|
|
|
|
|
def strip_thinking(text: str) -> str: |
|
|
"""Loại bỏ các block <think>...</think> từ output của LLM.""" |
|
|
return re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL).strip() |
|
|
|
|
|
|
|
|
def load_csv_data(csv_path: str, sample_size: int = 0) -> tuple[list, list]: |
|
|
"""Đọc dữ liệu câu hỏi và ground truth từ file CSV.""" |
|
|
questions, ground_truths = [], [] |
|
|
with open(csv_path, 'r', encoding='utf-8') as f: |
|
|
for row in csv.DictReader(f): |
|
|
if row.get('question') and row.get('ground_truth'): |
|
|
questions.append(row['question']) |
|
|
ground_truths.append(row['ground_truth']) |
|
|
|
|
|
|
|
|
if sample_size > 0: |
|
|
questions = questions[:sample_size] |
|
|
ground_truths = ground_truths[:sample_size] |
|
|
|
|
|
return questions, ground_truths |
|
|
|
|
|
|
|
|
def init_rag() -> tuple[RAGGenerator, QwenEmbeddings, OpenAI]: |
|
|
"""Khởi tạo các components RAG cho evaluation.""" |
|
|
embeddings = QwenEmbeddings(EmbeddingConfig()) |
|
|
db = ChromaVectorDB(embedder=embeddings, config=ChromaConfig()) |
|
|
retriever = Retriever(vector_db=db) |
|
|
rag = RAGGenerator(retriever=retriever) |
|
|
|
|
|
|
|
|
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() |
|
|
if not api_key: |
|
|
raise ValueError("Chưa đặt SILICONFLOW_API_KEY") |
|
|
|
|
|
llm_client = OpenAI(api_key=api_key, base_url="https://api.siliconflow.com/v1", timeout=60.0) |
|
|
return rag, embeddings, llm_client |
|
|
|
|
|
|
|
|
def generate_answers( |
|
|
rag: RAGGenerator, |
|
|
questions: list, |
|
|
llm_client: OpenAI, |
|
|
llm_model: str = "nex-agi/DeepSeek-V3.1-Nex-N1", |
|
|
retrieval_mode: str = "hybrid_rerank", |
|
|
max_workers: int = 8, |
|
|
) -> tuple[list, list]: |
|
|
"""Generate câu trả lời cho danh sách câu hỏi với parallel processing.""" |
|
|
|
|
|
def process(idx_q): |
|
|
"""Xử lý một câu hỏi: retrieve + generate.""" |
|
|
idx, q = idx_q |
|
|
try: |
|
|
|
|
|
prepared = rag.retrieve_and_prepare(q, mode=retrieval_mode) |
|
|
if not prepared["results"]: |
|
|
return idx, "Không tìm thấy thông tin.", [] |
|
|
|
|
|
|
|
|
resp = llm_client.chat.completions.create( |
|
|
model=llm_model, |
|
|
messages=[{"role": "user", "content": prepared["prompt"]}], |
|
|
temperature=0.0, |
|
|
max_tokens=4096, |
|
|
) |
|
|
answer = strip_thinking(resp.choices[0].message.content or "") |
|
|
return idx, answer, prepared["contexts"] |
|
|
except Exception as e: |
|
|
print(f" Q{idx+1} Lỗi: {e}") |
|
|
return idx, "Không thể trả lời.", [] |
|
|
|
|
|
n = len(questions) |
|
|
answers, contexts = [""] * n, [[] for _ in range(n)] |
|
|
|
|
|
print(f" Đang generate {n} câu trả lời...") |
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
futures = {executor.submit(process, (i, q)): i for i, q in enumerate(questions)} |
|
|
for i, future in enumerate(as_completed(futures), 1): |
|
|
idx, ans, ctx = future.result(timeout=120) |
|
|
answers[idx], contexts[idx] = ans, ctx |
|
|
print(f" [{i}/{n}] Xong") |
|
|
|
|
|
return answers, contexts |
|
|
|