CultureBot / main.py
T-Phong
init project
dded31e
"""
Main entry point để test RAG system
"""
import pandas as pd
from mark import evaluate_rag_pipeline
from rag import retrieve_context
import json
from reranking import advanced_search
from rewrite import QueryRewriter
# --- VÍ DỤ THỰC TẾ ---
def main():
rewriter = QueryRewriter()
"""Hàm main để test retrieve_context"""
print("=" * 60)
print("Vietnam Heritage RAG System - Test")
print("=" * 60)
# Test queries
test_queries = [
"Bánh xèo và bánh chưng khác nhau ở điểm nào?"
]
for i, query in enumerate(test_queries, 1):
print(f"\n[Test {i}] Query: {query}")
print("-" * 60)
try:
# Gọi retrieve_context từ rag.py
temp = advanced_search(query,['Bánh xèo','bánh chưng'])
# In kết quả
print(temp)
except Exception as e:
print(f" ❌ Lỗi: {e}")
print("\n" + "=" * 60)
print("Test hoàn thành!")
print("=" * 60)
def interactive_mode():
rewriter = QueryRewriter()
"""Chế độ tương tác - hỏi đáp liên tục"""
print("\n🎯 Chế độ tương tác (Gõ 'exit' để thoát)")
print("-" * 60)
while True:
query = input("\n❓ Câu hỏi: ").strip()
keyword = []
print("Nhập chuỗi (nhập 'done' để kết thúc):")
while True:
nhap = input()
if nhap.lower() == 'done':
break
keyword.append(nhap)
print("Mảng của bạn:", keyword)
if query.lower() in ['exit', 'quit', 'q']:
print("👋 Tạm biệt!")
break
if not query:
print("⚠️ Vui lòng nhập câu hỏi!")
continue
try:
# rewrite question with key word
q_rewrite = rewriter.rewrite(query,keyword)
print(f"\n--- q_rewrite: {q_rewrite} ---")
# get top 30 RAG and reranking by question rewrite and keyword then get 5
p = advanced_search(q_rewrite,keyword)
print("\n📝 Kết quả RAG + Reranking:", p)
except Exception as e:
print(f"❌ Lỗi: {e}")
def mark():
# Giả lập dữ liệu từ hệ thống RAG của bạn
test_cases = [
# CASE 1: TỐT TOÀN DIỆN (Tìm đúng, Trả lời đúng)
{
"question": "Lễ hội Gióng được UNESCO công nhận năm nào?",
"retrieved_contexts": [
"Lễ hội Gióng ở đền Phù Đổng và đền Sóc được UNESCO công nhận là Di sản văn hóa phi vật thể đại diện của nhân loại vào năm 2010.",
"Thánh Gióng là một trong tứ bất tử."
],
"model_answer": "Lễ hội Gióng được UNESCO công nhận vào năm 2010.",
"ground_truth": "Năm 2010."
},
# CASE 2: RETRIEVAL KÉM (Tìm sai tài liệu -> Bot không trả lời được hoặc bịa)
{
"question": "Nguyên liệu chính làm quả cầu trong lễ hội gieo cầu là gì?",
"retrieved_contexts": [
"Lễ hội Gióng tái hiện trận đánh giặc Ân.", # <--- Context không liên quan gì đến quả cầu
"Đền Hùng nằm ở Phú Thọ."
],
"model_answer": "Quả cầu được làm bằng nhựa.", # <--- Bot bịa (Hallucination) do không có context
"ground_truth": "Gỗ hoặc da."
},
# CASE 3: RETRIEVAL TỐT NHƯNG BOT BỊA (Hallucination)
{
"question": "Ý nghĩa của Lễ hội Đền Hùng?",
"retrieved_contexts": [
"Lễ hội Đền Hùng thể hiện lòng biết ơn sâu sắc đối với các Vua Hùng đã có công dựng nước."
],
"model_answer": "Lễ hội Đền Hùng là để cầu mưa thuận gió hòa cho miền Tây sông nước.", # <--- Sai, không dựa vào context
"ground_truth": "Tưởng nhớ công lao dựng nước của các Vua Hùng."
}
]
# --- VÒNG LẶP ĐÁNH GIÁ ---
results = []
print("Đang đánh giá hệ thống RAG...")
for i, case in enumerate(test_cases):
print(f"Processing case {i+1}...")
scores = evaluate_rag_pipeline(
case["question"],
case["retrieved_contexts"],
case["model_answer"],
case["ground_truth"]
)
# Gộp kết quả
case_result = {**case, **scores} # Merge dict
results.append(case_result)
# --- HIỂN THỊ KẾT QUẢ ---
df = pd.DataFrame(results)
# Chỉ hiện các cột quan trọng
display_cols = ["question", "context_score", "faithfulness_score", "correctness_score", "reason"]
print("\nBẢNG ĐIỂM CHI TIẾT:")
print(df[display_cols].to_string())
# Tính điểm tổng kết
print("\n--- TỔNG KẾT HIỆU SUẤT ---")
print(f"Độ chính xác tìm kiếm (Retrieval Accuracy): {df['context_score'].mean() * 100:.1f}%")
print(f"Độ trung thực (Faithfulness): {df['faithfulness_score'].mean() * 100:.1f}%")
print(f"Độ chính xác câu trả lời (End-to-End Accuracy): {df['correctness_score'].mean() * 100:.1f}%")
if __name__ == "__main__":
import sys
# Kiểm tra argument
if len(sys.argv) > 1 and sys.argv[1] == "--interactive":
interactive_mode()
else:
mark()