File size: 3,989 Bytes
b91b0a5
 
9681056
 
c429a2d
9681056
 
 
 
c429a2d
9681056
 
 
 
 
 
c429a2d
 
 
 
 
9681056
 
 
b91b0a5
9681056
 
 
 
b91b0a5
9681056
 
 
 
 
 
b91b0a5
 
9681056
 
 
b91b0a5
9681056
 
 
c429a2d
b91b0a5
c429a2d
 
9681056
c429a2d
794ce9a
b91b0a5
794ce9a
 
b91b0a5
794ce9a
c429a2d
 
9681056
 
 
c429a2d
9681056
c429a2d
794ce9a
c429a2d
 
9681056
b91b0a5
9681056
c429a2d
b91b0a5
794ce9a
9681056
b91b0a5
c429a2d
794ce9a
c429a2d
794ce9a
b91b0a5
c429a2d
794ce9a
 
 
 
 
c429a2d
 
9681056
b91b0a5
c429a2d
794ce9a
 
c429a2d
794ce9a
b91b0a5
 
 
794ce9a
c429a2d
 
 
 
b91b0a5
9681056
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Các utility functions cho evaluation."""

import os
import sys
import re
import csv
from pathlib import Path
from datetime import datetime
from dotenv import find_dotenv, load_dotenv
from concurrent.futures import ThreadPoolExecutor, as_completed

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
load_dotenv(find_dotenv(usecwd=True))

from openai import OpenAI
from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
from core.rag.retrival import Retriever
from core.rag.generator import RAGGenerator


def strip_thinking(text: str) -> str:
    """Loại bỏ các block <think>...</think> từ output của LLM."""
    return re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL).strip()


def load_csv_data(csv_path: str, sample_size: int = 0) -> tuple[list, list]:
    """Đọc dữ liệu câu hỏi và ground truth từ file CSV."""
    questions, ground_truths = [], []
    with open(csv_path, 'r', encoding='utf-8') as f:
        for row in csv.DictReader(f):
            if row.get('question') and row.get('ground_truth'):
                questions.append(row['question'])
                ground_truths.append(row['ground_truth'])
    
    # Giới hạn số lượng sample 
    if sample_size > 0:
        questions = questions[:sample_size]
        ground_truths = ground_truths[:sample_size]
    
    return questions, ground_truths


def init_rag() -> tuple[RAGGenerator, QwenEmbeddings, OpenAI]:
    """Khởi tạo các components RAG cho evaluation."""
    embeddings = QwenEmbeddings(EmbeddingConfig())
    db = ChromaVectorDB(embedder=embeddings, config=ChromaConfig())
    retriever = Retriever(vector_db=db)
    rag = RAGGenerator(retriever=retriever)
    
    # Khởi tạo LLM client
    api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
    if not api_key:
        raise ValueError("Chưa đặt SILICONFLOW_API_KEY")
    
    llm_client = OpenAI(api_key=api_key, base_url="https://api.siliconflow.com/v1", timeout=60.0)
    return rag, embeddings, llm_client


def generate_answers(
    rag: RAGGenerator,
    questions: list,
    llm_client: OpenAI,
    llm_model: str = "nex-agi/DeepSeek-V3.1-Nex-N1",
    retrieval_mode: str = "hybrid_rerank",
    max_workers: int = 8,
) -> tuple[list, list]:
    """Generate câu trả lời cho danh sách câu hỏi với parallel processing."""
    
    def process(idx_q):
        """Xử lý một câu hỏi: retrieve + generate."""
        idx, q = idx_q
        try:
            # Retrieve và chuẩn bị context
            prepared = rag.retrieve_and_prepare(q, mode=retrieval_mode)
            if not prepared["results"]:
                return idx, "Không tìm thấy thông tin.", []
            
            # Gọi LLM để generate answer
            resp = llm_client.chat.completions.create(
                model=llm_model,
                messages=[{"role": "user", "content": prepared["prompt"]}],
                temperature=0.0,
                max_tokens=4096,
            )
            answer = strip_thinking(resp.choices[0].message.content or "")
            return idx, answer, prepared["contexts"]
        except Exception as e:
            print(f"  Q{idx+1} Lỗi: {e}")
            return idx, "Không thể trả lời.", []
    
    n = len(questions)
    answers, contexts = [""] * n, [[] for _ in range(n)]
    
    print(f"  Đang generate {n} câu trả lời...")
    
    # Xử lý song song với ThreadPoolExecutor
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process, (i, q)): i for i, q in enumerate(questions)}
        for i, future in enumerate(as_completed(futures), 1):
            idx, ans, ctx = future.result(timeout=120)
            answers[idx], contexts[idx] = ans, ctx
            print(f"  [{i}/{n}] Xong")
    
    return answers, contexts