| """ |
| Analytics Service for AegisLM Multi-Run Analysis. |
| |
| Provides database integration and service layer for analytics operations |
| including run fetching, filtering, and data access. |
| """ |
|
|
| import uuid |
| import logging |
| from typing import Dict, List, Any, Optional, Tuple |
| from datetime import datetime, timedelta |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from sqlalchemy import select, and_, or_, func, desc |
| from fastapi import HTTPException, status |
|
|
| from core.database import get_db |
| from db_models.user import User |
| from db_models.evaluation import Evaluation, EvaluationStatus |
| from experiments.experiment_manager import get_experiment_manager |
| from schemas.experiment_schema import Experiment, ExperimentStatus as ExpStatus, ExperimentFilter |
|
|
| from .comparison_engine import ComparisonEngine, get_comparison_engine |
| from .trend_analyzer import TrendAnalyzer, get_trend_analyzer |
| from .aggregation_utils import AggregationUtils, get_aggregation_utils |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class AnalyticsService: |
| """ |
| Service layer for analytics operations. |
| |
| Provides database integration, run fetching, filtering, |
| and coordination between analytics components. |
| """ |
| |
| def __init__(self, db: AsyncSession): |
| """Initialize analytics service.""" |
| self.db = db |
| self.experiment_manager = get_experiment_manager() |
| self.comparison_engine = None |
| self.trend_analyzer = None |
| self.aggregation_utils = None |
| |
| async def _get_analytics_components(self): |
| """Get lazy-loaded analytics components.""" |
| if not self.comparison_engine: |
| self.comparison_engine = await get_comparison_engine() |
| if not self.trend_analyzer: |
| self.trend_analyzer = await get_trend_analyzer() |
| if not self.aggregation_utils: |
| self.aggregation_utils = await get_aggregation_utils() |
| |
| async def fetch_runs_by_ids(self, run_ids: List[str], user_id: Optional[int] = None) -> List[Experiment]: |
| """ |
| Fetch experiments by run IDs with optional user filtering. |
| |
| Args: |
| run_ids: List of run IDs to fetch |
| user_id: Optional user ID for permission filtering |
| |
| Returns: |
| List[Experiment]: Valid experiments |
| |
| Raises: |
| HTTPException: If runs not found or access denied |
| """ |
| await self._get_analytics_components() |
| |
| experiments = [] |
| |
| for run_id in run_ids: |
| try: |
| |
| run_uuid = uuid.UUID(run_id) if isinstance(run_id, str) else run_id |
| |
| |
| experiment = self.experiment_manager.store.get_experiment(run_uuid) |
| |
| if not experiment: |
| logger.warning(f"Experiment not found: {run_id}") |
| continue |
| |
| |
| if user_id and hasattr(experiment, 'created_by') and experiment.created_by: |
| |
| |
| pass |
| |
| experiments.append(experiment) |
| |
| except ValueError: |
| logger.warning(f"Invalid run ID format: {run_id}") |
| continue |
| except Exception as e: |
| logger.error(f"Error fetching experiment {run_id}: {e}") |
| continue |
| |
| if not experiments: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="No valid experiments found" |
| ) |
| |
| return experiments |
| |
| async def fetch_runs_by_filter( |
| self, |
| filters: ExperimentFilter, |
| user_id: Optional[int] = None, |
| limit: int = 100 |
| ) -> List[Experiment]: |
| """ |
| Fetch experiments using filter criteria. |
| |
| Args: |
| filters: Filter criteria |
| user_id: Optional user ID for permission filtering |
| limit: Maximum number of results |
| |
| Returns: |
| List[Experiment]: Filtered experiments |
| """ |
| await self._get_analytics_components() |
| |
| |
| all_experiments = self.experiment_manager.store.list_experiments(limit=limit * 2) |
| |
| filtered_experiments = [] |
| |
| for exp in all_experiments: |
| |
| if filters.model_name and exp.model_name != filters.model_name: |
| continue |
| |
| if filters.dataset_name and exp.dataset_name != filters.dataset_name: |
| continue |
| |
| if filters.status and exp.status != ExpStatus(filters.status): |
| continue |
| |
| if filters.attack_types and not any(atk in exp.attack_types for atk in filters.attack_types): |
| continue |
| |
| if filters.created_after and exp.created_at < filters.created_after: |
| continue |
| |
| if filters.created_before and exp.created_at > filters.created_before: |
| continue |
| |
| if filters.min_prompt_count and exp.prompt_count < filters.min_prompt_count: |
| continue |
| |
| if filters.max_prompt_count and exp.prompt_count > filters.max_prompt_count: |
| continue |
| |
| |
| if user_id and hasattr(exp, 'created_by') and exp.created_by: |
| |
| pass |
| |
| filtered_experiments.append(exp) |
| |
| if len(filtered_experiments) >= limit: |
| break |
| |
| return filtered_experiments |
| |
| async def fetch_recent_runs( |
| self, |
| days: int = 30, |
| user_id: Optional[int] = None, |
| model_name: Optional[str] = None, |
| dataset_name: Optional[str] = None |
| ) -> List[Experiment]: |
| """ |
| Fetch recent experiments within specified time period. |
| |
| Args: |
| days: Number of days to look back |
| user_id: Optional user ID for permission filtering |
| model_name: Optional model filter |
| dataset_name: Optional dataset filter |
| |
| Returns: |
| List[Experiment]: Recent experiments |
| """ |
| await self._get_analytics_components() |
| |
| |
| threshold_date = datetime.utcnow() - timedelta(days=days) |
| |
| |
| filters = ExperimentFilter( |
| created_after=threshold_date, |
| model_name=model_name, |
| dataset_name=dataset_name |
| ) |
| |
| return await self.fetch_runs_by_filter(filters, user_id, limit=500) |
| |
| async def compare_runs(self, run_ids: List[str], user_id: Optional[int] = None) -> Dict[str, Any]: |
| """ |
| Compare multiple experiment runs. |
| |
| Args: |
| run_ids: List of run IDs to compare |
| user_id: Optional user ID for permission filtering |
| |
| Returns: |
| Dict[str, Any]: Comparison results |
| |
| Raises: |
| HTTPException: If comparison fails |
| """ |
| await self._get_analytics_components() |
| |
| try: |
| |
| experiments = await self.fetch_runs_by_ids(run_ids, user_id) |
| |
| |
| comparison_result = await self.comparison_engine.compare_runs(run_ids) |
| |
| |
| return await self._serialize_comparison_result(comparison_result) |
| |
| except ValueError as e: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=str(e) |
| ) |
| except Exception as e: |
| logger.error(f"Comparison failed: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Comparison analysis failed" |
| ) |
| |
| async def analyze_trends(self, run_ids: List[str], user_id: Optional[int] = None) -> Dict[str, Any]: |
| """ |
| Analyze trends across multiple runs. |
| |
| Args: |
| run_ids: List of run IDs to analyze |
| user_id: Optional user ID for permission filtering |
| |
| Returns: |
| Dict[str, Any]: Trend analysis results |
| |
| Raises: |
| HTTPException: If trend analysis fails |
| """ |
| await self._get_analytics_components() |
| |
| try: |
| |
| experiments = await self.fetch_runs_by_ids(run_ids, user_id) |
| |
| |
| trend_result = await self.trend_analyzer.analyze_trend(run_ids) |
| |
| |
| return await self._serialize_trend_result(trend_result) |
| |
| except ValueError as e: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=str(e) |
| ) |
| except Exception as e: |
| logger.error(f"Trend analysis failed: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Trend analysis failed" |
| ) |
| |
| async def get_aggregated_metrics( |
| self, |
| filters: Optional[ExperimentFilter] = None, |
| user_id: Optional[int] = None, |
| group_by: Optional[str] = None |
| ) -> Dict[str, Any]: |
| """ |
| Get aggregated metrics with optional grouping. |
| |
| Args: |
| filters: Optional filter criteria |
| user_id: Optional user ID for permission filtering |
| group_by: Optional grouping method |
| |
| Returns: |
| Dict[str, Any]: Aggregated metrics |
| """ |
| await self._get_analytics_components() |
| |
| try: |
| |
| if filters: |
| experiments = await self.fetch_runs_by_filter(filters, user_id, limit=1000) |
| else: |
| experiments = await self.fetch_recent_runs(days=90, user_id=user_id) |
| |
| if not experiments: |
| return {"message": "No experiments found for aggregation"} |
| |
| |
| if group_by == 'model': |
| aggregations = await self.aggregation_utils.aggregate_by_model(experiments) |
| elif group_by == 'dataset': |
| aggregations = await self.aggregation_utils.aggregate_by_dataset(experiments) |
| elif group_by == 'time_window': |
| aggregations = await self.aggregation_utils.aggregate_by_time_window(experiments) |
| else: |
| |
| overall = await self.aggregation_utils.aggregate_metrics(experiments) |
| aggregations = {"overall": overall} |
| |
| |
| return await self._serialize_aggregations(aggregations) |
| |
| except Exception as e: |
| logger.error(f"Aggregation failed: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Metrics aggregation failed" |
| ) |
| |
| async def get_top_performers( |
| self, |
| metric: str = 'robustness_score', |
| top_n: int = 10, |
| user_id: Optional[int] = None, |
| model_name: Optional[str] = None |
| ) -> Dict[str, Any]: |
| """ |
| Get top performing experiments by metric. |
| |
| Args: |
| metric: Metric to rank by |
| top_n: Number of top performers |
| user_id: Optional user ID for permission filtering |
| model_name: Optional model filter |
| |
| Returns: |
| Dict[str, Any]: Top performers data |
| """ |
| await self._get_analytics_components() |
| |
| try: |
| |
| experiments = await self.fetch_recent_runs(days=90, user_id=user_id, model_name=model_name) |
| |
| if not experiments: |
| return {"message": "No experiments found"} |
| |
| |
| top_performers = await self.aggregation_utils.get_top_performers(experiments, metric, top_n) |
| |
| return { |
| "metric": metric, |
| "top_n": top_n, |
| "performers": [ |
| { |
| "run_id": run_id, |
| "value": value, |
| "experiment_name": name or run_id[:8] |
| } |
| for run_id, value, name in top_performers |
| ] |
| } |
| |
| except Exception as e: |
| logger.error(f"Top performers analysis failed: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Top performers analysis failed" |
| ) |
| |
| async def get_analytics_summary(self, user_id: Optional[int] = None) -> Dict[str, Any]: |
| """ |
| Get comprehensive analytics summary. |
| |
| Args: |
| user_id: Optional user ID for permission filtering |
| |
| Returns: |
| Dict[str, Any]: Analytics summary |
| """ |
| await self._get_analytics_components() |
| |
| try: |
| |
| experiments = await self.fetch_recent_runs(days=30, user_id=user_id) |
| |
| if not experiments: |
| return {"message": "No experiments found"} |
| |
| |
| overall_aggregation = await self.aggregation_utils.aggregate_metrics(experiments) |
| |
| |
| summary_stats = await self.aggregation_utils.get_summary_statistics(overall_aggregation) |
| |
| |
| model_counts = {} |
| for exp in experiments: |
| model_counts[exp.model_name] = model_counts.get(exp.model_name, 0) + 1 |
| |
| |
| recent_run_ids = [exp.run_id.hex for exp in experiments[-10:]] |
| trend_summary = None |
| if len(recent_run_ids) >= 3: |
| try: |
| trend_result = await self.trend_analyzer.analyze_trend(recent_run_ids) |
| trend_summary = { |
| "overall_direction": trend_result.overall_direction.value, |
| "health_score": trend_result.overall_health_score, |
| "key_insights": trend_result.key_insights[:3] |
| } |
| except: |
| pass |
| |
| return { |
| "summary_statistics": summary_stats, |
| "model_distribution": model_counts, |
| "recent_trends": trend_summary, |
| "total_experiments": len(experiments), |
| "analysis_period_days": 30 |
| } |
| |
| except Exception as e: |
| logger.error(f"Analytics summary failed: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Analytics summary failed" |
| ) |
| |
| async def _serialize_comparison_result(self, result) -> Dict[str, Any]: |
| """Serialize comparison result to JSON-compatible format.""" |
| return { |
| "run_ids": result.run_ids, |
| "comparison_date": result.comparison_date.isoformat(), |
| "total_runs": result.total_runs, |
| "best_run": result.best_run, |
| "worst_run": result.worst_run, |
| "rankings": [ |
| { |
| "run_id": r.run_id, |
| "experiment_name": r.experiment_name, |
| "rank": r.rank, |
| "total_runs": r.total_runs, |
| "robustness_score": r.robustness_score, |
| "risk_score": r.risk_score, |
| "success_rate": r.success_rate, |
| "execution_time_ms": r.execution_time_ms, |
| "performance_tier": r.performance_tier, |
| "is_best": r.is_best, |
| "is_worst": r.is_worst |
| } |
| for r in result.rankings |
| ], |
| "metric_averages": result.metric_averages, |
| "improvement_opportunities": result.improvement_opportunities, |
| "key_differences": result.key_differences, |
| "consistency_score": result.consistency_score, |
| "chart_data": result.chart_data |
| } |
| |
| async def _serialize_trend_result(self, result) -> Dict[str, Any]: |
| """Serialize trend result to JSON-compatible format.""" |
| return { |
| "run_ids": result.run_ids, |
| "analysis_date": result.analysis_date.isoformat(), |
| "time_period_days": result.time_period_days, |
| "total_runs": result.total_runs, |
| "overall_direction": result.overall_direction.value, |
| "overall_health_score": result.overall_health_score, |
| "key_insights": result.key_insights, |
| "recommendations": result.recommendations, |
| "warning_indicators": result.warning_indicators, |
| "improvement_summary": result.improvement_summary, |
| "degradation_summary": result.degradation_summary, |
| "metric_trends": { |
| metric: { |
| "metric_name": trend.metric_name, |
| "direction": trend.metrics.direction.value, |
| "strength": trend.metrics.strength.value, |
| "improvement_rate": trend.metrics.improvement_rate, |
| "stability_score": trend.metrics.stability_score, |
| "data_points": len(trend.data_points), |
| "anomalies_count": len(trend.anomalies), |
| "significant_changes_count": len(trend.significant_changes) |
| } |
| for metric, trend in result.metric_trends.items() |
| }, |
| "chart_data": result.chart_data |
| } |
| |
| async def _serialize_aggregations(self, aggregations: Dict[str, Any]) -> Dict[str, Any]: |
| """Serialize aggregation results to JSON-compatible format.""" |
| serialized = {} |
| |
| for key, aggregation in aggregations.items(): |
| serialized[key] = { |
| "total_experiments": aggregation.total_experiments, |
| "completed_experiments": aggregation.completed_experiments, |
| "failed_experiments": aggregation.failed_experiments, |
| "success_rate": aggregation.success_rate, |
| "overall_health_score": aggregation.overall_health_score, |
| "time_period_days": aggregation.time_period_days, |
| "avg_experiments_per_day": aggregation.avg_experiments_per_day, |
| "performance_tiers": aggregation.performance_tiers or {} |
| } |
| |
| |
| if aggregation.robustness_stats: |
| serialized[key]["robustness_stats"] = { |
| "mean": aggregation.robustness_stats.mean, |
| "median": aggregation.robustness_stats.median, |
| "std_deviation": aggregation.robustness_stats.std_deviation, |
| "min_value": aggregation.robustness_stats.min_value, |
| "max_value": aggregation.robustness_stats.max_value, |
| "data_quality_score": aggregation.robustness_stats.data_quality_score |
| } |
| |
| if aggregation.risk_stats: |
| serialized[key]["risk_stats"] = { |
| "mean": aggregation.risk_stats.mean, |
| "median": aggregation.risk_stats.median, |
| "std_deviation": aggregation.risk_stats.std_deviation, |
| "min_value": aggregation.risk_stats.min_value, |
| "max_value": aggregation.risk_stats.max_value, |
| "data_quality_score": aggregation.risk_stats.data_quality_score |
| } |
| |
| return serialized |
|
|