Spaces:
Sleeping
Sleeping
| """ | |
| Single Turn Synthetic Retrieval Evaluation Dataset Creator for ViettelPay RAG System | |
| Uses Google Gemini 2.0 Flash with JSON responses for better parsing | |
| Simplified version with only MRR and hit rate evaluation (no qrels generation) | |
| """ | |
| import json | |
| import os | |
| import sys | |
| import argparse | |
| import time | |
| from typing import Dict, List, Tuple, Optional, Union | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import re | |
| # Load environment variables from .env file | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Add the project root to Python path so we can import from src | |
| project_root = Path(__file__).parent.parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| # Import prompts (only the ones we need) | |
| from src.evaluation.prompts import ( | |
| QUESTION_GENERATION_PROMPT, | |
| QUESTION_QUALITY_CHECK_PROMPT, | |
| CONTEXT_QUALITY_CHECK_PROMPT, | |
| QUESTION_EVOLUTION_PROMPT, | |
| ) | |
| # Import your existing knowledge base and LLM client | |
| from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase | |
| from src.llm.llm_client import LLMClientFactory, BaseLLMClient | |
| class SingleTurnDatasetCreator: | |
| """Single turn synthetic evaluation dataset creator with JSON responses and all chunks processing""" | |
| def __init__( | |
| self, gemini_api_key: str, knowledge_base: ViettelKnowledgeBase = None | |
| ): | |
| """ | |
| Initialize with Gemini API key and optional knowledge base | |
| Args: | |
| gemini_api_key: Google AI API key for Gemini | |
| knowledge_base: Pre-initialized ViettelKnowledgeBase instance | |
| """ | |
| self.llm_client = LLMClientFactory.create_client( | |
| "gemini", api_key=gemini_api_key, model="gemini-2.0-flash" | |
| ) | |
| self.knowledge_base = knowledge_base | |
| self.dataset = { | |
| "queries": {}, | |
| "documents": {}, | |
| "metadata": { | |
| "total_chunks_processed": 0, | |
| "questions_generated": 0, | |
| "creation_timestamp": time.time(), | |
| }, | |
| } | |
| print("β SingleTurnDatasetCreator initialized with Gemini 2.0 Flash") | |
| def generate_json_response( | |
| self, prompt: str, max_retries: int = 3 | |
| ) -> Optional[Dict]: | |
| """ | |
| Generate response and parse as JSON with retries | |
| Args: | |
| prompt: Input prompt | |
| max_retries: Maximum number of retry attempts | |
| Returns: | |
| Parsed JSON response or None if failed | |
| """ | |
| for attempt in range(max_retries): | |
| try: | |
| response = self.llm_client.generate(prompt, temperature=0.1) | |
| if response: | |
| # Clean response text | |
| response_text = response.strip() | |
| # Extract JSON from response (handle cases with extra text) | |
| json_match = re.search(r"\{.*\}", response_text, re.DOTALL) | |
| if json_match: | |
| json_text = json_match.group() | |
| return json.loads(json_text) | |
| else: | |
| # Try parsing the whole response | |
| return json.loads(response_text) | |
| except json.JSONDecodeError as e: | |
| print(f"β οΈ JSON parsing error (attempt {attempt + 1}): {e}") | |
| if attempt == max_retries - 1: | |
| print(f"β Failed to parse JSON after {max_retries} attempts") | |
| print( | |
| f"Raw response: {response if 'response' in locals() else 'No response'}" | |
| ) | |
| except Exception as e: | |
| print(f"β οΈ API error (attempt {attempt + 1}): {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(2**attempt) # Exponential backoff | |
| return None | |
| def get_all_chunks(self) -> List[Dict]: | |
| """ | |
| Get ALL chunks directly from ChromaDB vectorstore (no sampling) | |
| Returns: | |
| List of all document chunks with content and metadata | |
| """ | |
| print(f"π Retrieving ALL chunks directly from ChromaDB vectorstore...") | |
| if not self.knowledge_base: | |
| raise ValueError( | |
| "Knowledge base not provided. Please initialize with a ViettelKnowledgeBase instance." | |
| ) | |
| try: | |
| # Access the ChromaDB vectorstore directly | |
| if ( | |
| not hasattr(self.knowledge_base, "chroma_retriever") | |
| or not self.knowledge_base.chroma_retriever | |
| ): | |
| raise ValueError("ChromaDB retriever not found in knowledge base") | |
| # Get the vectorstore from the retriever | |
| vectorstore = self.knowledge_base.chroma_retriever.vectorstore | |
| # Get all documents directly from ChromaDB | |
| print(" Accessing ChromaDB collection...") | |
| all_docs = vectorstore.get(include=["documents", "metadatas"]) | |
| documents = all_docs["documents"] | |
| metadatas = all_docs["metadatas"] | |
| print(f" Found {len(documents)} documents in ChromaDB") | |
| print(f" Sample document preview:") | |
| for i, doc in enumerate(documents[:3]): | |
| print(f" Doc {i+1}: {doc[:100]}...") | |
| # Convert to our expected format | |
| all_chunks = [] | |
| seen_content_hashes = set() | |
| for i, (content, metadata) in enumerate(zip(documents, metadatas)): | |
| # Create content hash for deduplication (just in case) | |
| content_hash = hash(content[:300]) | |
| if ( | |
| content_hash not in seen_content_hashes | |
| and len(content.strip()) > 50 | |
| ): | |
| chunk_info = { | |
| "id": f"chunk_{len(all_chunks)}", | |
| "content": content, | |
| "metadata": metadata or {}, | |
| "source": "chromadb_direct", | |
| "content_length": len(content), | |
| "original_index": i, | |
| } | |
| all_chunks.append(chunk_info) | |
| seen_content_hashes.add(content_hash) | |
| else: | |
| if content_hash in seen_content_hashes: | |
| print(f" β οΈ Skipping duplicate content at index {i}") | |
| else: | |
| print( | |
| f" β οΈ Skipping short content at index {i} (length: {len(content.strip())})" | |
| ) | |
| print(f"β Retrieved {len(all_chunks)} unique chunks from ChromaDB") | |
| print( | |
| f" Filtered out {len(documents) - len(all_chunks)} duplicates/short chunks" | |
| ) | |
| # Sort by content length (longer chunks first, usually more informative) | |
| all_chunks.sort(key=lambda x: x["content_length"], reverse=True) | |
| # Display statistics | |
| avg_length = sum(chunk["content_length"] for chunk in all_chunks) / len( | |
| all_chunks | |
| ) | |
| min_length = min(chunk["content_length"] for chunk in all_chunks) | |
| max_length = max(chunk["content_length"] for chunk in all_chunks) | |
| print(f" π Chunk Statistics:") | |
| print(f" β’ Average length: {avg_length:.0f} characters") | |
| print(f" β’ Min length: {min_length} characters") | |
| print(f" β’ Max length: {max_length} characters") | |
| return all_chunks | |
| except Exception as e: | |
| print(f"β Error accessing ChromaDB directly: {e}") | |
| print(f" Falling back to search-based method...") | |
| return self._get_all_chunks_fallback() | |
| def _get_all_chunks_fallback(self) -> List[Dict]: | |
| """ | |
| Fallback method using search queries if direct ChromaDB access fails | |
| Returns: | |
| List of document chunks retrieved via search | |
| """ | |
| print(f"π Using fallback search-based chunk retrieval...") | |
| # Use comprehensive search terms to capture most content | |
| comprehensive_queries = [ | |
| "ViettelPay", | |
| "nαΊ‘p", | |
| "cΖ°α»c", | |
| "giao dα»ch", | |
| "thanh toΓ‘n", | |
| "lα»i", | |
| "hα»§y", | |
| "thαΊ»", | |
| "chuyα»n", | |
| "tiα»n", | |
| "quy Δα»nh", | |
| "phΓ", | |
| "dα»ch vα»₯", | |
| "tΓ i khoαΊ£n", | |
| "ngΓ’n hΓ ng", | |
| "OTP", | |
| "PIN", | |
| "mΓ£", | |
| "sα»", | |
| "Δiα»n thoαΊ‘i", | |
| "internet", | |
| "truyα»n hΓ¬nh", | |
| "homephone", | |
| "cα» Δα»nh", | |
| "game", | |
| "Viettel", | |
| "Mobifone", | |
| # Add some Vietnamese words that might not be captured above | |
| "α»©ng dα»₯ng", | |
| "khΓ‘ch hΓ ng", | |
| "hỠtrợ", | |
| "kiα»m tra", | |
| "xΓ‘c nhαΊn", | |
| "bαΊ£o mαΊt", | |
| ] | |
| all_chunks = [] | |
| seen_content_hashes = set() | |
| for query in comprehensive_queries: | |
| try: | |
| # Search with large k to get as many chunks as possible | |
| docs = self.knowledge_base.search(query, top_k=50) | |
| for doc in docs: | |
| # Create content hash for deduplication | |
| content_hash = hash(doc.page_content[:300]) | |
| if ( | |
| content_hash not in seen_content_hashes | |
| and len(doc.page_content.strip()) > 50 | |
| ): | |
| chunk_info = { | |
| "id": f"chunk_{len(all_chunks)}", | |
| "content": doc.page_content, | |
| "metadata": doc.metadata, | |
| "source": f"search_{query}", | |
| "content_length": len(doc.page_content), | |
| } | |
| all_chunks.append(chunk_info) | |
| seen_content_hashes.add(content_hash) | |
| except Exception as e: | |
| print(f"β οΈ Error searching for '{query}': {e}") | |
| continue | |
| print(f"β Fallback method retrieved {len(all_chunks)} unique chunks") | |
| # Sort by content length | |
| all_chunks.sort(key=lambda x: x["content_length"], reverse=True) | |
| return all_chunks | |
| def generate_questions_for_chunk( | |
| self, chunk: Dict, num_questions: int = 2 | |
| ) -> List[Dict]: | |
| """ | |
| Generate questions for a single chunk using Gemini with JSON response | |
| Args: | |
| chunk: Chunk dictionary with content and metadata | |
| num_questions: Number of questions to generate per chunk | |
| Returns: | |
| List of question dictionaries with metadata | |
| """ | |
| content = chunk["content"] | |
| prompt = QUESTION_GENERATION_PROMPT.format( | |
| num_questions=num_questions, content=content | |
| ) | |
| response_json = self.generate_json_response(prompt) | |
| if response_json and "questions" in response_json: | |
| questions = response_json["questions"] | |
| # Create question objects with metadata | |
| question_objects = [] | |
| for i, question_text in enumerate(questions): | |
| if len(question_text.strip()) > 5: # Filter very short questions | |
| question_obj = { | |
| "id": f"q_{chunk['id']}_{i}", | |
| "text": question_text.strip(), | |
| "source_chunk": chunk["id"], | |
| "chunk_metadata": chunk["metadata"], | |
| "generation_method": "gemini_json", | |
| } | |
| question_objects.append(question_obj) | |
| return question_objects | |
| else: | |
| print(f"β οΈ No valid questions generated for chunk {chunk['id']}") | |
| return [] | |
| def check_context_quality(self, chunk: Dict) -> bool: | |
| """ | |
| Check if a chunk is suitable for question generation | |
| Args: | |
| chunk: Chunk dictionary | |
| Returns: | |
| True if chunk should be used, False otherwise | |
| """ | |
| content = chunk["content"] | |
| # Basic checks first | |
| if len(content.strip()) < 100: | |
| return False | |
| # Use Gemini for quality assessment | |
| prompt = CONTEXT_QUALITY_CHECK_PROMPT.format(content=content[:1000]) | |
| response_json = self.generate_json_response(prompt) | |
| if response_json: | |
| return response_json.get("use_context", True) | |
| else: | |
| # Fallback to basic heuristics | |
| return len(content.strip()) > 100 and len(content.split()) > 20 | |
| def create_complete_dataset( | |
| self, | |
| questions_per_chunk: int = 2, | |
| save_path: str = "evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval_dataset.json", | |
| quality_check: bool = True, | |
| ) -> Dict: | |
| """ | |
| Create complete synthetic evaluation dataset using ALL chunks | |
| Args: | |
| questions_per_chunk: Number of questions to generate per chunk | |
| save_path: Path to save the dataset JSON file | |
| quality_check: Whether to perform quality checks on chunks | |
| Returns: | |
| Complete dataset dictionary | |
| """ | |
| print(f"\nπ Creating simplified synthetic evaluation dataset...") | |
| print(f" Target: Process ALL chunks from knowledge base") | |
| print(f" Questions per chunk: {questions_per_chunk}") | |
| print(f" Quality check: {quality_check}") | |
| print(f" Evaluation method: MRR and Hit Rates only (no qrels)") | |
| # Step 1: Get all chunks | |
| all_chunks = self.get_all_chunks() | |
| total_chunks = len(all_chunks) | |
| if total_chunks == 0: | |
| raise ValueError("No chunks found in knowledge base!") | |
| print(f"β Found {total_chunks} chunks to process") | |
| # Step 2: Quality filtering (optional) | |
| if quality_check: | |
| print(f"\nπ Performing quality checks on chunks...") | |
| quality_chunks = [] | |
| for chunk in tqdm(all_chunks, desc="Quality checking"): | |
| if self.check_context_quality(chunk): | |
| quality_chunks.append(chunk) | |
| time.sleep(0.1) # Rate limiting | |
| print( | |
| f"β {len(quality_chunks)}/{total_chunks} chunks passed quality check" | |
| ) | |
| chunks_to_process = quality_chunks | |
| else: | |
| chunks_to_process = all_chunks | |
| # Step 3: Generate questions for all chunks | |
| print(f"\nπ Generating questions for {len(chunks_to_process)} chunks...") | |
| all_questions = [] | |
| for chunk in tqdm(chunks_to_process, desc="Generating questions"): | |
| questions = self.generate_questions_for_chunk(chunk, questions_per_chunk) | |
| all_questions.extend(questions) | |
| time.sleep(0.2) # Rate limiting for Gemini API | |
| print( | |
| f"β Generated {len(all_questions)} questions from {len(chunks_to_process)} chunks" | |
| ) | |
| # Step 4: Populate dataset structure | |
| self.dataset["documents"] = { | |
| chunk["id"]: chunk["content"] for chunk in chunks_to_process | |
| } | |
| self.dataset["queries"] = {q["id"]: q["text"] for q in all_questions} | |
| # Add question metadata | |
| question_metadata = { | |
| q["id"]: { | |
| "source_chunk": q["source_chunk"], | |
| "chunk_metadata": q["chunk_metadata"], | |
| "generation_method": q["generation_method"], | |
| } | |
| for q in all_questions | |
| } | |
| self.dataset["question_metadata"] = question_metadata | |
| # Step 5: Update metadata | |
| self.dataset["metadata"].update( | |
| { | |
| "total_chunks_processed": len(chunks_to_process), | |
| "total_chunks_available": total_chunks, | |
| "questions_generated": len(all_questions), | |
| "questions_per_chunk": questions_per_chunk, | |
| "quality_check_enabled": quality_check, | |
| "evaluation_method": "mrr_hit_rates_only", | |
| "completion_timestamp": time.time(), | |
| } | |
| ) | |
| # Step 6: Save dataset | |
| os.makedirs( | |
| os.path.dirname(save_path) if os.path.dirname(save_path) else ".", | |
| exist_ok=True, | |
| ) | |
| with open(save_path, "w", encoding="utf-8") as f: | |
| json.dump(self.dataset, f, ensure_ascii=False, indent=2) | |
| print(f"\nβ COMPLETE dataset created successfully!") | |
| print(f" π Saved to: {save_path}") | |
| print(f" π Statistics:") | |
| print(f" β’ Chunks processed: {len(chunks_to_process)}/{total_chunks}") | |
| print(f" β’ Questions generated: {len(all_questions)}") | |
| print(f" β’ Evaluation method: MRR and Hit Rates only") | |
| print( | |
| f" β’ Coverage: {len(chunks_to_process)/total_chunks*100:.1f}% of knowledge base" | |
| ) | |
| return self.dataset | |
| def load_dataset(self, dataset_path: str) -> Dict: | |
| """Load dataset from JSON file with metadata""" | |
| with open(dataset_path, "r", encoding="utf-8") as f: | |
| self.dataset = json.load(f) | |
| metadata = self.dataset.get("metadata", {}) | |
| print(f"π Loaded dataset from {dataset_path}") | |
| print(f" π Dataset Statistics:") | |
| print(f" β’ Queries: {len(self.dataset['queries'])}") | |
| print(f" β’ Documents: {len(self.dataset['documents'])}") | |
| print(f" β’ Created: {time.ctime(metadata.get('creation_timestamp', 0))}") | |
| return self.dataset | |
| class SingleTurnRetrievalEvaluator: | |
| """Simplified retrieval evaluator with only MRR and hit rates""" | |
| def __init__(self, dataset: Dict, knowledge_base: ViettelKnowledgeBase): | |
| """ | |
| Initialize evaluator with dataset and knowledge base | |
| Args: | |
| dataset: Evaluation dataset with queries and documents | |
| knowledge_base: ViettelKnowledgeBase instance to evaluate | |
| """ | |
| self.dataset = dataset | |
| self.knowledge_base = knowledge_base | |
| self.results = {} | |
| def _match_retrieved_documents(self, retrieved_docs) -> List[str]: | |
| """ | |
| Enhanced document matching with multiple strategies | |
| Args: | |
| retrieved_docs: Retrieved Document objects from knowledge base | |
| Returns: | |
| List of matched document IDs | |
| """ | |
| matched_ids = [] | |
| for doc in retrieved_docs: | |
| # Strategy 1: Try to find exact content match | |
| doc_id = self._find_exact_content_match(doc.page_content) | |
| if not doc_id: | |
| # Strategy 2: Try fuzzy content matching | |
| doc_id = self._find_fuzzy_content_match(doc.page_content) | |
| if doc_id: | |
| matched_ids.append(doc_id) | |
| return matched_ids | |
| def _find_exact_content_match(self, retrieved_content: str) -> Optional[str]: | |
| """Find exact content match""" | |
| for doc_id, doc_content in self.dataset["documents"].items(): | |
| if retrieved_content.strip() == doc_content.strip(): | |
| return doc_id | |
| return None | |
| def _find_fuzzy_content_match( | |
| self, retrieved_content: str, min_overlap: int = 50 | |
| ) -> Optional[str]: | |
| """Find fuzzy content match with word overlap""" | |
| best_match_id = None | |
| best_overlap = 0 | |
| retrieved_words = set(retrieved_content.lower().split()) | |
| for doc_id, doc_content in self.dataset["documents"].items(): | |
| doc_words = set(doc_content.lower().split()) | |
| overlap = len(retrieved_words & doc_words) | |
| if overlap > best_overlap and overlap >= min_overlap: | |
| best_overlap = overlap | |
| best_match_id = doc_id | |
| return best_match_id | |
| def _safe_average(self, values: List[float]) -> float: | |
| """Calculate average safely handling empty lists""" | |
| return sum(values) / len(values) if values else 0.0 | |
| def evaluate(self, k_values: List[int] = [1, 3, 5, 10]) -> Dict: | |
| """ | |
| Simplified evaluation with only MRR and hit rates | |
| This method checks if the source document (where the question was generated from) | |
| is retrieved among the top-k results. | |
| Args: | |
| k_values: List of k values to evaluate | |
| Returns: | |
| Dictionary with MRR and hit rate results | |
| """ | |
| print(f"\nπ Running simplified evaluation (MRR and Hit Rates only)...") | |
| print(f" π K values: {k_values}") | |
| print(f" π Total queries: {len(self.dataset['queries'])}") | |
| # Initialize results | |
| hit_rates = {k: [] for k in k_values} | |
| rr_scores = [] # Reciprocal Rank scores for MRR calculation | |
| query_results = {} | |
| failed_queries = [] | |
| # Process each query | |
| for query_id, query_text in tqdm( | |
| self.dataset["queries"].items(), desc="Evaluating queries" | |
| ): | |
| try: | |
| # Get source document from metadata - handle both single-turn and multi-turn formats | |
| source_chunk_id = None | |
| # Try question_metadata first (single-turn format) | |
| question_meta = self.dataset.get("question_metadata", {}).get( | |
| query_id, {} | |
| ) | |
| if question_meta: | |
| source_chunk_id = question_meta.get("source_chunk") | |
| # If not found, try conversation_metadata (multi-turn format) | |
| if not source_chunk_id: | |
| conversation_meta = self.dataset.get( | |
| "conversation_metadata", {} | |
| ).get(query_id, {}) | |
| if conversation_meta: | |
| source_chunk_id = conversation_meta.get("source_chunk") | |
| if not source_chunk_id: | |
| print(f"β οΈ No source chunk info for query {query_id}") | |
| continue | |
| # Get retrieval results | |
| retrieved_docs = self.knowledge_base.search( | |
| query_text, top_k=max(k_values) | |
| ) | |
| retrieved_doc_ids = self._match_retrieved_documents(retrieved_docs) | |
| # Check if source document is in top-k for each k | |
| query_results[query_id] = { | |
| "query": query_text, | |
| "source_chunk": source_chunk_id, | |
| "retrieved": retrieved_doc_ids, | |
| "hit_rates": {}, | |
| } | |
| # Calculate Reciprocal Rank (MRR) - once per query | |
| if source_chunk_id in retrieved_doc_ids: | |
| source_rank = ( | |
| retrieved_doc_ids.index(source_chunk_id) + 1 | |
| ) # 1-indexed rank | |
| rr_score = 1.0 / source_rank | |
| else: | |
| rr_score = 0.0 | |
| query_results[query_id]["rr"] = rr_score | |
| query_results[query_id]["source_rank"] = ( | |
| source_rank if rr_score > 0 else None | |
| ) | |
| rr_scores.append(rr_score) | |
| for k in k_values: | |
| top_k_docs = retrieved_doc_ids[:k] | |
| hit = 1 if source_chunk_id in top_k_docs else 0 | |
| hit_rates[k].append(hit) | |
| query_results[query_id]["hit_rates"][k] = hit | |
| except Exception as e: | |
| print(f"β Error evaluating query {query_id}: {e}") | |
| failed_queries.append((query_id, str(e))) | |
| continue | |
| # Calculate average metrics | |
| avg_hit_rates = {} | |
| avg_rr = sum(rr_scores) / len(rr_scores) if rr_scores else 0.0 | |
| for k in k_values: | |
| avg_hit_rates[k] = self._safe_average(hit_rates[k]) | |
| results = { | |
| "hit_rates": avg_hit_rates, | |
| "mrr": avg_rr, | |
| "per_query_results": query_results, | |
| "failed_queries": failed_queries, | |
| "summary": { | |
| "total_queries": len(self.dataset["queries"]), | |
| "evaluated_queries": len(query_results), | |
| "failed_queries": len(failed_queries), | |
| "success_rate": len(query_results) / len(self.dataset["queries"]) * 100, | |
| "k_values": k_values, | |
| "evaluation_type": "mrr_hit_rates_only", | |
| "evaluation_timestamp": time.time(), | |
| }, | |
| } | |
| return results | |
| def print_evaluation_results(self, results: Dict): | |
| """Print simplified evaluation results""" | |
| print(f"\nπ SIMPLIFIED EVALUATION RESULTS (MRR + Hit Rates)") | |
| print("=" * 60) | |
| print(f"\nπ Hit Rates (Source Document Found in Top-K):") | |
| print(f"{'K':<5} {'Hit Rate':<12} {'Percentage':<12}") | |
| print("-" * 30) | |
| for k in sorted(results["hit_rates"].keys()): | |
| hit_rate = results["hit_rates"][k] | |
| percentage = hit_rate * 100 | |
| print(f"{k:<5} {hit_rate:<12.4f} {percentage:<12.1f}%") | |
| # Display MRR separately since it's not k-dependent | |
| mrr = results["mrr"] | |
| print(f"\nπ Mean Reciprocal Rank (MRR): {mrr:.4f}") | |
| print(f" β’ MRR measures the average reciprocal rank of the source document") | |
| print(f" β’ Higher is better (max = 1.0 if all sources are rank 1)") | |
| print(f"\nπ Hit Rate Summary:") | |
| for k in sorted(results["hit_rates"].keys()): | |
| hit_rate = results["hit_rates"][k] | |
| percentage = hit_rate * 100 | |
| print( | |
| f" β’ Top-{k}: {percentage:.1f}% of questions find their source document" | |
| ) | |
| # Summary stats | |
| summary = results["summary"] | |
| print(f"\nπ Evaluation Summary:") | |
| print(f" β’ Total queries: {summary['total_queries']}") | |
| print(f" β’ Successfully evaluated: {summary['evaluated_queries']}") | |
| print(f" β’ Failed queries: {summary['failed_queries']}") | |
| print(f" β’ Success rate: {summary['success_rate']:.1f}%") | |
| print(f" β’ Evaluation type: {summary['evaluation_type']}") | |
| # Simple interpretation | |
| avg_hit_rate_5 = results["hit_rates"].get(5, 0) | |
| mrr = results["mrr"] | |
| print(f"\nπ― Quick Interpretation:") | |
| if avg_hit_rate_5 > 0.8: | |
| print( | |
| f" β Excellent: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}" | |
| ) | |
| elif avg_hit_rate_5 > 0.6: | |
| print(f" π Good: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}") | |
| elif avg_hit_rate_5 > 0.4: | |
| print(f" β οΈ Fair: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}") | |
| else: | |
| print(f" β Poor: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}") | |
| def main(): | |
| """Main function with argument parsing for separate operations""" | |
| parser = argparse.ArgumentParser( | |
| description="ViettelPay Retrieval Evaluation Dataset Creator (Simplified)" | |
| ) | |
| parser.add_argument( | |
| "--mode", | |
| choices=["create", "evaluate", "both"], | |
| default="both", | |
| help="Mode: create dataset, evaluate only, or both", | |
| ) | |
| parser.add_argument( | |
| "--dataset-path", | |
| default="evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval.json", | |
| help="Path to dataset file", | |
| ) | |
| parser.add_argument( | |
| "--results-path", | |
| default="evaluation_data/results/single_turn_retrieval/viettelpay_eval_results.json", | |
| help="Path to save evaluation results", | |
| ) | |
| parser.add_argument( | |
| "--questions-per-chunk", | |
| type=int, | |
| default=3, | |
| help="Number of questions per chunk", | |
| ) | |
| parser.add_argument( | |
| "--k-values", | |
| nargs="+", | |
| type=int, | |
| default=[1, 3, 5, 10], | |
| help="K values for evaluation", | |
| ) | |
| parser.add_argument( | |
| "--quality-check", | |
| action="store_true", | |
| help="Enable quality checking for chunks", | |
| ) | |
| parser.add_argument( | |
| "--knowledge-base-path", | |
| default="./knowledge_base", | |
| help="Path to knowledge base", | |
| ) | |
| args = parser.parse_args() | |
| # Configuration | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GEMINI_API_KEY: | |
| print("β Please set GEMINI_API_KEY environment variable") | |
| return | |
| try: | |
| # Initialize knowledge base | |
| print("π§ Initializing ViettelPay knowledge base...") | |
| kb = ViettelKnowledgeBase() | |
| if not kb.load_knowledge_base(args.knowledge_base_path): | |
| print( | |
| "β Failed to load knowledge base. Please run build_database_script.py first." | |
| ) | |
| return | |
| # Create dataset if requested | |
| if args.mode in ["create", "both"]: | |
| print(f"\nπ― Creating synthetic evaluation dataset...") | |
| creator = SingleTurnDatasetCreator(GEMINI_API_KEY, kb) | |
| dataset = creator.create_complete_dataset( | |
| questions_per_chunk=args.questions_per_chunk, | |
| save_path=args.dataset_path, | |
| quality_check=args.quality_check, | |
| ) | |
| # Evaluate if requested | |
| if args.mode in ["evaluate", "both"]: | |
| print(f"\nβ‘ Evaluating retrieval performance...") | |
| # Load dataset if not created in this run | |
| if args.mode == "evaluate": | |
| if not os.path.exists(args.dataset_path): | |
| print(f"β Dataset file not found: {args.dataset_path}") | |
| return | |
| creator = SingleTurnDatasetCreator(GEMINI_API_KEY, kb) | |
| dataset = creator.load_dataset(args.dataset_path) | |
| # Run evaluation | |
| evaluator = SingleTurnRetrievalEvaluator(dataset, kb) | |
| results = evaluator.evaluate(k_values=args.k_values) | |
| evaluator.print_evaluation_results(results) | |
| # Save results | |
| if args.results_path: | |
| with open(args.results_path, "w", encoding="utf-8") as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| print(f"\nπΎ Results saved to: {args.results_path}") | |
| print(f"\nβ Operation completed successfully!") | |
| print(f"\nπ‘ Next steps:") | |
| print(f" 1. Review the MRR and hit rate results") | |
| print(f" 2. Identify queries with low performance") | |
| print(f" 3. Optimize your retrieval system") | |
| print(f" 4. Re-run evaluation to measure progress") | |
| except Exception as e: | |
| print(f"β Error in main execution: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() | |