""" Evaluation tasks for AegisLM SaaS Backend. Production-ready Celery tasks for AI evaluation processing, background job execution, and result storage. """ import sys import os from datetime import datetime from typing import Dict, Any, Optional from celery import current_task from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker # Add the backend directory to Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from workers.celery_worker import celery_app from core.config import settings from core.database import get_redis from db_models.evaluation import Evaluation, EvaluationStatus from schemas.evaluation_schema import EvaluationCreate # Try to import AI pipeline components (now works with default path fix) try: from ai.pipelines.redteam_pipeline import run_redteam_pipeline, create_pipeline_config from ai.models import PipelineConfig as AIPipelineConfig AI_PIPELINE_AVAILABLE = True except ImportError as e: AI_PIPELINE_AVAILABLE = False print(f"Warning: AI pipeline not available: {str(e)}") # Database session for tasks AsyncSessionLocal = sessionmaker( bind=create_async_engine(settings.DATABASE_URL), class_=AsyncSession, expire_on_commit=False ) class ModelInterface: """ Model interface for AI evaluation. Implements the interface expected by the red team pipeline. """ def __init__(self, model_config: Dict[str, Any]): """ Initialize model interface. Args: model_config: Model configuration """ self.model_name = model_config.get("model_name", "unknown") self.model_type = model_config.get("model_type", "api") self.api_endpoint = model_config.get("api_endpoint") self.api_key = model_config.get("api_key") self.temperature = model_config.get("temperature", 0.7) self.max_tokens = model_config.get("max_tokens", 1000) self.custom_parameters = model_config.get("custom_parameters", {}) async def generate_response(self, prompt: str) -> str: """ Produce a real-time response using the AegisLM LLM Router. """ from ai.llm_router.llm_router import get_llm_router router = get_llm_router() try: # Route to the actual model provider (Groq/OpenAI/Mistral) response = await router.route_and_generate( model=self.model_name, prompt=prompt, temperature=self.temperature, max_tokens=self.max_tokens ) return response.text except Exception as e: logger.error(f"Router failure: {str(e)}") return f"Error: Model {self.model_name} failed to respond: {str(e)}" @celery_app.task(bind=True, name="tasks.evaluation_task.run_evaluation_task") def run_evaluation_task(self, job_id: str): """ Run evaluation task for given job ID. Args: job_id: Evaluation job ID Returns: dict: Task result """ async def _run_evaluation(): async with AsyncSessionLocal() as db: try: # Get evaluation from database result = await db.execute( select(Evaluation).where(Evaluation.job_id == job_id) ) evaluation = result.scalar_one_or_none() if not evaluation: raise ValueError(f"Evaluation not found: {job_id}") # Update status to running evaluation.update_status(EvaluationStatus.RUNNING) await db.commit() # Create model interface model = ModelInterface(evaluation.model_config) # Create pipeline configuration pipeline_config_dict = evaluation.pipeline_config if AI_PIPELINE_AVAILABLE: pipeline_config = create_pipeline_config( max_iterations=pipeline_config_dict.get("max_iterations", 5), attack_types=pipeline_config_dict.get("attack_types"), num_prompts=pipeline_config_dict.get("num_prompts", 10), mutation_enabled=pipeline_config_dict.get("mutation_enabled", True), weights=pipeline_config_dict.get("weights") ) else: # Use REAL pipeline config with actual AI models if available try: from ai.models import PipelineConfig as AIPipelineConfig pipeline_config = AIPipelineConfig( max_iterations=pipeline_config_dict.get("max_iterations", 5), attack_types=pipeline_config_dict.get("attack_types"), num_prompts=pipeline_config_dict.get("num_prompts", 10), mutation_enabled=pipeline_config_dict.get("mutation_enabled", True), weights=pipeline_config_dict.get("weights"), use_real_models=True, # Enable real model usage fallback_to_simulation=True # Allow fallback if needed ) except ImportError: # Final fallback to basic config pipeline_config = type('PipelineConfig', (), { 'max_iterations': pipeline_config_dict.get("max_iterations", 5), 'attack_types': pipeline_config_dict.get("attack_types", []), 'num_prompts': pipeline_config_dict.get("num_prompts", 10), 'mutation_enabled': pipeline_config_dict.get("mutation_enabled", True), 'weights': pipeline_config_dict.get("weights", {}) })() # Run red team pipeline start_time = datetime.utcnow() if AI_PIPELINE_AVAILABLE: # INTEGRATION: Awaiting the production async pipeline pipeline_result = await run_redteam_pipeline(model, pipeline_config, job_id) else: # Use simplified result if AI pipeline not available class SimplifiedPipelineResult: def __init__(self): self.pipeline_id = job_id self.total_attacks = 0 self.successful_attacks = [] self.failed_attacks = [] self.success_rate = 0.0 self.execution_time_ms = 1000 def dict(self): return { "pipeline_id": self.pipeline_id, "total_attacks": self.total_attacks, "successful_attacks": self.successful_attacks, "failed_attacks": self.failed_attacks, "success_rate": self.success_rate, "execution_time_ms": self.execution_time_ms } pipeline_result = SimplifiedPipelineResult() end_time = datetime.utcnow() # Calculate metrics execution_time_ms = int((end_time - start_time).total_seconds() * 1000) success_rate = pipeline_result.success_rate * 100 if pipeline_result.success_rate else 0.0 # Update evaluation with results evaluation.status = EvaluationStatus.COMPLETED evaluation.result_json = pipeline_result.dict() evaluation.total_attacks = pipeline_result.total_attacks evaluation.successful_attacks = len(pipeline_result.successful_attacks) evaluation.success_rate = pipeline_result.success_rate # Store as float (0.0 to 1.0) evaluation.execution_time_ms = execution_time_ms evaluation.completed_at = end_time await db.commit() # Store result in Redis for quick access redis_client = await get_redis() result_key = f"evaluation_result:{job_id}" await redis_client.setex( result_key, 3600, # 1 hour expiry str(pipeline_result.dict()) ) return { "job_id": job_id, "status": "completed", "execution_time_ms": execution_time_ms, "total_attacks": pipeline_result.total_attacks, "successful_attacks": len(pipeline_result.successful_attacks), "success_rate": success_rate, "completed_at": end_time.isoformat() } except Exception as e: # Update evaluation status to failed if evaluation: evaluation.update_status(EvaluationStatus.FAILED, str(e)) await db.commit() # Log error print(f"Evaluation task failed for job {job_id}: {str(e)}") return { "job_id": job_id, "status": "failed", "error": str(e), "failed_at": datetime.utcnow().isoformat() } # Run the async function import asyncio return asyncio.run(_run_evaluation()) @celery_app.task(bind=True, name="tasks.evaluation_task.run_benchmark_task") def run_benchmark_task(self, benchmark_id: str, job_ids: list): """ Run benchmark task for multiple evaluation jobs. Args: benchmark_id: Benchmark ID job_ids: List of evaluation job IDs Returns: dict: Benchmark task result """ async def _run_benchmark(): try: # Trigger individual evaluation tasks task_results = [] for job_id in job_ids: # Chain the evaluation task task = run_evaluation_task.delay(job_id) task_results.append({ "job_id": job_id, "task_id": task.id, "status": "queued" }) # Store benchmark progress in Redis redis_client = await get_redis() benchmark_key = f"benchmark:{benchmark_id}" benchmark_data = { "benchmark_id": benchmark_id, "job_ids": job_ids, "task_ids": [result["task_id"] for result in task_results], "status": "running", "started_at": datetime.utcnow().isoformat(), "total_jobs": len(job_ids), "completed_jobs": 0, "failed_jobs": 0 } await redis_client.setex( benchmark_key, 24 * 3600, # 24 hours expiry str(benchmark_data) ) return { "benchmark_id": benchmark_id, "status": "started", "total_jobs": len(job_ids), "task_results": task_results } except Exception as e: print(f"Benchmark task failed for {benchmark_id}: {str(e)}") return { "benchmark_id": benchmark_id, "status": "failed", "error": str(e), "failed_at": datetime.utcnow().isoformat() } # Run the async function import asyncio return asyncio.run(_run_benchmark()) @celery_app.task(name="tasks.evaluation_task.cleanup_expired_results") def cleanup_expired_results(): """ Cleanup expired evaluation results from Redis. Returns: dict: Cleanup results """ async def _cleanup(): try: redis_client = await get_redis() # Get all evaluation result keys result_keys = await redis_client.keys("evaluation_result:*") cleaned_count = 0 for key in result_keys: # Check if key exists and is expired ttl = await redis_client.ttl(key) if ttl == -1: # No expiry set, set one await redis_client.expire(key, 3600) # 1 hour elif ttl == -2: # Key expired cleaned_count += 1 # Also cleanup benchmark data benchmark_keys = await redis_client.keys("benchmark:*") for key in benchmark_keys: ttl = await redis_client.ttl(key) if ttl == -1: # No expiry set, set one await redis_client.expire(key, 24 * 3600) # 24 hours elif ttl == -2: # Key expired cleaned_count += 1 return { "status": "completed", "timestamp": datetime.utcnow().isoformat(), "cleaned_keys": cleaned_count, "total_result_keys": len(result_keys), "total_benchmark_keys": len(benchmark_keys) } except Exception as e: print(f"Cleanup task failed: {str(e)}") return { "status": "failed", "error": str(e), "timestamp": datetime.utcnow().isoformat() } # Run the async function import asyncio return asyncio.run(_cleanup()) @celery_app.task(name="tasks.evaluation_task.update_evaluation_progress") def update_evaluation_progress(job_id: str, progress: float): """ Update evaluation progress. Args: job_id: Evaluation job ID progress: Progress percentage (0-100) Returns: dict: Update result """ async def _update_progress(): try: redis_client = await get_redis() progress_key = f"evaluation_progress:{job_id}" progress_data = { "job_id": job_id, "progress": progress, "updated_at": datetime.utcnow().isoformat() } await redis_client.setex( progress_key, 3600, # 1 hour expiry str(progress_data) ) return { "job_id": job_id, "progress": progress, "updated_at": progress_data["updated_at"] } except Exception as e: print(f"Progress update failed for {job_id}: {str(e)}") return { "job_id": job_id, "error": str(e) } # Run the async function import asyncio return asyncio.run(_update_progress()) @celery_app.task(name="tasks.evaluation_task.cancel_evaluation_task") def cancel_evaluation_task(job_id: str): """ Cancel evaluation task. Args: job_id: Evaluation job ID Returns: dict: Cancellation result """ async def _cancel_evaluation(): async with AsyncSessionLocal() as db: try: # Get evaluation from database result = await db.execute( select(Evaluation).where(Evaluation.job_id == job_id) ) evaluation = result.scalar_one_or_none() if not evaluation: raise ValueError(f"Evaluation not found: {job_id}") # Update status to cancelled evaluation.update_status(EvaluationStatus.CANCELLED) await db.commit() # Remove from Redis if exists redis_client = await get_redis() result_key = f"evaluation_result:{job_id}" await redis_client.delete(result_key) progress_key = f"evaluation_progress:{job_id}" await redis_client.delete(progress_key) return { "job_id": job_id, "status": "cancelled", "cancelled_at": datetime.utcnow().isoformat() } except Exception as e: print(f"Cancellation failed for {job_id}: {str(e)}") return { "job_id": job_id, "error": str(e) } # Run the async function import asyncio return asyncio.run(_cancel_evaluation()) # Helper function to check task status @celery_app.task(name="tasks.evaluation_task.check_task_status") def check_task_status(task_id: str): """ Check status of a Celery task. Args: task_id: Celery task ID Returns: dict: Task status """ try: task = celery_app.AsyncResult(task_id) return { "task_id": task_id, "status": task.status, "result": task.result if task.ready() else None, "traceback": task.traceback if task.failed() else None, "date_done": task.date_done.isoformat() if task.date_done else None } except Exception as e: return { "task_id": task_id, "error": str(e) }