Spaces:
Sleeping
Sleeping
| """Background evaluation job manager for RAG evaluation.""" | |
| import os | |
| import json | |
| import threading | |
| import time | |
| from datetime import datetime | |
| from typing import Dict, Optional, List, Any | |
| import traceback | |
| # Job status constants | |
| STATUS_PENDING = "pending" | |
| STATUS_RUNNING = "running" | |
| STATUS_COMPLETED = "completed" | |
| STATUS_FAILED = "failed" | |
| # Directory for storing job files | |
| JOBS_DIR = "./evaluation_jobs" | |
| def ensure_jobs_dir(): | |
| """Ensure jobs directory exists.""" | |
| os.makedirs(JOBS_DIR, exist_ok=True) | |
| def get_job_file_path(job_id: str) -> str: | |
| """Get path to job status file.""" | |
| ensure_jobs_dir() | |
| return os.path.join(JOBS_DIR, f"{job_id}.json") | |
| def save_job_status(job_id: str, status: Dict): | |
| """Save job status to file.""" | |
| filepath = get_job_file_path(job_id) | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(status, f, indent=2, default=str) | |
| def load_job_status(job_id: str) -> Optional[Dict]: | |
| """Load job status from file.""" | |
| filepath = get_job_file_path(job_id) | |
| if os.path.exists(filepath): | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| content = f.read().strip() | |
| if not content: | |
| # Empty file, delete it | |
| os.remove(filepath) | |
| return None | |
| return json.loads(content) | |
| except (json.JSONDecodeError, Exception) as e: | |
| print(f"[Background] Error loading job {job_id}: {e}") | |
| # Remove corrupted file | |
| try: | |
| os.remove(filepath) | |
| except: | |
| pass | |
| return None | |
| return None | |
| def list_jobs() -> List[Dict]: | |
| """List all jobs with their status.""" | |
| ensure_jobs_dir() | |
| jobs = [] | |
| for filename in os.listdir(JOBS_DIR): | |
| if filename.endswith('.json'): | |
| job_id = filename[:-5] | |
| try: | |
| status = load_job_status(job_id) | |
| if status: | |
| jobs.append(status) | |
| except Exception as e: | |
| print(f"[Background] Error loading job {job_id}: {e}") | |
| continue | |
| # Sort by created_at descending | |
| jobs.sort(key=lambda x: x.get('created_at', ''), reverse=True) | |
| return jobs | |
| def create_job( | |
| job_id: str, | |
| collection_name: str, | |
| dataset_name: str, | |
| num_samples: int, | |
| method: str, | |
| llm_model: str, | |
| embedding_model: str, | |
| llm_provider: str = "groq" | |
| ) -> Dict: | |
| """Create a new evaluation job.""" | |
| job_status = { | |
| "job_id": job_id, | |
| "collection_name": collection_name, | |
| "dataset_name": dataset_name, | |
| "num_samples": num_samples, | |
| "method": method, | |
| "llm_model": llm_model, | |
| "embedding_model": embedding_model, | |
| "llm_provider": llm_provider, | |
| "status": STATUS_PENDING, | |
| "progress": 0, | |
| "current_step": "Initializing...", | |
| "logs": [], | |
| "created_at": datetime.now().isoformat(), | |
| "started_at": None, | |
| "completed_at": None, | |
| "results": None, | |
| "error": None | |
| } | |
| save_job_status(job_id, job_status) | |
| return job_status | |
| def update_job_progress(job_id: str, progress: int, current_step: str, log_message: Optional[str] = None): | |
| """Update job progress.""" | |
| status = load_job_status(job_id) | |
| if status: | |
| status["progress"] = progress | |
| status["current_step"] = current_step | |
| if log_message: | |
| status["logs"].append({ | |
| "time": datetime.now().isoformat(), | |
| "message": log_message | |
| }) | |
| save_job_status(job_id, status) | |
| def run_background_evaluation( | |
| job_id: str, | |
| rag_pipeline, | |
| vector_store, | |
| dataset_name: str, | |
| num_samples: int, | |
| method: str, | |
| llm_model: str, | |
| embedding_model: str, | |
| llm_provider: str, | |
| groq_api_key: Optional[str] = None, | |
| groq_api_keys: Optional[List[str]] = None, | |
| ollama_host: Optional[str] = None | |
| ): | |
| """Run evaluation in background thread.""" | |
| try: | |
| # Update status to running | |
| status = load_job_status(job_id) | |
| if not status: | |
| print(f"[Background] Error: Could not load job status for {job_id}") | |
| return | |
| status["status"] = STATUS_RUNNING | |
| status["started_at"] = datetime.now().isoformat() | |
| status["logs"].append({ | |
| "time": datetime.now().isoformat(), | |
| "message": f"⏱️ Evaluation started" | |
| }) | |
| save_job_status(job_id, status) | |
| # Import required modules | |
| from dataset_loader import RAGBenchLoader | |
| from config import settings | |
| from llm_client import create_llm_client | |
| update_job_progress(job_id, 5, "Loading test data...", "📥 Loading test data...") | |
| # Load test data | |
| loader = RAGBenchLoader() | |
| test_data = loader.get_test_data(dataset_name, num_samples) | |
| update_job_progress(job_id, 10, f"Loaded {len(test_data)} samples", f"✅ Loaded {len(test_data)} test samples") | |
| # Create LLM client for evaluation if needed | |
| if method in ["gpt_labeling", "hybrid"]: | |
| update_job_progress(job_id, 15, "Initializing LLM client...", f"🤖 Initializing {llm_provider.upper()} LLM...") | |
| eval_llm_client = create_llm_client( | |
| provider=llm_provider, | |
| api_key=groq_api_key, | |
| api_keys=groq_api_keys, | |
| model_name=llm_model, | |
| ollama_host=ollama_host or settings.ollama_host, | |
| max_rpm=settings.groq_rpm_limit, | |
| rate_limit_delay=settings.rate_limit_delay, | |
| max_retries=settings.max_retries, | |
| retry_delay=settings.retry_delay | |
| ) | |
| # Process samples | |
| test_cases = [] | |
| total_samples = len(test_data) | |
| update_job_progress(job_id, 20, "Processing samples...", "🔍 Starting sample processing...") | |
| for i, sample in enumerate(test_data): | |
| progress = 20 + int((i / total_samples) * 40) # 20-60% for processing | |
| # Query the RAG system | |
| result = rag_pipeline.query(sample["question"], n_results=5) | |
| test_cases.append({ | |
| "query": sample["question"], | |
| "response": result["response"], | |
| "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], | |
| "ground_truth": sample.get("answer", "") | |
| }) | |
| if (i + 1) % 5 == 0 or (i + 1) == total_samples: | |
| update_job_progress( | |
| job_id, | |
| progress, | |
| f"Processed {i+1}/{total_samples} samples", | |
| f" ✓ Processed {i + 1}/{total_samples} samples" | |
| ) | |
| update_job_progress(job_id, 60, "Running evaluation...", f"📊 Running {method} evaluation...") | |
| # Get chunking metadata | |
| chunking_strategy = getattr(vector_store, 'chunking_strategy', None) | |
| chunk_size = getattr(vector_store, 'chunk_size', None) | |
| chunk_overlap = getattr(vector_store, 'chunk_overlap', None) | |
| # Run evaluation using advanced evaluator | |
| from advanced_rag_evaluator import AdvancedRAGEvaluator | |
| evaluator = AdvancedRAGEvaluator( | |
| llm_client=eval_llm_client if method in ["gpt_labeling", "hybrid"] else rag_pipeline.llm, | |
| embedding_model=embedding_model, | |
| chunking_strategy=chunking_strategy, | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap | |
| ) | |
| results = evaluator.evaluate_batch(test_cases) | |
| update_job_progress(job_id, 95, "Saving results...", "💾 Saving results...") | |
| # Save results to file | |
| status = load_job_status(job_id) | |
| collection_name = status.get("collection_name", "unknown") if status else "unknown" | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| results_filename = f"{collection_name}_{timestamp}_evaluation.json" | |
| save_data = { | |
| "collection_name": collection_name, | |
| "timestamp": datetime.now().isoformat(), | |
| "evaluation_method": method, | |
| "num_samples": num_samples, | |
| "embedding_model": embedding_model, | |
| "llm_model": llm_model, | |
| "results": results | |
| } | |
| with open(results_filename, 'w', encoding='utf-8') as f: | |
| json.dump(save_data, f, indent=2, default=str) | |
| # Update job as completed | |
| status = load_job_status(job_id) | |
| status["status"] = STATUS_COMPLETED | |
| status["progress"] = 100 | |
| status["current_step"] = "Completed!" | |
| status["completed_at"] = datetime.now().isoformat() | |
| status["results"] = results | |
| status["results_file"] = results_filename | |
| status["logs"].append({ | |
| "time": datetime.now().isoformat(), | |
| "message": f"✅ Evaluation completed! Results saved to {results_filename}" | |
| }) | |
| save_job_status(job_id, status) | |
| except Exception as e: | |
| # Update job as failed | |
| error_msg = str(e) | |
| error_trace = traceback.format_exc() | |
| status = load_job_status(job_id) | |
| if status: | |
| status["status"] = STATUS_FAILED | |
| status["error"] = error_msg | |
| status["error_trace"] = error_trace | |
| status["completed_at"] = datetime.now().isoformat() | |
| status["logs"].append({ | |
| "time": datetime.now().isoformat(), | |
| "message": f"❌ Error: {error_msg}" | |
| }) | |
| save_job_status(job_id, status) | |
| def start_background_job( | |
| job_id: str, | |
| rag_pipeline, | |
| vector_store, | |
| dataset_name: str, | |
| num_samples: int, | |
| method: str, | |
| llm_model: str, | |
| embedding_model: str, | |
| llm_provider: str, | |
| groq_api_key: Optional[str] = None, | |
| groq_api_keys: Optional[List[str]] = None, | |
| ollama_host: Optional[str] = None | |
| ) -> threading.Thread: | |
| """Start evaluation in a background thread.""" | |
| thread = threading.Thread( | |
| target=run_background_evaluation, | |
| args=( | |
| job_id, | |
| rag_pipeline, | |
| vector_store, | |
| dataset_name, | |
| num_samples, | |
| method, | |
| llm_model, | |
| embedding_model, | |
| llm_provider, | |
| groq_api_key, | |
| groq_api_keys, | |
| ollama_host | |
| ), | |
| daemon=True | |
| ) | |
| thread.start() | |
| return thread | |
| def delete_job(job_id: str) -> bool: | |
| """Delete a job and its status file.""" | |
| filepath = get_job_file_path(job_id) | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| return True | |
| return False | |
| def cleanup_old_jobs(max_age_hours: int = 24): | |
| """Clean up jobs older than specified hours.""" | |
| ensure_jobs_dir() | |
| cutoff = datetime.now().timestamp() - (max_age_hours * 3600) | |
| for filename in os.listdir(JOBS_DIR): | |
| if filename.endswith('.json'): | |
| filepath = os.path.join(JOBS_DIR, filename) | |
| if os.path.getmtime(filepath) < cutoff: | |
| os.remove(filepath) | |