File size: 20,510 Bytes
2ed8996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
"""
Aggregation Utilities for AegisLM Analytics.

Provides comprehensive metrics aggregation capabilities including
statistical analysis, percentile calculations, and data summarization.
"""

import statistics
from typing import Dict, List, Any, Optional, Tuple, Union
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
import logging

from schemas.experiment_schema import Experiment, ResultSummary

logger = logging.getLogger(__name__)


class AggregationMethod(str, Enum):
    """Available aggregation methods."""
    MEAN = "mean"
    MEDIAN = "median"
    MIN = "min"
    MAX = "max"
    SUM = "sum"
    STD_DEV = "std_dev"
    VARIANCE = "variance"
    COUNT = "count"
    PERCENTILE_25 = "percentile_25"
    PERCENTILE_75 = "percentile_75"
    PERCENTILE_90 = "percentile_90"
    PERCENTILE_95 = "percentile_95"


@dataclass
class MetricStats:
    """Statistical summary for a metric."""
    metric_name: str
    count: int
    mean: float
    median: float
    min_value: float
    max_value: float
    std_deviation: float
    variance: float
    sum: float
    
    # Percentiles
    p25: float
    p75: float
    p90: float
    p95: float
    
    # Additional statistics
    range_size: float
    coefficient_of_variation: float
    
    # Distribution info
    outliers_count: int
    outliers_values: List[float]
    
    # Quality indicators
    data_quality_score: float  # 0-1 based on completeness and consistency


@dataclass
class AggregatedMetrics:
    """Container for aggregated experiment metrics."""
    total_experiments: int
    completed_experiments: int
    failed_experiments: int
    success_rate: float
    
    # Metric statistics
    robustness_stats: Optional[MetricStats] = None
    risk_stats: Optional[MetricStats] = None
    success_rate_stats: Optional[MetricStats] = None
    confidence_stats: Optional[MetricStats] = None
    hallucination_stats: Optional[MetricStats] = None
    toxicity_stats: Optional[MetricStats] = None
    execution_time_stats: Optional[MetricStats] = None
    
    # Time-based aggregations
    time_period_days: int = 0
    avg_experiments_per_day: float = 0.0
    
    # Model and dataset aggregations
    model_distribution: Dict[str, int] = None
    dataset_distribution: Dict[str, int] = None
    attack_type_distribution: Dict[str, int] = None
    
    # Performance summary
    overall_health_score: float = 0.0
    performance_tiers: Dict[str, int] = None  # excellent, good, average, poor


class AggregationUtils:
    """
    Utilities for aggregating experiment metrics.
    
    Provides comprehensive statistical analysis and data summarization
    capabilities for experiment analytics.
    """
    
    def __init__(self):
        """Initialize aggregation utilities."""
        pass
    
    async def aggregate_metrics(self, experiments: List[Experiment]) -> AggregatedMetrics:
        """
        Aggregate metrics across multiple experiments.
        
        Args:
            experiments: List of experiments to aggregate
            
        Returns:
            AggregatedMetrics: Comprehensive aggregation results
        """
        if not experiments:
            return AggregatedMetrics(
                total_experiments=0,
                completed_experiments=0,
                failed_experiments=0,
                success_rate=0.0
            )
        
        logger.info(f"Aggregating metrics for {len(experiments)} experiments")
        
        # Basic counts
        total_experiments = len(experiments)
        completed_experiments = len([
            exp for exp in experiments 
            if exp.status.value == 'completed' and exp.result_summary
        ])
        failed_experiments = len([
            exp for exp in experiments 
            if exp.status.value == 'failed'
        ])
        
        success_rate = completed_experiments / total_experiments if total_experiments > 0 else 0.0
        
        # Create aggregated metrics
        aggregated = AggregatedMetrics(
            total_experiments=total_experiments,
            completed_experiments=completed_experiments,
            failed_experiments=failed_experiments,
            success_rate=success_rate
        )
        
        # Calculate metric statistics
        await self._calculate_metric_stats(experiments, aggregated)
        
        # Calculate time-based aggregations
        await self._calculate_time_aggregations(experiments, aggregated)
        
        # Calculate distribution aggregations
        await self._calculate_distribution_aggregations(experiments, aggregated)
        
        # Calculate performance summary
        await self._calculate_performance_summary(experiments, aggregated)
        
        return aggregated
    
    async def _calculate_metric_stats(self, experiments: List[Experiment], aggregated: AggregatedMetrics):
        """
        Calculate statistical summaries for each metric.
        
        Args:
            experiments: List of experiments
            aggregated: AggregatedMetrics to update
        """
        # Extract metric values
        metrics_mapping = {
            'robustness_stats': 'robustness_score',
            'risk_stats': 'risk_score',
            'success_rate_stats': 'success_rate',
            'confidence_stats': 'confidence_score',
            'hallucination_stats': 'hallucination_rate',
            'toxicity_stats': 'toxicity_rate',
            'execution_time_stats': 'execution_time_ms'
        }
        
        for stats_attr, metric_name in metrics_mapping.items():
            values = []
            
            for exp in experiments:
                if exp.result_summary and hasattr(exp.result_summary, metric_name):
                    value = getattr(exp.result_summary, metric_name)
                    if value is not None:
                        values.append(float(value))
            
            if values:
                stats = await self._calculate_metric_statistics(metric_name, values)
                setattr(aggregated, stats_attr, stats)
    
    async def _calculate_metric_statistics(self, metric_name: str, values: List[float]) -> MetricStats:
        """
        Calculate comprehensive statistics for a metric.
        
        Args:
            metric_name: Name of the metric
            values: List of metric values
            
        Returns:
            MetricStats: Statistical summary
        """
        if not values:
            raise ValueError(f"No values provided for metric {metric_name}")
        
        # Basic statistics
        count = len(values)
        mean_val = statistics.mean(values)
        median_val = statistics.median(values)
        min_val = min(values)
        max_val = max(values)
        sum_val = sum(values)
        
        # Standard deviation and variance
        if count > 1:
            std_dev = statistics.stdev(values)
            variance = statistics.variance(values)
        else:
            std_dev = 0.0
            variance = 0.0
        
        # Percentiles
        sorted_values = sorted(values)
        p25 = self._calculate_percentile(sorted_values, 25)
        p75 = self._calculate_percentile(sorted_values, 75)
        p90 = self._calculate_percentile(sorted_values, 90)
        p95 = self._calculate_percentile(sorted_values, 95)
        
        # Additional statistics
        range_size = max_val - min_val
        coefficient_of_variation = std_dev / mean_val if mean_val != 0 else 0.0
        
        # Outlier detection using IQR method
        q1, q3 = p25, p75
        iqr = q3 - q1
        lower_bound = q1 - 1.5 * iqr
        upper_bound = q3 + 1.5 * iqr
        
        outliers = [v for v in values if v < lower_bound or v > upper_bound]
        
        # Data quality score based on completeness and consistency
        completeness_score = 1.0  # All values are present
        consistency_score = max(0.0, 1.0 - coefficient_of_variation)  # Lower variation = higher consistency
        data_quality_score = (completeness_score + consistency_score) / 2.0
        
        return MetricStats(
            metric_name=metric_name,
            count=count,
            mean=mean_val,
            median=median_val,
            min_value=min_val,
            max_value=max_val,
            std_deviation=std_dev,
            variance=variance,
            sum=sum_val,
            p25=p25,
            p75=p75,
            p90=p90,
            p95=p95,
            range_size=range_size,
            coefficient_of_variation=coefficient_of_variation,
            outliers_count=len(outliers),
            outliers_values=outliers,
            data_quality_score=data_quality_score
        )
    
    def _calculate_percentile(self, sorted_values: List[float], percentile: int) -> float:
        """
        Calculate percentile value from sorted values.
        
        Args:
            sorted_values: Sorted list of values
            percentile: Percentile to calculate (0-100)
            
        Returns:
            float: Percentile value
        """
        if not sorted_values:
            return 0.0
        
        index = (percentile / 100) * (len(sorted_values) - 1)
        
        if index.is_integer():
            return sorted_values[int(index)]
        else:
            lower_index = int(index)
            upper_index = lower_index + 1
            weight = index - lower_index
            return sorted_values[lower_index] * (1 - weight) + sorted_values[upper_index] * weight
    
    async def _calculate_time_aggregations(self, experiments: List[Experiment], aggregated: AggregatedMetrics):
        """
        Calculate time-based aggregations.
        
        Args:
            experiments: List of experiments
            aggregated: AggregatedMetrics to update
        """
        if not experiments:
            return
        
        # Calculate time period
        timestamps = [exp.created_at for exp in experiments]
        min_time = min(timestamps)
        max_time = max(timestamps)
        time_period = max_time - min_time
        time_period_days = max(1, time_period.days)
        
        aggregated.time_period_days = time_period_days
        aggregated.avg_experiments_per_day = len(experiments) / time_period_days
    
    async def _calculate_distribution_aggregations(self, experiments: List[Experiment], aggregated: AggregatedMetrics):
        """
        Calculate distribution aggregations for models, datasets, and attack types.
        
        Args:
            experiments: List of experiments
            aggregated: AggregatedMetrics to update
        """
        # Model distribution
        model_counts = {}
        for exp in experiments:
            model_name = exp.model_name
            model_counts[model_name] = model_counts.get(model_name, 0) + 1
        aggregated.model_distribution = model_counts
        
        # Dataset distribution
        dataset_counts = {}
        for exp in experiments:
            if exp.dataset_name:
                dataset_name = exp.dataset_name
                dataset_counts[dataset_name] = dataset_counts.get(dataset_name, 0) + 1
        aggregated.dataset_distribution = dataset_counts
        
        # Attack type distribution
        attack_counts = {}
        for exp in experiments:
            for attack_type in exp.attack_types:
                attack_counts[attack_type] = attack_counts.get(attack_type, 0) + 1
        aggregated.attack_type_distribution = attack_counts
    
    async def _calculate_performance_summary(self, experiments: List[Experiment], aggregated: AggregatedMetrics):
        """
        Calculate performance summary and health score.
        
        Args:
            experiments: List of experiments
            aggregated: AggregatedMetrics to update
        """
        if not aggregated.robustness_stats or not aggregated.risk_stats:
            return
        
        # Calculate health score based on key metrics
        health_factors = []
        
        # Robustness (higher is better)
        robustness_score = min(1.0, aggregated.robustness_stats.mean / 0.8)  # Normalize to 0.8 as excellent
        health_factors.append(robustness_score)
        
        # Risk (lower is better)
        risk_score = max(0.0, 1.0 - (aggregated.risk_stats.mean / 0.5))  # Normalize to 0.5 as high risk
        health_factors.append(risk_score)
        
        # Success rate (higher is better)
        if aggregated.success_rate_stats:
            success_score = aggregated.success_rate_stats.mean
            health_factors.append(success_score)
        
        # Data quality
        quality_factors = [
            aggregated.robustness_stats.data_quality_score,
            aggregated.risk_stats.data_quality_score
        ]
        if aggregated.success_rate_stats:
            quality_factors.append(aggregated.success_rate_stats.data_quality_score)
        
        data_quality_score = sum(quality_factors) / len(quality_factors)
        health_factors.append(data_quality_score)
        
        # Overall health score
        aggregated.overall_health_score = sum(health_factors) / len(health_factors)
        
        # Calculate performance tiers
        performance_tiers = {"excellent": 0, "good": 0, "average": 0, "poor": 0}
        
        for exp in experiments:
            if exp.result_summary:
                robustness = exp.result_summary.robustness_score
                risk = exp.result_summary.risk_score
                
                # Simple tier classification
                if robustness >= 0.8 and risk <= 0.1:
                    performance_tiers["excellent"] += 1
                elif robustness >= 0.6 and risk <= 0.2:
                    performance_tiers["good"] += 1
                elif robustness >= 0.4 and risk <= 0.3:
                    performance_tiers["average"] += 1
                else:
                    performance_tiers["poor"] += 1
        
        aggregated.performance_tiers = performance_tiers
    
    async def aggregate_by_time_window(
        self, 
        experiments: List[Experiment], 
        window_days: int = 7
    ) -> Dict[str, AggregatedMetrics]:
        """
        Aggregate metrics by time windows.
        
        Args:
            experiments: List of experiments
            window_days: Size of time window in days
            
        Returns:
            Dict[str, AggregatedMetrics]: Aggregations by time window
        """
        if not experiments:
            return {}
        
        # Sort experiments by creation time
        sorted_experiments = sorted(experiments, key=lambda x: x.created_at)
        
        # Group by time windows
        windowed_experiments = {}
        
        for exp in sorted_experiments:
            # Calculate window key
            days_since_start = (exp.created_at - sorted_experiments[0].created_at).days
            window_key = f"week_{days_since_start // window_days}"
            
            if window_key not in windowed_experiments:
                windowed_experiments[window_key] = []
            windowed_experiments[window_key].append(exp)
        
        # Aggregate each window
        windowed_aggregations = {}
        for window_key, window_exps in windowed_experiments.items():
            aggregation = await self.aggregate_metrics(window_exps)
            windowed_aggregations[window_key] = aggregation
        
        return windowed_aggregations
    
    async def aggregate_by_model(self, experiments: List[Experiment]) -> Dict[str, AggregatedMetrics]:
        """
        Aggregate metrics by model.
        
        Args:
            experiments: List of experiments
            
        Returns:
            Dict[str, AggregatedMetrics]: Aggregations by model
        """
        model_experiments = {}
        
        for exp in experiments:
            model_name = exp.model_name
            if model_name not in model_experiments:
                model_experiments[model_name] = []
            model_experiments[model_name].append(exp)
        
        model_aggregations = {}
        for model_name, model_exps in model_experiments.items():
            aggregation = await self.aggregate_metrics(model_exps)
            model_aggregations[model_name] = aggregation
        
        return model_aggregations
    
    async def aggregate_by_dataset(self, experiments: List[Experiment]) -> Dict[str, AggregatedMetrics]:
        """
        Aggregate metrics by dataset.
        
        Args:
            experiments: List of experiments
            
        Returns:
            Dict[str, AggregatedMetrics]: Aggregations by dataset
        """
        dataset_experiments = {}
        
        for exp in experiments:
            dataset_name = exp.dataset_name or "unknown"
            if dataset_name not in dataset_experiments:
                dataset_experiments[dataset_name] = []
            dataset_experiments[dataset_name].append(exp)
        
        dataset_aggregations = {}
        for dataset_name, dataset_exps in dataset_experiments.items():
            aggregation = await self.aggregate_metrics(dataset_exps)
            dataset_aggregations[dataset_name] = aggregation
        
        return dataset_aggregations
    
    async def get_top_performers(
        self, 
        experiments: List[Experiment], 
        metric: str = 'robustness_score',
        top_n: int = 5
    ) -> List[Tuple[str, float, str]]:
        """
        Get top performing experiments by metric.
        
        Args:
            experiments: List of experiments
            metric: Metric to rank by
            top_n: Number of top performers to return
            
        Returns:
            List[Tuple[str, float, str]]: (run_id, metric_value, experiment_name)
        """
        valid_experiments = []
        
        for exp in experiments:
            if exp.result_summary and hasattr(exp.result_summary, metric):
                value = getattr(exp.result_summary, metric)
                if value is not None:
                    valid_experiments.append((exp.run_id.hex, float(value), exp.experiment_name))
        
        # Sort by metric value (descending for most metrics)
        reverse_order = metric not in ['risk_score', 'execution_time_ms']
        sorted_experiments = sorted(valid_experiments, key=lambda x: x[1], reverse=reverse_order)
        
        return sorted_experiments[:top_n]
    
    async def get_summary_statistics(self, aggregated: AggregatedMetrics) -> Dict[str, Any]:
        """
        Get summary statistics from aggregated metrics.
        
        Args:
            aggregated: Aggregated metrics
            
        Returns:
            Dict[str, Any]: Summary statistics
        """
        summary = {
            'experiment_summary': {
                'total': aggregated.total_experiments,
                'completed': aggregated.completed_experiments,
                'failed': aggregated.failed_experiments,
                'success_rate': aggregated.success_rate
            },
            'time_summary': {
                'period_days': aggregated.time_period_days,
                'avg_per_day': aggregated.avg_experiments_per_day
            },
            'health_score': aggregated.overall_health_score,
            'performance_tiers': aggregated.performance_tiers or {}
        }
        
        # Add metric summaries
        metric_stats = {
            'robustness': aggregated.robustness_stats,
            'risk': aggregated.risk_stats,
            'success_rate': aggregated.success_rate_stats,
            'confidence': aggregated.confidence_stats,
            'execution_time': aggregated.execution_time_stats
        }
        
        for metric_name, stats in metric_stats.items():
            if stats:
                summary[f'{metric_name}_summary'] = {
                    'mean': stats.mean,
                    'median': stats.median,
                    'std_dev': stats.std_deviation,
                    'min': stats.min_value,
                    'max': stats.max_value,
                    'quality_score': stats.data_quality_score
                }
        
        return summary


# Global aggregation utilities instance
aggregation_utils = AggregationUtils()


async def get_aggregation_utils() -> AggregationUtils:
    """
    Get the global aggregation utilities instance.
    
    Returns:
        AggregationUtils: Global instance
    """
    return aggregation_utils