| """ |
| Evaluation tasks for AegisLM SaaS Backend. |
| |
| Production-ready Celery tasks for AI evaluation processing, |
| background job execution, and result storage. |
| """ |
|
|
| import sys |
| import os |
| from datetime import datetime |
| from typing import Dict, Any, Optional |
| from celery import current_task |
| from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine |
| from sqlalchemy.orm import sessionmaker |
|
|
| |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from workers.celery_worker import celery_app |
| from core.config import settings |
| from core.database import get_redis |
| from db_models.evaluation import Evaluation, EvaluationStatus |
| from schemas.evaluation_schema import EvaluationCreate |
|
|
| |
| try: |
| from ai.pipelines.redteam_pipeline import run_redteam_pipeline, create_pipeline_config |
| from ai.models import PipelineConfig as AIPipelineConfig |
| AI_PIPELINE_AVAILABLE = True |
| except ImportError as e: |
| AI_PIPELINE_AVAILABLE = False |
| print(f"Warning: AI pipeline not available: {str(e)}") |
|
|
|
|
| |
| AsyncSessionLocal = sessionmaker( |
| bind=create_async_engine(settings.DATABASE_URL), |
| class_=AsyncSession, |
| expire_on_commit=False |
| ) |
|
|
|
|
| class ModelInterface: |
| """ |
| Model interface for AI evaluation. |
| |
| Implements the interface expected by the red team pipeline. |
| """ |
| |
| def __init__(self, model_config: Dict[str, Any]): |
| """ |
| Initialize model interface. |
| |
| Args: |
| model_config: Model configuration |
| """ |
| self.model_name = model_config.get("model_name", "unknown") |
| self.model_type = model_config.get("model_type", "api") |
| self.api_endpoint = model_config.get("api_endpoint") |
| self.api_key = model_config.get("api_key") |
| self.temperature = model_config.get("temperature", 0.7) |
| self.max_tokens = model_config.get("max_tokens", 1000) |
| self.custom_parameters = model_config.get("custom_parameters", {}) |
| |
| async def generate_response(self, prompt: str) -> str: |
| """ |
| Produce a real-time response using the AegisLM LLM Router. |
| """ |
| from ai.llm_router.llm_router import get_llm_router |
| router = get_llm_router() |
| try: |
| |
| response = await router.route_and_generate( |
| model=self.model_name, |
| prompt=prompt, |
| temperature=self.temperature, |
| max_tokens=self.max_tokens |
| ) |
| return response.text |
| except Exception as e: |
| logger.error(f"Router failure: {str(e)}") |
| return f"Error: Model {self.model_name} failed to respond: {str(e)}" |
|
|
|
|
| @celery_app.task(bind=True, name="tasks.evaluation_task.run_evaluation_task") |
| def run_evaluation_task(self, job_id: str): |
| """ |
| Run evaluation task for given job ID. |
| |
| Args: |
| job_id: Evaluation job ID |
| |
| Returns: |
| dict: Task result |
| """ |
| async def _run_evaluation(): |
| async with AsyncSessionLocal() as db: |
| try: |
| |
| result = await db.execute( |
| select(Evaluation).where(Evaluation.job_id == job_id) |
| ) |
| evaluation = result.scalar_one_or_none() |
| |
| if not evaluation: |
| raise ValueError(f"Evaluation not found: {job_id}") |
| |
| |
| evaluation.update_status(EvaluationStatus.RUNNING) |
| await db.commit() |
| |
| |
| model = ModelInterface(evaluation.model_config) |
| |
| |
| pipeline_config_dict = evaluation.pipeline_config |
| if AI_PIPELINE_AVAILABLE: |
| pipeline_config = create_pipeline_config( |
| max_iterations=pipeline_config_dict.get("max_iterations", 5), |
| attack_types=pipeline_config_dict.get("attack_types"), |
| num_prompts=pipeline_config_dict.get("num_prompts", 10), |
| mutation_enabled=pipeline_config_dict.get("mutation_enabled", True), |
| weights=pipeline_config_dict.get("weights") |
| ) |
| else: |
| |
| try: |
| from ai.models import PipelineConfig as AIPipelineConfig |
| pipeline_config = AIPipelineConfig( |
| max_iterations=pipeline_config_dict.get("max_iterations", 5), |
| attack_types=pipeline_config_dict.get("attack_types"), |
| num_prompts=pipeline_config_dict.get("num_prompts", 10), |
| mutation_enabled=pipeline_config_dict.get("mutation_enabled", True), |
| weights=pipeline_config_dict.get("weights"), |
| use_real_models=True, |
| fallback_to_simulation=True |
| ) |
| except ImportError: |
| |
| pipeline_config = type('PipelineConfig', (), { |
| 'max_iterations': pipeline_config_dict.get("max_iterations", 5), |
| 'attack_types': pipeline_config_dict.get("attack_types", []), |
| 'num_prompts': pipeline_config_dict.get("num_prompts", 10), |
| 'mutation_enabled': pipeline_config_dict.get("mutation_enabled", True), |
| 'weights': pipeline_config_dict.get("weights", {}) |
| })() |
| |
| |
| start_time = datetime.utcnow() |
| |
| if AI_PIPELINE_AVAILABLE: |
| |
| pipeline_result = await run_redteam_pipeline(model, pipeline_config, job_id) |
| else: |
| |
| class SimplifiedPipelineResult: |
| def __init__(self): |
| self.pipeline_id = job_id |
| self.total_attacks = 0 |
| self.successful_attacks = [] |
| self.failed_attacks = [] |
| self.success_rate = 0.0 |
| self.execution_time_ms = 1000 |
| |
| def dict(self): |
| return { |
| "pipeline_id": self.pipeline_id, |
| "total_attacks": self.total_attacks, |
| "successful_attacks": self.successful_attacks, |
| "failed_attacks": self.failed_attacks, |
| "success_rate": self.success_rate, |
| "execution_time_ms": self.execution_time_ms |
| } |
| |
| pipeline_result = SimplifiedPipelineResult() |
| |
| end_time = datetime.utcnow() |
| |
| |
| execution_time_ms = int((end_time - start_time).total_seconds() * 1000) |
| success_rate = pipeline_result.success_rate * 100 if pipeline_result.success_rate else 0.0 |
| |
| |
| evaluation.status = EvaluationStatus.COMPLETED |
| evaluation.result_json = pipeline_result.dict() |
| evaluation.total_attacks = pipeline_result.total_attacks |
| evaluation.successful_attacks = len(pipeline_result.successful_attacks) |
| evaluation.success_rate = pipeline_result.success_rate |
| evaluation.execution_time_ms = execution_time_ms |
| evaluation.completed_at = end_time |
| |
| await db.commit() |
| |
| |
| redis_client = await get_redis() |
| result_key = f"evaluation_result:{job_id}" |
| await redis_client.setex( |
| result_key, |
| 3600, |
| str(pipeline_result.dict()) |
| ) |
| |
| return { |
| "job_id": job_id, |
| "status": "completed", |
| "execution_time_ms": execution_time_ms, |
| "total_attacks": pipeline_result.total_attacks, |
| "successful_attacks": len(pipeline_result.successful_attacks), |
| "success_rate": success_rate, |
| "completed_at": end_time.isoformat() |
| } |
| |
| except Exception as e: |
| |
| if evaluation: |
| evaluation.update_status(EvaluationStatus.FAILED, str(e)) |
| await db.commit() |
| |
| |
| print(f"Evaluation task failed for job {job_id}: {str(e)}") |
| |
| return { |
| "job_id": job_id, |
| "status": "failed", |
| "error": str(e), |
| "failed_at": datetime.utcnow().isoformat() |
| } |
| |
| |
| import asyncio |
| return asyncio.run(_run_evaluation()) |
|
|
|
|
| @celery_app.task(bind=True, name="tasks.evaluation_task.run_benchmark_task") |
| def run_benchmark_task(self, benchmark_id: str, job_ids: list): |
| """ |
| Run benchmark task for multiple evaluation jobs. |
| |
| Args: |
| benchmark_id: Benchmark ID |
| job_ids: List of evaluation job IDs |
| |
| Returns: |
| dict: Benchmark task result |
| """ |
| async def _run_benchmark(): |
| try: |
| |
| task_results = [] |
| |
| for job_id in job_ids: |
| |
| task = run_evaluation_task.delay(job_id) |
| task_results.append({ |
| "job_id": job_id, |
| "task_id": task.id, |
| "status": "queued" |
| }) |
| |
| |
| redis_client = await get_redis() |
| benchmark_key = f"benchmark:{benchmark_id}" |
| |
| benchmark_data = { |
| "benchmark_id": benchmark_id, |
| "job_ids": job_ids, |
| "task_ids": [result["task_id"] for result in task_results], |
| "status": "running", |
| "started_at": datetime.utcnow().isoformat(), |
| "total_jobs": len(job_ids), |
| "completed_jobs": 0, |
| "failed_jobs": 0 |
| } |
| |
| await redis_client.setex( |
| benchmark_key, |
| 24 * 3600, |
| str(benchmark_data) |
| ) |
| |
| return { |
| "benchmark_id": benchmark_id, |
| "status": "started", |
| "total_jobs": len(job_ids), |
| "task_results": task_results |
| } |
| |
| except Exception as e: |
| print(f"Benchmark task failed for {benchmark_id}: {str(e)}") |
| |
| return { |
| "benchmark_id": benchmark_id, |
| "status": "failed", |
| "error": str(e), |
| "failed_at": datetime.utcnow().isoformat() |
| } |
| |
| |
| import asyncio |
| return asyncio.run(_run_benchmark()) |
|
|
|
|
| @celery_app.task(name="tasks.evaluation_task.cleanup_expired_results") |
| def cleanup_expired_results(): |
| """ |
| Cleanup expired evaluation results from Redis. |
| |
| Returns: |
| dict: Cleanup results |
| """ |
| async def _cleanup(): |
| try: |
| redis_client = await get_redis() |
| |
| |
| result_keys = await redis_client.keys("evaluation_result:*") |
| |
| cleaned_count = 0 |
| for key in result_keys: |
| |
| ttl = await redis_client.ttl(key) |
| if ttl == -1: |
| await redis_client.expire(key, 3600) |
| elif ttl == -2: |
| cleaned_count += 1 |
| |
| |
| benchmark_keys = await redis_client.keys("benchmark:*") |
| for key in benchmark_keys: |
| ttl = await redis_client.ttl(key) |
| if ttl == -1: |
| await redis_client.expire(key, 24 * 3600) |
| elif ttl == -2: |
| cleaned_count += 1 |
| |
| return { |
| "status": "completed", |
| "timestamp": datetime.utcnow().isoformat(), |
| "cleaned_keys": cleaned_count, |
| "total_result_keys": len(result_keys), |
| "total_benchmark_keys": len(benchmark_keys) |
| } |
| |
| except Exception as e: |
| print(f"Cleanup task failed: {str(e)}") |
| |
| return { |
| "status": "failed", |
| "error": str(e), |
| "timestamp": datetime.utcnow().isoformat() |
| } |
| |
| |
| import asyncio |
| return asyncio.run(_cleanup()) |
|
|
|
|
| @celery_app.task(name="tasks.evaluation_task.update_evaluation_progress") |
| def update_evaluation_progress(job_id: str, progress: float): |
| """ |
| Update evaluation progress. |
| |
| Args: |
| job_id: Evaluation job ID |
| progress: Progress percentage (0-100) |
| |
| Returns: |
| dict: Update result |
| """ |
| async def _update_progress(): |
| try: |
| redis_client = await get_redis() |
| progress_key = f"evaluation_progress:{job_id}" |
| |
| progress_data = { |
| "job_id": job_id, |
| "progress": progress, |
| "updated_at": datetime.utcnow().isoformat() |
| } |
| |
| await redis_client.setex( |
| progress_key, |
| 3600, |
| str(progress_data) |
| ) |
| |
| return { |
| "job_id": job_id, |
| "progress": progress, |
| "updated_at": progress_data["updated_at"] |
| } |
| |
| except Exception as e: |
| print(f"Progress update failed for {job_id}: {str(e)}") |
| |
| return { |
| "job_id": job_id, |
| "error": str(e) |
| } |
| |
| |
| import asyncio |
| return asyncio.run(_update_progress()) |
|
|
|
|
| @celery_app.task(name="tasks.evaluation_task.cancel_evaluation_task") |
| def cancel_evaluation_task(job_id: str): |
| """ |
| Cancel evaluation task. |
| |
| Args: |
| job_id: Evaluation job ID |
| |
| Returns: |
| dict: Cancellation result |
| """ |
| async def _cancel_evaluation(): |
| async with AsyncSessionLocal() as db: |
| try: |
| |
| result = await db.execute( |
| select(Evaluation).where(Evaluation.job_id == job_id) |
| ) |
| evaluation = result.scalar_one_or_none() |
| |
| if not evaluation: |
| raise ValueError(f"Evaluation not found: {job_id}") |
| |
| |
| evaluation.update_status(EvaluationStatus.CANCELLED) |
| await db.commit() |
| |
| |
| redis_client = await get_redis() |
| result_key = f"evaluation_result:{job_id}" |
| await redis_client.delete(result_key) |
| |
| progress_key = f"evaluation_progress:{job_id}" |
| await redis_client.delete(progress_key) |
| |
| return { |
| "job_id": job_id, |
| "status": "cancelled", |
| "cancelled_at": datetime.utcnow().isoformat() |
| } |
| |
| except Exception as e: |
| print(f"Cancellation failed for {job_id}: {str(e)}") |
| |
| return { |
| "job_id": job_id, |
| "error": str(e) |
| } |
| |
| |
| import asyncio |
| return asyncio.run(_cancel_evaluation()) |
|
|
|
|
| |
| @celery_app.task(name="tasks.evaluation_task.check_task_status") |
| def check_task_status(task_id: str): |
| """ |
| Check status of a Celery task. |
| |
| Args: |
| task_id: Celery task ID |
| |
| Returns: |
| dict: Task status |
| """ |
| try: |
| task = celery_app.AsyncResult(task_id) |
| |
| return { |
| "task_id": task_id, |
| "status": task.status, |
| "result": task.result if task.ready() else None, |
| "traceback": task.traceback if task.failed() else None, |
| "date_done": task.date_done.isoformat() if task.date_done else None |
| } |
| |
| except Exception as e: |
| return { |
| "task_id": task_id, |
| "error": str(e) |
| } |
|
|