ALM-2 / backend /analytics /analytics_service.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
Analytics Service for AegisLM Multi-Run Analysis.
Provides database integration and service layer for analytics operations
including run fetching, filtering, and data access.
"""
import uuid
import logging
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func, desc
from fastapi import HTTPException, status
from core.database import get_db
from db_models.user import User
from db_models.evaluation import Evaluation, EvaluationStatus
from experiments.experiment_manager import get_experiment_manager
from schemas.experiment_schema import Experiment, ExperimentStatus as ExpStatus, ExperimentFilter
from .comparison_engine import ComparisonEngine, get_comparison_engine
from .trend_analyzer import TrendAnalyzer, get_trend_analyzer
from .aggregation_utils import AggregationUtils, get_aggregation_utils
logger = logging.getLogger(__name__)
class AnalyticsService:
"""
Service layer for analytics operations.
Provides database integration, run fetching, filtering,
and coordination between analytics components.
"""
def __init__(self, db: AsyncSession):
"""Initialize analytics service."""
self.db = db
self.experiment_manager = get_experiment_manager()
self.comparison_engine = None
self.trend_analyzer = None
self.aggregation_utils = None
async def _get_analytics_components(self):
"""Get lazy-loaded analytics components."""
if not self.comparison_engine:
self.comparison_engine = await get_comparison_engine()
if not self.trend_analyzer:
self.trend_analyzer = await get_trend_analyzer()
if not self.aggregation_utils:
self.aggregation_utils = await get_aggregation_utils()
async def fetch_runs_by_ids(self, run_ids: List[str], user_id: Optional[int] = None) -> List[Experiment]:
"""
Fetch experiments by run IDs with optional user filtering.
Args:
run_ids: List of run IDs to fetch
user_id: Optional user ID for permission filtering
Returns:
List[Experiment]: Valid experiments
Raises:
HTTPException: If runs not found or access denied
"""
await self._get_analytics_components()
experiments = []
for run_id in run_ids:
try:
# Convert string to UUID
run_uuid = uuid.UUID(run_id) if isinstance(run_id, str) else run_id
# Fetch experiment
experiment = self.experiment_manager.store.get_experiment(run_uuid)
if not experiment:
logger.warning(f"Experiment not found: {run_id}")
continue
# Check user permissions if specified
if user_id and hasattr(experiment, 'created_by') and experiment.created_by:
# This would need proper user mapping in the experiment system
# For now, we'll skip user filtering on experiments
pass
experiments.append(experiment)
except ValueError:
logger.warning(f"Invalid run ID format: {run_id}")
continue
except Exception as e:
logger.error(f"Error fetching experiment {run_id}: {e}")
continue
if not experiments:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No valid experiments found"
)
return experiments
async def fetch_runs_by_filter(
self,
filters: ExperimentFilter,
user_id: Optional[int] = None,
limit: int = 100
) -> List[Experiment]:
"""
Fetch experiments using filter criteria.
Args:
filters: Filter criteria
user_id: Optional user ID for permission filtering
limit: Maximum number of results
Returns:
List[Experiment]: Filtered experiments
"""
await self._get_analytics_components()
# Get all experiments from store
all_experiments = self.experiment_manager.store.list_experiments(limit=limit * 2) # Get more for filtering
filtered_experiments = []
for exp in all_experiments:
# Apply filters
if filters.model_name and exp.model_name != filters.model_name:
continue
if filters.dataset_name and exp.dataset_name != filters.dataset_name:
continue
if filters.status and exp.status != ExpStatus(filters.status):
continue
if filters.attack_types and not any(atk in exp.attack_types for atk in filters.attack_types):
continue
if filters.created_after and exp.created_at < filters.created_after:
continue
if filters.created_before and exp.created_at > filters.created_before:
continue
if filters.min_prompt_count and exp.prompt_count < filters.min_prompt_count:
continue
if filters.max_prompt_count and exp.prompt_count > filters.max_prompt_count:
continue
# Check user permissions if specified
if user_id and hasattr(exp, 'created_by') and exp.created_by:
# Skip user filtering for now as experiment system doesn't have proper user mapping
pass
filtered_experiments.append(exp)
if len(filtered_experiments) >= limit:
break
return filtered_experiments
async def fetch_recent_runs(
self,
days: int = 30,
user_id: Optional[int] = None,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None
) -> List[Experiment]:
"""
Fetch recent experiments within specified time period.
Args:
days: Number of days to look back
user_id: Optional user ID for permission filtering
model_name: Optional model filter
dataset_name: Optional dataset filter
Returns:
List[Experiment]: Recent experiments
"""
await self._get_analytics_components()
# Calculate date threshold
threshold_date = datetime.utcnow() - timedelta(days=days)
# Create filter
filters = ExperimentFilter(
created_after=threshold_date,
model_name=model_name,
dataset_name=dataset_name
)
return await self.fetch_runs_by_filter(filters, user_id, limit=500)
async def compare_runs(self, run_ids: List[str], user_id: Optional[int] = None) -> Dict[str, Any]:
"""
Compare multiple experiment runs.
Args:
run_ids: List of run IDs to compare
user_id: Optional user ID for permission filtering
Returns:
Dict[str, Any]: Comparison results
Raises:
HTTPException: If comparison fails
"""
await self._get_analytics_components()
try:
# Fetch experiments
experiments = await self.fetch_runs_by_ids(run_ids, user_id)
# Perform comparison
comparison_result = await self.comparison_engine.compare_runs(run_ids)
# Convert to serializable format
return await self._serialize_comparison_result(comparison_result)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Comparison failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Comparison analysis failed"
)
async def analyze_trends(self, run_ids: List[str], user_id: Optional[int] = None) -> Dict[str, Any]:
"""
Analyze trends across multiple runs.
Args:
run_ids: List of run IDs to analyze
user_id: Optional user ID for permission filtering
Returns:
Dict[str, Any]: Trend analysis results
Raises:
HTTPException: If trend analysis fails
"""
await self._get_analytics_components()
try:
# Fetch experiments
experiments = await self.fetch_runs_by_ids(run_ids, user_id)
# Perform trend analysis
trend_result = await self.trend_analyzer.analyze_trend(run_ids)
# Convert to serializable format
return await self._serialize_trend_result(trend_result)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Trend analysis failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Trend analysis failed"
)
async def get_aggregated_metrics(
self,
filters: Optional[ExperimentFilter] = None,
user_id: Optional[int] = None,
group_by: Optional[str] = None # 'model', 'dataset', 'time_window'
) -> Dict[str, Any]:
"""
Get aggregated metrics with optional grouping.
Args:
filters: Optional filter criteria
user_id: Optional user ID for permission filtering
group_by: Optional grouping method
Returns:
Dict[str, Any]: Aggregated metrics
"""
await self._get_analytics_components()
try:
# Fetch experiments
if filters:
experiments = await self.fetch_runs_by_filter(filters, user_id, limit=1000)
else:
experiments = await self.fetch_recent_runs(days=90, user_id=user_id)
if not experiments:
return {"message": "No experiments found for aggregation"}
# Perform aggregation
if group_by == 'model':
aggregations = await self.aggregation_utils.aggregate_by_model(experiments)
elif group_by == 'dataset':
aggregations = await self.aggregation_utils.aggregate_by_dataset(experiments)
elif group_by == 'time_window':
aggregations = await self.aggregation_utils.aggregate_by_time_window(experiments)
else:
# Overall aggregation
overall = await self.aggregation_utils.aggregate_metrics(experiments)
aggregations = {"overall": overall}
# Convert to serializable format
return await self._serialize_aggregations(aggregations)
except Exception as e:
logger.error(f"Aggregation failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Metrics aggregation failed"
)
async def get_top_performers(
self,
metric: str = 'robustness_score',
top_n: int = 10,
user_id: Optional[int] = None,
model_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Get top performing experiments by metric.
Args:
metric: Metric to rank by
top_n: Number of top performers
user_id: Optional user ID for permission filtering
model_name: Optional model filter
Returns:
Dict[str, Any]: Top performers data
"""
await self._get_analytics_components()
try:
# Fetch experiments
experiments = await self.fetch_recent_runs(days=90, user_id=user_id, model_name=model_name)
if not experiments:
return {"message": "No experiments found"}
# Get top performers
top_performers = await self.aggregation_utils.get_top_performers(experiments, metric, top_n)
return {
"metric": metric,
"top_n": top_n,
"performers": [
{
"run_id": run_id,
"value": value,
"experiment_name": name or run_id[:8]
}
for run_id, value, name in top_performers
]
}
except Exception as e:
logger.error(f"Top performers analysis failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Top performers analysis failed"
)
async def get_analytics_summary(self, user_id: Optional[int] = None) -> Dict[str, Any]:
"""
Get comprehensive analytics summary.
Args:
user_id: Optional user ID for permission filtering
Returns:
Dict[str, Any]: Analytics summary
"""
await self._get_analytics_components()
try:
# Fetch recent experiments
experiments = await self.fetch_recent_runs(days=30, user_id=user_id)
if not experiments:
return {"message": "No experiments found"}
# Get overall aggregation
overall_aggregation = await self.aggregation_utils.aggregate_metrics(experiments)
# Get summary statistics
summary_stats = await self.aggregation_utils.get_summary_statistics(overall_aggregation)
# Get model distribution
model_counts = {}
for exp in experiments:
model_counts[exp.model_name] = model_counts.get(exp.model_name, 0) + 1
# Get recent trends (last 10 runs)
recent_run_ids = [exp.run_id.hex for exp in experiments[-10:]]
trend_summary = None
if len(recent_run_ids) >= 3:
try:
trend_result = await self.trend_analyzer.analyze_trend(recent_run_ids)
trend_summary = {
"overall_direction": trend_result.overall_direction.value,
"health_score": trend_result.overall_health_score,
"key_insights": trend_result.key_insights[:3] # Top 3 insights
}
except:
pass # Trend analysis might fail with insufficient data
return {
"summary_statistics": summary_stats,
"model_distribution": model_counts,
"recent_trends": trend_summary,
"total_experiments": len(experiments),
"analysis_period_days": 30
}
except Exception as e:
logger.error(f"Analytics summary failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Analytics summary failed"
)
async def _serialize_comparison_result(self, result) -> Dict[str, Any]:
"""Serialize comparison result to JSON-compatible format."""
return {
"run_ids": result.run_ids,
"comparison_date": result.comparison_date.isoformat(),
"total_runs": result.total_runs,
"best_run": result.best_run,
"worst_run": result.worst_run,
"rankings": [
{
"run_id": r.run_id,
"experiment_name": r.experiment_name,
"rank": r.rank,
"total_runs": r.total_runs,
"robustness_score": r.robustness_score,
"risk_score": r.risk_score,
"success_rate": r.success_rate,
"execution_time_ms": r.execution_time_ms,
"performance_tier": r.performance_tier,
"is_best": r.is_best,
"is_worst": r.is_worst
}
for r in result.rankings
],
"metric_averages": result.metric_averages,
"improvement_opportunities": result.improvement_opportunities,
"key_differences": result.key_differences,
"consistency_score": result.consistency_score,
"chart_data": result.chart_data
}
async def _serialize_trend_result(self, result) -> Dict[str, Any]:
"""Serialize trend result to JSON-compatible format."""
return {
"run_ids": result.run_ids,
"analysis_date": result.analysis_date.isoformat(),
"time_period_days": result.time_period_days,
"total_runs": result.total_runs,
"overall_direction": result.overall_direction.value,
"overall_health_score": result.overall_health_score,
"key_insights": result.key_insights,
"recommendations": result.recommendations,
"warning_indicators": result.warning_indicators,
"improvement_summary": result.improvement_summary,
"degradation_summary": result.degradation_summary,
"metric_trends": {
metric: {
"metric_name": trend.metric_name,
"direction": trend.metrics.direction.value,
"strength": trend.metrics.strength.value,
"improvement_rate": trend.metrics.improvement_rate,
"stability_score": trend.metrics.stability_score,
"data_points": len(trend.data_points),
"anomalies_count": len(trend.anomalies),
"significant_changes_count": len(trend.significant_changes)
}
for metric, trend in result.metric_trends.items()
},
"chart_data": result.chart_data
}
async def _serialize_aggregations(self, aggregations: Dict[str, Any]) -> Dict[str, Any]:
"""Serialize aggregation results to JSON-compatible format."""
serialized = {}
for key, aggregation in aggregations.items():
serialized[key] = {
"total_experiments": aggregation.total_experiments,
"completed_experiments": aggregation.completed_experiments,
"failed_experiments": aggregation.failed_experiments,
"success_rate": aggregation.success_rate,
"overall_health_score": aggregation.overall_health_score,
"time_period_days": aggregation.time_period_days,
"avg_experiments_per_day": aggregation.avg_experiments_per_day,
"performance_tiers": aggregation.performance_tiers or {}
}
# Add metric statistics
if aggregation.robustness_stats:
serialized[key]["robustness_stats"] = {
"mean": aggregation.robustness_stats.mean,
"median": aggregation.robustness_stats.median,
"std_deviation": aggregation.robustness_stats.std_deviation,
"min_value": aggregation.robustness_stats.min_value,
"max_value": aggregation.robustness_stats.max_value,
"data_quality_score": aggregation.robustness_stats.data_quality_score
}
if aggregation.risk_stats:
serialized[key]["risk_stats"] = {
"mean": aggregation.risk_stats.mean,
"median": aggregation.risk_stats.median,
"std_deviation": aggregation.risk_stats.std_deviation,
"min_value": aggregation.risk_stats.min_value,
"max_value": aggregation.risk_stats.max_value,
"data_quality_score": aggregation.risk_stats.data_quality_score
}
return serialized