DoAn / evaluation /eval_utils.py
hungnha's picture
change commit
b91b0a5
"""Các utility functions cho evaluation."""
import os
import sys
import re
import csv
from pathlib import Path
from datetime import datetime
from dotenv import find_dotenv, load_dotenv
from concurrent.futures import ThreadPoolExecutor, as_completed
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
load_dotenv(find_dotenv(usecwd=True))
from openai import OpenAI
from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
from core.rag.retrival import Retriever
from core.rag.generator import RAGGenerator
def strip_thinking(text: str) -> str:
"""Loại bỏ các block <think>...</think> từ output của LLM."""
return re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL).strip()
def load_csv_data(csv_path: str, sample_size: int = 0) -> tuple[list, list]:
"""Đọc dữ liệu câu hỏi và ground truth từ file CSV."""
questions, ground_truths = [], []
with open(csv_path, 'r', encoding='utf-8') as f:
for row in csv.DictReader(f):
if row.get('question') and row.get('ground_truth'):
questions.append(row['question'])
ground_truths.append(row['ground_truth'])
# Giới hạn số lượng sample
if sample_size > 0:
questions = questions[:sample_size]
ground_truths = ground_truths[:sample_size]
return questions, ground_truths
def init_rag() -> tuple[RAGGenerator, QwenEmbeddings, OpenAI]:
"""Khởi tạo các components RAG cho evaluation."""
embeddings = QwenEmbeddings(EmbeddingConfig())
db = ChromaVectorDB(embedder=embeddings, config=ChromaConfig())
retriever = Retriever(vector_db=db)
rag = RAGGenerator(retriever=retriever)
# Khởi tạo LLM client
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
raise ValueError("Chưa đặt SILICONFLOW_API_KEY")
llm_client = OpenAI(api_key=api_key, base_url="https://api.siliconflow.com/v1", timeout=60.0)
return rag, embeddings, llm_client
def generate_answers(
rag: RAGGenerator,
questions: list,
llm_client: OpenAI,
llm_model: str = "nex-agi/DeepSeek-V3.1-Nex-N1",
retrieval_mode: str = "hybrid_rerank",
max_workers: int = 8,
) -> tuple[list, list]:
"""Generate câu trả lời cho danh sách câu hỏi với parallel processing."""
def process(idx_q):
"""Xử lý một câu hỏi: retrieve + generate."""
idx, q = idx_q
try:
# Retrieve và chuẩn bị context
prepared = rag.retrieve_and_prepare(q, mode=retrieval_mode)
if not prepared["results"]:
return idx, "Không tìm thấy thông tin.", []
# Gọi LLM để generate answer
resp = llm_client.chat.completions.create(
model=llm_model,
messages=[{"role": "user", "content": prepared["prompt"]}],
temperature=0.0,
max_tokens=4096,
)
answer = strip_thinking(resp.choices[0].message.content or "")
return idx, answer, prepared["contexts"]
except Exception as e:
print(f" Q{idx+1} Lỗi: {e}")
return idx, "Không thể trả lời.", []
n = len(questions)
answers, contexts = [""] * n, [[] for _ in range(n)]
print(f" Đang generate {n} câu trả lời...")
# Xử lý song song với ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(process, (i, q)): i for i, q in enumerate(questions)}
for i, future in enumerate(as_completed(futures), 1):
idx, ans, ctx = future.result(timeout=120)
answers[idx], contexts[idx] = ans, ctx
print(f" [{i}/{n}] Xong")
return answers, contexts