ALM-2 / backend /tasks /evaluation_task.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
Evaluation tasks for AegisLM SaaS Backend.
Production-ready Celery tasks for AI evaluation processing,
background job execution, and result storage.
"""
import sys
import os
from datetime import datetime
from typing import Dict, Any, Optional
from celery import current_task
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
# Add the backend directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from workers.celery_worker import celery_app
from core.config import settings
from core.database import get_redis
from db_models.evaluation import Evaluation, EvaluationStatus
from schemas.evaluation_schema import EvaluationCreate
# Try to import AI pipeline components (now works with default path fix)
try:
from ai.pipelines.redteam_pipeline import run_redteam_pipeline, create_pipeline_config
from ai.models import PipelineConfig as AIPipelineConfig
AI_PIPELINE_AVAILABLE = True
except ImportError as e:
AI_PIPELINE_AVAILABLE = False
print(f"Warning: AI pipeline not available: {str(e)}")
# Database session for tasks
AsyncSessionLocal = sessionmaker(
bind=create_async_engine(settings.DATABASE_URL),
class_=AsyncSession,
expire_on_commit=False
)
class ModelInterface:
"""
Model interface for AI evaluation.
Implements the interface expected by the red team pipeline.
"""
def __init__(self, model_config: Dict[str, Any]):
"""
Initialize model interface.
Args:
model_config: Model configuration
"""
self.model_name = model_config.get("model_name", "unknown")
self.model_type = model_config.get("model_type", "api")
self.api_endpoint = model_config.get("api_endpoint")
self.api_key = model_config.get("api_key")
self.temperature = model_config.get("temperature", 0.7)
self.max_tokens = model_config.get("max_tokens", 1000)
self.custom_parameters = model_config.get("custom_parameters", {})
async def generate_response(self, prompt: str) -> str:
"""
Produce a real-time response using the AegisLM LLM Router.
"""
from ai.llm_router.llm_router import get_llm_router
router = get_llm_router()
try:
# Route to the actual model provider (Groq/OpenAI/Mistral)
response = await router.route_and_generate(
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
max_tokens=self.max_tokens
)
return response.text
except Exception as e:
logger.error(f"Router failure: {str(e)}")
return f"Error: Model {self.model_name} failed to respond: {str(e)}"
@celery_app.task(bind=True, name="tasks.evaluation_task.run_evaluation_task")
def run_evaluation_task(self, job_id: str):
"""
Run evaluation task for given job ID.
Args:
job_id: Evaluation job ID
Returns:
dict: Task result
"""
async def _run_evaluation():
async with AsyncSessionLocal() as db:
try:
# Get evaluation from database
result = await db.execute(
select(Evaluation).where(Evaluation.job_id == job_id)
)
evaluation = result.scalar_one_or_none()
if not evaluation:
raise ValueError(f"Evaluation not found: {job_id}")
# Update status to running
evaluation.update_status(EvaluationStatus.RUNNING)
await db.commit()
# Create model interface
model = ModelInterface(evaluation.model_config)
# Create pipeline configuration
pipeline_config_dict = evaluation.pipeline_config
if AI_PIPELINE_AVAILABLE:
pipeline_config = create_pipeline_config(
max_iterations=pipeline_config_dict.get("max_iterations", 5),
attack_types=pipeline_config_dict.get("attack_types"),
num_prompts=pipeline_config_dict.get("num_prompts", 10),
mutation_enabled=pipeline_config_dict.get("mutation_enabled", True),
weights=pipeline_config_dict.get("weights")
)
else:
# Use REAL pipeline config with actual AI models if available
try:
from ai.models import PipelineConfig as AIPipelineConfig
pipeline_config = AIPipelineConfig(
max_iterations=pipeline_config_dict.get("max_iterations", 5),
attack_types=pipeline_config_dict.get("attack_types"),
num_prompts=pipeline_config_dict.get("num_prompts", 10),
mutation_enabled=pipeline_config_dict.get("mutation_enabled", True),
weights=pipeline_config_dict.get("weights"),
use_real_models=True, # Enable real model usage
fallback_to_simulation=True # Allow fallback if needed
)
except ImportError:
# Final fallback to basic config
pipeline_config = type('PipelineConfig', (), {
'max_iterations': pipeline_config_dict.get("max_iterations", 5),
'attack_types': pipeline_config_dict.get("attack_types", []),
'num_prompts': pipeline_config_dict.get("num_prompts", 10),
'mutation_enabled': pipeline_config_dict.get("mutation_enabled", True),
'weights': pipeline_config_dict.get("weights", {})
})()
# Run red team pipeline
start_time = datetime.utcnow()
if AI_PIPELINE_AVAILABLE:
# INTEGRATION: Awaiting the production async pipeline
pipeline_result = await run_redteam_pipeline(model, pipeline_config, job_id)
else:
# Use simplified result if AI pipeline not available
class SimplifiedPipelineResult:
def __init__(self):
self.pipeline_id = job_id
self.total_attacks = 0
self.successful_attacks = []
self.failed_attacks = []
self.success_rate = 0.0
self.execution_time_ms = 1000
def dict(self):
return {
"pipeline_id": self.pipeline_id,
"total_attacks": self.total_attacks,
"successful_attacks": self.successful_attacks,
"failed_attacks": self.failed_attacks,
"success_rate": self.success_rate,
"execution_time_ms": self.execution_time_ms
}
pipeline_result = SimplifiedPipelineResult()
end_time = datetime.utcnow()
# Calculate metrics
execution_time_ms = int((end_time - start_time).total_seconds() * 1000)
success_rate = pipeline_result.success_rate * 100 if pipeline_result.success_rate else 0.0
# Update evaluation with results
evaluation.status = EvaluationStatus.COMPLETED
evaluation.result_json = pipeline_result.dict()
evaluation.total_attacks = pipeline_result.total_attacks
evaluation.successful_attacks = len(pipeline_result.successful_attacks)
evaluation.success_rate = pipeline_result.success_rate # Store as float (0.0 to 1.0)
evaluation.execution_time_ms = execution_time_ms
evaluation.completed_at = end_time
await db.commit()
# Store result in Redis for quick access
redis_client = await get_redis()
result_key = f"evaluation_result:{job_id}"
await redis_client.setex(
result_key,
3600, # 1 hour expiry
str(pipeline_result.dict())
)
return {
"job_id": job_id,
"status": "completed",
"execution_time_ms": execution_time_ms,
"total_attacks": pipeline_result.total_attacks,
"successful_attacks": len(pipeline_result.successful_attacks),
"success_rate": success_rate,
"completed_at": end_time.isoformat()
}
except Exception as e:
# Update evaluation status to failed
if evaluation:
evaluation.update_status(EvaluationStatus.FAILED, str(e))
await db.commit()
# Log error
print(f"Evaluation task failed for job {job_id}: {str(e)}")
return {
"job_id": job_id,
"status": "failed",
"error": str(e),
"failed_at": datetime.utcnow().isoformat()
}
# Run the async function
import asyncio
return asyncio.run(_run_evaluation())
@celery_app.task(bind=True, name="tasks.evaluation_task.run_benchmark_task")
def run_benchmark_task(self, benchmark_id: str, job_ids: list):
"""
Run benchmark task for multiple evaluation jobs.
Args:
benchmark_id: Benchmark ID
job_ids: List of evaluation job IDs
Returns:
dict: Benchmark task result
"""
async def _run_benchmark():
try:
# Trigger individual evaluation tasks
task_results = []
for job_id in job_ids:
# Chain the evaluation task
task = run_evaluation_task.delay(job_id)
task_results.append({
"job_id": job_id,
"task_id": task.id,
"status": "queued"
})
# Store benchmark progress in Redis
redis_client = await get_redis()
benchmark_key = f"benchmark:{benchmark_id}"
benchmark_data = {
"benchmark_id": benchmark_id,
"job_ids": job_ids,
"task_ids": [result["task_id"] for result in task_results],
"status": "running",
"started_at": datetime.utcnow().isoformat(),
"total_jobs": len(job_ids),
"completed_jobs": 0,
"failed_jobs": 0
}
await redis_client.setex(
benchmark_key,
24 * 3600, # 24 hours expiry
str(benchmark_data)
)
return {
"benchmark_id": benchmark_id,
"status": "started",
"total_jobs": len(job_ids),
"task_results": task_results
}
except Exception as e:
print(f"Benchmark task failed for {benchmark_id}: {str(e)}")
return {
"benchmark_id": benchmark_id,
"status": "failed",
"error": str(e),
"failed_at": datetime.utcnow().isoformat()
}
# Run the async function
import asyncio
return asyncio.run(_run_benchmark())
@celery_app.task(name="tasks.evaluation_task.cleanup_expired_results")
def cleanup_expired_results():
"""
Cleanup expired evaluation results from Redis.
Returns:
dict: Cleanup results
"""
async def _cleanup():
try:
redis_client = await get_redis()
# Get all evaluation result keys
result_keys = await redis_client.keys("evaluation_result:*")
cleaned_count = 0
for key in result_keys:
# Check if key exists and is expired
ttl = await redis_client.ttl(key)
if ttl == -1: # No expiry set, set one
await redis_client.expire(key, 3600) # 1 hour
elif ttl == -2: # Key expired
cleaned_count += 1
# Also cleanup benchmark data
benchmark_keys = await redis_client.keys("benchmark:*")
for key in benchmark_keys:
ttl = await redis_client.ttl(key)
if ttl == -1: # No expiry set, set one
await redis_client.expire(key, 24 * 3600) # 24 hours
elif ttl == -2: # Key expired
cleaned_count += 1
return {
"status": "completed",
"timestamp": datetime.utcnow().isoformat(),
"cleaned_keys": cleaned_count,
"total_result_keys": len(result_keys),
"total_benchmark_keys": len(benchmark_keys)
}
except Exception as e:
print(f"Cleanup task failed: {str(e)}")
return {
"status": "failed",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
}
# Run the async function
import asyncio
return asyncio.run(_cleanup())
@celery_app.task(name="tasks.evaluation_task.update_evaluation_progress")
def update_evaluation_progress(job_id: str, progress: float):
"""
Update evaluation progress.
Args:
job_id: Evaluation job ID
progress: Progress percentage (0-100)
Returns:
dict: Update result
"""
async def _update_progress():
try:
redis_client = await get_redis()
progress_key = f"evaluation_progress:{job_id}"
progress_data = {
"job_id": job_id,
"progress": progress,
"updated_at": datetime.utcnow().isoformat()
}
await redis_client.setex(
progress_key,
3600, # 1 hour expiry
str(progress_data)
)
return {
"job_id": job_id,
"progress": progress,
"updated_at": progress_data["updated_at"]
}
except Exception as e:
print(f"Progress update failed for {job_id}: {str(e)}")
return {
"job_id": job_id,
"error": str(e)
}
# Run the async function
import asyncio
return asyncio.run(_update_progress())
@celery_app.task(name="tasks.evaluation_task.cancel_evaluation_task")
def cancel_evaluation_task(job_id: str):
"""
Cancel evaluation task.
Args:
job_id: Evaluation job ID
Returns:
dict: Cancellation result
"""
async def _cancel_evaluation():
async with AsyncSessionLocal() as db:
try:
# Get evaluation from database
result = await db.execute(
select(Evaluation).where(Evaluation.job_id == job_id)
)
evaluation = result.scalar_one_or_none()
if not evaluation:
raise ValueError(f"Evaluation not found: {job_id}")
# Update status to cancelled
evaluation.update_status(EvaluationStatus.CANCELLED)
await db.commit()
# Remove from Redis if exists
redis_client = await get_redis()
result_key = f"evaluation_result:{job_id}"
await redis_client.delete(result_key)
progress_key = f"evaluation_progress:{job_id}"
await redis_client.delete(progress_key)
return {
"job_id": job_id,
"status": "cancelled",
"cancelled_at": datetime.utcnow().isoformat()
}
except Exception as e:
print(f"Cancellation failed for {job_id}: {str(e)}")
return {
"job_id": job_id,
"error": str(e)
}
# Run the async function
import asyncio
return asyncio.run(_cancel_evaluation())
# Helper function to check task status
@celery_app.task(name="tasks.evaluation_task.check_task_status")
def check_task_status(task_id: str):
"""
Check status of a Celery task.
Args:
task_id: Celery task ID
Returns:
dict: Task status
"""
try:
task = celery_app.AsyncResult(task_id)
return {
"task_id": task_id,
"status": task.status,
"result": task.result if task.ready() else None,
"traceback": task.traceback if task.failed() else None,
"date_done": task.date_done.isoformat() if task.date_done else None
}
except Exception as e:
return {
"task_id": task_id,
"error": str(e)
}