CapStoneRAG10 / background_evaluator.py
Developer
Add domain-specific embedding models for financial, law, and customer service
d3be8f6
"""Background evaluation job manager for RAG evaluation."""
import os
import json
import threading
import time
from datetime import datetime
from typing import Dict, Optional, List, Any
import traceback
# Job status constants
STATUS_PENDING = "pending"
STATUS_RUNNING = "running"
STATUS_COMPLETED = "completed"
STATUS_FAILED = "failed"
# Directory for storing job files
JOBS_DIR = "./evaluation_jobs"
def ensure_jobs_dir():
"""Ensure jobs directory exists."""
os.makedirs(JOBS_DIR, exist_ok=True)
def get_job_file_path(job_id: str) -> str:
"""Get path to job status file."""
ensure_jobs_dir()
return os.path.join(JOBS_DIR, f"{job_id}.json")
def save_job_status(job_id: str, status: Dict):
"""Save job status to file."""
filepath = get_job_file_path(job_id)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(status, f, indent=2, default=str)
def load_job_status(job_id: str) -> Optional[Dict]:
"""Load job status from file."""
filepath = get_job_file_path(job_id)
if os.path.exists(filepath):
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read().strip()
if not content:
# Empty file, delete it
os.remove(filepath)
return None
return json.loads(content)
except (json.JSONDecodeError, Exception) as e:
print(f"[Background] Error loading job {job_id}: {e}")
# Remove corrupted file
try:
os.remove(filepath)
except:
pass
return None
return None
def list_jobs() -> List[Dict]:
"""List all jobs with their status."""
ensure_jobs_dir()
jobs = []
for filename in os.listdir(JOBS_DIR):
if filename.endswith('.json'):
job_id = filename[:-5]
try:
status = load_job_status(job_id)
if status:
jobs.append(status)
except Exception as e:
print(f"[Background] Error loading job {job_id}: {e}")
continue
# Sort by created_at descending
jobs.sort(key=lambda x: x.get('created_at', ''), reverse=True)
return jobs
def create_job(
job_id: str,
collection_name: str,
dataset_name: str,
num_samples: int,
method: str,
llm_model: str,
embedding_model: str,
llm_provider: str = "groq"
) -> Dict:
"""Create a new evaluation job."""
job_status = {
"job_id": job_id,
"collection_name": collection_name,
"dataset_name": dataset_name,
"num_samples": num_samples,
"method": method,
"llm_model": llm_model,
"embedding_model": embedding_model,
"llm_provider": llm_provider,
"status": STATUS_PENDING,
"progress": 0,
"current_step": "Initializing...",
"logs": [],
"created_at": datetime.now().isoformat(),
"started_at": None,
"completed_at": None,
"results": None,
"error": None
}
save_job_status(job_id, job_status)
return job_status
def update_job_progress(job_id: str, progress: int, current_step: str, log_message: Optional[str] = None):
"""Update job progress."""
status = load_job_status(job_id)
if status:
status["progress"] = progress
status["current_step"] = current_step
if log_message:
status["logs"].append({
"time": datetime.now().isoformat(),
"message": log_message
})
save_job_status(job_id, status)
def run_background_evaluation(
job_id: str,
rag_pipeline,
vector_store,
dataset_name: str,
num_samples: int,
method: str,
llm_model: str,
embedding_model: str,
llm_provider: str,
groq_api_key: Optional[str] = None,
groq_api_keys: Optional[List[str]] = None,
ollama_host: Optional[str] = None
):
"""Run evaluation in background thread."""
try:
# Update status to running
status = load_job_status(job_id)
if not status:
print(f"[Background] Error: Could not load job status for {job_id}")
return
status["status"] = STATUS_RUNNING
status["started_at"] = datetime.now().isoformat()
status["logs"].append({
"time": datetime.now().isoformat(),
"message": f"⏱️ Evaluation started"
})
save_job_status(job_id, status)
# Import required modules
from dataset_loader import RAGBenchLoader
from config import settings
from llm_client import create_llm_client
update_job_progress(job_id, 5, "Loading test data...", "📥 Loading test data...")
# Load test data
loader = RAGBenchLoader()
test_data = loader.get_test_data(dataset_name, num_samples)
update_job_progress(job_id, 10, f"Loaded {len(test_data)} samples", f"✅ Loaded {len(test_data)} test samples")
# Create LLM client for evaluation if needed
if method in ["gpt_labeling", "hybrid"]:
update_job_progress(job_id, 15, "Initializing LLM client...", f"🤖 Initializing {llm_provider.upper()} LLM...")
eval_llm_client = create_llm_client(
provider=llm_provider,
api_key=groq_api_key,
api_keys=groq_api_keys,
model_name=llm_model,
ollama_host=ollama_host or settings.ollama_host,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay,
max_retries=settings.max_retries,
retry_delay=settings.retry_delay
)
# Process samples
test_cases = []
total_samples = len(test_data)
update_job_progress(job_id, 20, "Processing samples...", "🔍 Starting sample processing...")
for i, sample in enumerate(test_data):
progress = 20 + int((i / total_samples) * 40) # 20-60% for processing
# Query the RAG system
result = rag_pipeline.query(sample["question"], n_results=5)
test_cases.append({
"query": sample["question"],
"response": result["response"],
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
"ground_truth": sample.get("answer", "")
})
if (i + 1) % 5 == 0 or (i + 1) == total_samples:
update_job_progress(
job_id,
progress,
f"Processed {i+1}/{total_samples} samples",
f" ✓ Processed {i + 1}/{total_samples} samples"
)
update_job_progress(job_id, 60, "Running evaluation...", f"📊 Running {method} evaluation...")
# Get chunking metadata
chunking_strategy = getattr(vector_store, 'chunking_strategy', None)
chunk_size = getattr(vector_store, 'chunk_size', None)
chunk_overlap = getattr(vector_store, 'chunk_overlap', None)
# Run evaluation using advanced evaluator
from advanced_rag_evaluator import AdvancedRAGEvaluator
evaluator = AdvancedRAGEvaluator(
llm_client=eval_llm_client if method in ["gpt_labeling", "hybrid"] else rag_pipeline.llm,
embedding_model=embedding_model,
chunking_strategy=chunking_strategy,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
results = evaluator.evaluate_batch(test_cases)
update_job_progress(job_id, 95, "Saving results...", "💾 Saving results...")
# Save results to file
status = load_job_status(job_id)
collection_name = status.get("collection_name", "unknown") if status else "unknown"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
results_filename = f"{collection_name}_{timestamp}_evaluation.json"
save_data = {
"collection_name": collection_name,
"timestamp": datetime.now().isoformat(),
"evaluation_method": method,
"num_samples": num_samples,
"embedding_model": embedding_model,
"llm_model": llm_model,
"results": results
}
with open(results_filename, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, default=str)
# Update job as completed
status = load_job_status(job_id)
status["status"] = STATUS_COMPLETED
status["progress"] = 100
status["current_step"] = "Completed!"
status["completed_at"] = datetime.now().isoformat()
status["results"] = results
status["results_file"] = results_filename
status["logs"].append({
"time": datetime.now().isoformat(),
"message": f"✅ Evaluation completed! Results saved to {results_filename}"
})
save_job_status(job_id, status)
except Exception as e:
# Update job as failed
error_msg = str(e)
error_trace = traceback.format_exc()
status = load_job_status(job_id)
if status:
status["status"] = STATUS_FAILED
status["error"] = error_msg
status["error_trace"] = error_trace
status["completed_at"] = datetime.now().isoformat()
status["logs"].append({
"time": datetime.now().isoformat(),
"message": f"❌ Error: {error_msg}"
})
save_job_status(job_id, status)
def start_background_job(
job_id: str,
rag_pipeline,
vector_store,
dataset_name: str,
num_samples: int,
method: str,
llm_model: str,
embedding_model: str,
llm_provider: str,
groq_api_key: Optional[str] = None,
groq_api_keys: Optional[List[str]] = None,
ollama_host: Optional[str] = None
) -> threading.Thread:
"""Start evaluation in a background thread."""
thread = threading.Thread(
target=run_background_evaluation,
args=(
job_id,
rag_pipeline,
vector_store,
dataset_name,
num_samples,
method,
llm_model,
embedding_model,
llm_provider,
groq_api_key,
groq_api_keys,
ollama_host
),
daemon=True
)
thread.start()
return thread
def delete_job(job_id: str) -> bool:
"""Delete a job and its status file."""
filepath = get_job_file_path(job_id)
if os.path.exists(filepath):
os.remove(filepath)
return True
return False
def cleanup_old_jobs(max_age_hours: int = 24):
"""Clean up jobs older than specified hours."""
ensure_jobs_dir()
cutoff = datetime.now().timestamp() - (max_age_hours * 3600)
for filename in os.listdir(JOBS_DIR):
if filename.endswith('.json'):
filepath = os.path.join(JOBS_DIR, filename)
if os.path.getmtime(filepath) < cutoff:
os.remove(filepath)