Debito commited on
Commit
eefb8cb
·
verified ·
1 Parent(s): 65aa3db

Upload 3 files

Browse files
monitoring/evaluator.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Evaluator for Mamba Swarm
3
+ Comprehensive evaluation system for model performance and quality
4
+ """
5
+
6
+ import time
7
+ import json
8
+ import logging
9
+ import torch
10
+ import numpy as np
11
+ from typing import Dict, List, Any, Optional, Tuple, Callable, Union
12
+ from dataclasses import dataclass, field
13
+ from collections import defaultdict
14
+ import math
15
+ import re
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+ import asyncio
19
+ import concurrent.futures
20
+
21
+ # Evaluation metrics
22
+ @dataclass
23
+ class EvaluationResult:
24
+ metric_name: str
25
+ score: float
26
+ details: Dict[str, Any] = field(default_factory=dict)
27
+ timestamp: float = field(default_factory=time.time)
28
+
29
+ @dataclass
30
+ class BenchmarkResult:
31
+ benchmark_name: str
32
+ overall_score: float
33
+ individual_metrics: List[EvaluationResult]
34
+ execution_time: float
35
+ model_info: Dict[str, Any]
36
+ timestamp: float = field(default_factory=time.time)
37
+
38
+ """
39
+ Model Evaluator for Mamba Swarm
40
+ Comprehensive evaluation system for model performance and quality
41
+ """
42
+
43
+ import time
44
+ import json
45
+ import logging
46
+ import torch
47
+ import numpy as np
48
+ from typing import Dict, List, Any, Optional, Tuple, Callable, Union
49
+ from dataclasses import dataclass, field
50
+ from collections import defaultdict
51
+ import math
52
+ import re
53
+ from datetime import datetime
54
+ from pathlib import Path
55
+ import asyncio
56
+ import concurrent.futures
57
+
58
+ # Evaluation metrics
59
+ @dataclass
60
+ class EvaluationResult:
61
+ metric_name: str
62
+ score: float
63
+ details: Dict[str, Any] = field(default_factory=dict)
64
+ timestamp: float = field(default_factory=time.time)
65
+
66
+ @dataclass
67
+ class BenchmarkResult:
68
+ benchmark_name: str
69
+ overall_score: float
70
+ individual_metrics: List[EvaluationResult]
71
+ execution_time: float
72
+ model_info: Dict[str, Any]
73
+ timestamp: float = field(default_factory=time.time)
74
+
75
+ class PerplexityCalculator:
76
+ """Calculate perplexity for language models"""
77
+
78
+ def __init__(self, model, tokenizer):
79
+ self.model = model
80
+ self.tokenizer = tokenizer
81
+ self.device = next(model.parameters()).device
82
+
83
+ def calculate_perplexity(self, text: str, max_length: int = 512) -> float:
84
+ """Calculate perplexity for given text"""
85
+ # Tokenize text
86
+ tokens = self.tokenizer.encode(text, return_tensors="pt", max_length=max_length, truncation=True)
87
+ tokens = tokens.to(self.device)
88
+
89
+ with torch.no_grad():
90
+ # Get model outputs
91
+ outputs = self.model(tokens)
92
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs
93
+
94
+ # Calculate cross-entropy loss
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = tokens[..., 1:].contiguous()
97
+
98
+ loss_fn = torch.nn.CrossEntropyLoss()
99
+ loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
100
+
101
+ # Convert to perplexity
102
+ perplexity = torch.exp(loss)
103
+
104
+ return perplexity.item()
105
+
106
+ class BLEUScore:
107
+ """BLEU score calculator for text generation"""
108
+
109
+ def __init__(self, n_grams: int = 4):
110
+ self.n_grams = n_grams
111
+
112
+ def calculate_bleu(self, reference: str, candidate: str) -> float:
113
+ """Calculate BLEU score between reference and candidate"""
114
+ ref_tokens = self._tokenize(reference)
115
+ cand_tokens = self._tokenize(candidate)
116
+
117
+ if len(cand_tokens) == 0:
118
+ return 0.0
119
+
120
+ # Calculate n-gram precisions
121
+ precisions = []
122
+ for n in range(1, self.n_grams + 1):
123
+ precision = self._calculate_n_gram_precision(ref_tokens, cand_tokens, n)
124
+ precisions.append(precision)
125
+
126
+ # Brevity penalty
127
+ bp = self._brevity_penalty(len(ref_tokens), len(cand_tokens))
128
+
129
+ # Calculate BLEU score
130
+ if 0 in precisions:
131
+ return 0.0
132
+
133
+ log_precisions = [math.log(p) for p in precisions]
134
+ bleu = bp * math.exp(sum(log_precisions) / len(log_precisions))
135
+
136
+ return bleu
137
+
138
+ def _tokenize(self, text: str) -> List[str]:
139
+ """Simple tokenization"""
140
+ return text.lower().split()
141
+
142
+ def _calculate_n_gram_precision(self, ref_tokens: List[str], cand_tokens: List[str], n: int) -> float:
143
+ """Calculate n-gram precision"""
144
+ if len(cand_tokens) < n:
145
+ return 0.0
146
+
147
+ # Get n-grams
148
+ ref_ngrams = self._get_ngrams(ref_tokens, n)
149
+ cand_ngrams = self._get_ngrams(cand_tokens, n)
150
+
151
+ if len(cand_ngrams) == 0:
152
+ return 0.0
153
+
154
+ # Count matches
155
+ matches = 0
156
+ for ngram in cand_ngrams:
157
+ if ngram in ref_ngrams:
158
+ matches += min(cand_ngrams[ngram], ref_ngrams[ngram])
159
+
160
+ return matches / sum(cand_ngrams.values())
161
+
162
+ def _get_ngrams(self, tokens: List[str], n: int) -> Dict[Tuple[str, ...], int]:
163
+ """Get n-gram counts"""
164
+ ngrams = defaultdict(int)
165
+ for i in range(len(tokens) - n + 1):
166
+ ngram = tuple(tokens[i:i+n])
167
+ ngrams[ngram] += 1
168
+ return ngrams
169
+
170
+ def _brevity_penalty(self, ref_len: int, cand_len: int) -> float:
171
+ """Calculate brevity penalty"""
172
+ if cand_len > ref_len:
173
+ return 1.0
174
+ elif cand_len == 0:
175
+ return 0.0
176
+ else:
177
+ return math.exp(1 - ref_len / cand_len)
178
+
179
+ class ROUGEScore:
180
+ """ROUGE score calculator"""
181
+
182
+ def __init__(self):
183
+ pass
184
+
185
+ def calculate_rouge_l(self, reference: str, candidate: str) -> float:
186
+ """Calculate ROUGE-L score"""
187
+ ref_tokens = reference.lower().split()
188
+ cand_tokens = candidate.lower().split()
189
+
190
+ if not ref_tokens or not cand_tokens:
191
+ return 0.0
192
+
193
+ # Calculate LCS
194
+ lcs_length = self._lcs_length(ref_tokens, cand_tokens)
195
+
196
+ if lcs_length == 0:
197
+ return 0.0
198
+
199
+ # Calculate precision and recall
200
+ precision = lcs_length / len(cand_tokens)
201
+ recall = lcs_length / len(ref_tokens)
202
+
203
+ # Calculate F1 score
204
+ if precision + recall == 0:
205
+ return 0.0
206
+
207
+ f1 = 2 * precision * recall / (precision + recall)
208
+ return f1
209
+
210
+ def _lcs_length(self, seq1: List[str], seq2: List[str]) -> int:
211
+ """Calculate length of longest common subsequence"""
212
+ m, n = len(seq1), len(seq2)
213
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
214
+
215
+ for i in range(1, m + 1):
216
+ for j in range(1, n + 1):
217
+ if seq1[i-1] == seq2[j-1]:
218
+ dp[i][j] = dp[i-1][j-1] + 1
219
+ else:
220
+ dp[i][j] = max(dp[i-1][j], dp[i][j-1])
221
+
222
+ return dp[m][n]
223
+
224
+ class CoherenceAnalyzer:
225
+ """Analyze text coherence and quality"""
226
+
227
+ def __init__(self):
228
+ pass
229
+
230
+ def analyze_coherence(self, text: str) -> Dict[str, float]:
231
+ """Analyze text coherence"""
232
+ sentences = self._split_sentences(text)
233
+
234
+ if len(sentences) < 2:
235
+ return {"coherence_score": 1.0, "repetition_score": 1.0, "diversity_score": 0.5}
236
+
237
+ # Calculate coherence metrics
238
+ coherence_score = self._calculate_coherence(sentences)
239
+ repetition_score = self._calculate_repetition(text)
240
+ diversity_score = self._calculate_diversity(text)
241
+
242
+ return {
243
+ "coherence_score": coherence_score,
244
+ "repetition_score": repetition_score,
245
+ "diversity_score": diversity_score
246
+ }
247
+
248
+ def _split_sentences(self, text: str) -> List[str]:
249
+ """Split text into sentences"""
250
+ # Simple sentence splitting
251
+ sentences = re.split(r'[.!?]+', text)
252
+ return [s.strip() for s in sentences if s.strip()]
253
+
254
+ def _calculate_coherence(self, sentences: List[str]) -> float:
255
+ """Calculate coherence score based on sentence similarity"""
256
+ if len(sentences) < 2:
257
+ return 1.0
258
+
259
+ similarities = []
260
+ for i in range(len(sentences) - 1):
261
+ sim = self._sentence_similarity(sentences[i], sentences[i+1])
262
+ similarities.append(sim)
263
+
264
+ return sum(similarities) / len(similarities)
265
+
266
+ def _sentence_similarity(self, sent1: str, sent2: str) -> float:
267
+ """Calculate similarity between two sentences"""
268
+ words1 = set(sent1.lower().split())
269
+ words2 = set(sent2.lower().split())
270
+
271
+ if not words1 or not words2:
272
+ return 0.0
273
+
274
+ intersection = words1.intersection(words2)
275
+ union = words1.union(words2)
276
+
277
+ return len(intersection) / len(union)
278
+
279
+ def _calculate_repetition(self, text: str) -> float:
280
+ """Calculate repetition score (lower is better)"""
281
+ words = text.lower().split()
282
+ if len(words) < 2:
283
+ return 1.0
284
+
285
+ unique_words = set(words)
286
+ repetition_ratio = len(words) / len(unique_words)
287
+
288
+ # Normalize to 0-1 scale (1 is best, no repetition)
289
+ return 1.0 / repetition_ratio
290
+
291
+ def _calculate_diversity(self, text: str) -> float:
292
+ """Calculate lexical diversity"""
293
+ words = text.lower().split()
294
+ if len(words) == 0:
295
+ return 0.0
296
+
297
+ unique_words = set(words)
298
+ return len(unique_words) / len(words)
299
+
300
+ class LatencyBenchmark:
301
+ """Benchmark model latency and throughput"""
302
+
303
+ def __init__(self, model, tokenizer):
304
+ self.model = model
305
+ self.tokenizer = tokenizer
306
+ self.device = next(model.parameters()).device
307
+
308
+ def benchmark_inference_speed(self, prompts: List[str], max_length: int = 100, num_runs: int = 5) -> Dict[str, float]:
309
+ """Benchmark inference speed"""
310
+ latencies = []
311
+ token_counts = []
312
+
313
+ for _ in range(num_runs):
314
+ for prompt in prompts:
315
+ start_time = time.time()
316
+
317
+ # Tokenize input
318
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
319
+
320
+ # Generate
321
+ with torch.no_grad():
322
+ outputs = self.model.generate(
323
+ inputs,
324
+ max_length=max_length,
325
+ do_sample=False,
326
+ pad_token_id=self.tokenizer.eos_token_id
327
+ )
328
+
329
+ end_time = time.time()
330
+
331
+ # Calculate metrics
332
+ latency = end_time - start_time
333
+ generated_tokens = outputs.shape[1] - inputs.shape[1]
334
+
335
+ latencies.append(latency)
336
+ token_counts.append(generated_tokens)
337
+
338
+ # Calculate statistics
339
+ avg_latency = np.mean(latencies)
340
+ p95_latency = np.percentile(latencies, 95)
341
+ total_tokens = sum(token_counts)
342
+ total_time = sum(latencies)
343
+ throughput = total_tokens / total_time if total_time > 0 else 0
344
+
345
+ return {
346
+ "avg_latency_ms": avg_latency * 1000,
347
+ "p95_latency_ms": p95_latency * 1000,
348
+ "throughput_tokens_per_sec": throughput,
349
+ "total_runs": len(latencies)
350
+ }
351
+
352
+ class QualityEvaluator:
353
+ """Comprehensive quality evaluation"""
354
+
355
+ def __init__(self, model, tokenizer):
356
+ self.model = model
357
+ self.tokenizer = tokenizer
358
+ self.perplexity_calc = PerplexityCalculator(model, tokenizer)
359
+ self.bleu_calc = BLEUScore()
360
+ self.rouge_calc = ROUGEScore()
361
+ self.coherence_analyzer = CoherenceAnalyzer()
362
+ self.latency_benchmark = LatencyBenchmark(model, tokenizer)
363
+
364
+ def evaluate_generation_quality(self, prompts: List[str], references: Optional[List[str]] = None, max_length: int = 100) -> List[EvaluationResult]:
365
+ """Evaluate generation quality"""
366
+ results = []
367
+
368
+ for i, prompt in enumerate(prompts):
369
+ # Generate text
370
+ generated_text = self._generate_text(prompt, max_length)
371
+
372
+ # Calculate perplexity
373
+ try:
374
+ perplexity = self.perplexity_calc.calculate_perplexity(generated_text)
375
+ results.append(EvaluationResult(
376
+ metric_name="perplexity",
377
+ score=perplexity,
378
+ details={"prompt_index": i, "generated_text": generated_text[:100]}
379
+ ))
380
+ except Exception as e:
381
+ logging.warning(f"Failed to calculate perplexity: {e}")
382
+
383
+ # Calculate coherence metrics
384
+ coherence_metrics = self.coherence_analyzer.analyze_coherence(generated_text)
385
+ for metric_name, score in coherence_metrics.items():
386
+ results.append(EvaluationResult(
387
+ metric_name=metric_name,
388
+ score=score,
389
+ details={"prompt_index": i}
390
+ ))
391
+
392
+ # Calculate BLEU and ROUGE if references are provided
393
+ if references and i < len(references):
394
+ reference = references[i]
395
+
396
+ bleu_score = self.bleu_calc.calculate_bleu(reference, generated_text)
397
+ results.append(EvaluationResult(
398
+ metric_name="bleu_score",
399
+ score=bleu_score,
400
+ details={"prompt_index": i, "reference": reference[:100]}
401
+ ))
402
+
403
+ rouge_score = self.rouge_calc.calculate_rouge_l(reference, generated_text)
404
+ results.append(EvaluationResult(
405
+ metric_name="rouge_l",
406
+ score=rouge_score,
407
+ details={"prompt_index": i}
408
+ ))
409
+
410
+ return results
411
+
412
+ def _generate_text(self, prompt: str, max_length: int) -> str:
413
+ """Generate text from prompt"""
414
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(next(self.model.parameters()).device)
415
+
416
+ with torch.no_grad():
417
+ outputs = self.model.generate(
418
+ inputs,
419
+ max_length=max_length,
420
+ do_sample=True,
421
+ temperature=0.7,
422
+ pad_token_id=self.tokenizer.eos_token_id
423
+ )
424
+
425
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
426
+ # Remove the original prompt
427
+ generated_text = generated_text[len(prompt):].strip()
428
+
429
+ return generated_text
430
+
431
+ class MambaSwarmEvaluator:
432
+ """Main evaluator for Mamba Swarm models"""
433
+
434
+ def __init__(self, swarm_engine, config: Optional[Dict[str, Any]] = None):
435
+ self.swarm_engine = swarm_engine
436
+ self.config = config or {}
437
+ self.logger = logging.getLogger(__name__)
438
+
439
+ # Initialize evaluators
440
+ self.quality_evaluator = None
441
+ self._initialize_evaluators()
442
+
443
+ # Benchmark datasets
444
+ self.benchmark_prompts = [
445
+ "The future of artificial intelligence is",
446
+ "In a world where technology advances rapidly,",
447
+ "The most important challenge facing humanity today is",
448
+ "Scientific discoveries have always been driven by",
449
+ "The relationship between humans and machines will"
450
+ ]
451
+
452
+ def _initialize_evaluators(self):
453
+ """Initialize quality evaluators"""
454
+ try:
455
+ # Get model and tokenizer from swarm engine
456
+ model = self.swarm_engine.get_model()
457
+ tokenizer = self.swarm_engine.get_tokenizer()
458
+
459
+ if model and tokenizer:
460
+ self.quality_evaluator = QualityEvaluator(model, tokenizer)
461
+ except Exception as e:
462
+ self.logger.warning(f"Failed to initialize evaluators: {e}")
463
+
464
+ def run_comprehensive_evaluation(self) -> BenchmarkResult:
465
+ """Run comprehensive evaluation of the Mamba Swarm"""
466
+ start_time = time.time()
467
+ all_results = []
468
+
469
+ # Performance benchmarks
470
+ performance_results = self._evaluate_performance()
471
+ all_results.extend(performance_results)
472
+
473
+ # Quality benchmarks
474
+ if self.quality_evaluator:
475
+ quality_results = self._evaluate_quality()
476
+ all_results.extend(quality_results)
477
+
478
+ # Scalability benchmarks
479
+ scalability_results = self._evaluate_scalability()
480
+ all_results.extend(scalability_results)
481
+
482
+ # Resource utilization
483
+ resource_results = self._evaluate_resource_utilization()
484
+ all_results.extend(resource_results)
485
+
486
+ # Calculate overall score
487
+ overall_score = self._calculate_overall_score(all_results)
488
+
489
+ execution_time = time.time() - start_time
490
+
491
+ # Get model info
492
+ model_info = self.swarm_engine.get_model_info()
493
+
494
+ return BenchmarkResult(
495
+ benchmark_name="comprehensive_evaluation",
496
+ overall_score=overall_score,
497
+ individual_metrics=all_results,
498
+ execution_time=execution_time,
499
+ model_info=model_info
500
+ )
501
+
502
+ def _evaluate_performance(self) -> List[EvaluationResult]:
503
+ """Evaluate performance metrics"""
504
+ results = []
505
+
506
+ try:
507
+ # Latency benchmark
508
+ if self.quality_evaluator:
509
+ latency_metrics = self.quality_evaluator.latency_benchmark.benchmark_inference_speed(
510
+ self.benchmark_prompts[:3] # Use subset for speed
511
+ )
512
+
513
+ for metric_name, score in latency_metrics.items():
514
+ results.append(EvaluationResult(
515
+ metric_name=f"performance_{metric_name}",
516
+ score=score,
517
+ details={"category": "performance"}
518
+ ))
519
+
520
+ # Throughput test
521
+ throughput = self._measure_throughput()
522
+ results.append(EvaluationResult(
523
+ metric_name="throughput_requests_per_sec",
524
+ score=throughput,
525
+ details={"category": "performance"}
526
+ ))
527
+
528
+ except Exception as e:
529
+ self.logger.error(f"Performance evaluation failed: {e}")
530
+
531
+ return results
532
+
533
+ def _evaluate_quality(self) -> List[EvaluationResult]:
534
+ """Evaluate generation quality"""
535
+ results = []
536
+
537
+ try:
538
+ # Quality evaluation
539
+ quality_results = self.quality_evaluator.evaluate_generation_quality(
540
+ self.benchmark_prompts
541
+ )
542
+
543
+ # Add category to results
544
+ for result in quality_results:
545
+ result.details["category"] = "quality"
546
+ results.append(result)
547
+
548
+ except Exception as e:
549
+ self.logger.error(f"Quality evaluation failed: {e}")
550
+
551
+ return results
552
+
553
+ def _evaluate_scalability(self) -> List[EvaluationResult]:
554
+ """Evaluate scalability metrics"""
555
+ results = []
556
+
557
+ try:
558
+ # Test with different loads
559
+ load_levels = [1, 5, 10]
560
+
561
+ for load in load_levels:
562
+ start_time = time.time()
563
+
564
+ # Simulate concurrent requests
565
+ tasks = []
566
+ for _ in range(load):
567
+ task = self._simulate_inference_request()
568
+ tasks.append(task)
569
+
570
+ # Wait for completion
571
+ success_count = sum(1 for task in tasks if task)
572
+ total_time = time.time() - start_time
573
+
574
+ # Calculate metrics
575
+ success_rate = success_count / load
576
+ avg_response_time = total_time / load
577
+
578
+ results.append(EvaluationResult(
579
+ metric_name=f"scalability_success_rate_load_{load}",
580
+ score=success_rate,
581
+ details={"category": "scalability", "load_level": load}
582
+ ))
583
+
584
+ results.append(EvaluationResult(
585
+ metric_name=f"scalability_avg_response_time_load_{load}",
586
+ score=avg_response_time,
587
+ details={"category": "scalability", "load_level": load}
588
+ ))
589
+
590
+ except Exception as e:
591
+ self.logger.error(f"Scalability evaluation failed: {e}")
592
+
593
+ return results
594
+
595
+ def _evaluate_resource_utilization(self) -> List[EvaluationResult]:
596
+ """Evaluate resource utilization"""
597
+ results = []
598
+
599
+ try:
600
+ # Get memory stats
601
+ memory_stats = self.swarm_engine.memory_manager.get_memory_stats()
602
+
603
+ results.append(EvaluationResult(
604
+ metric_name="memory_utilization_gb",
605
+ score=memory_stats.used_memory,
606
+ details={"category": "resources", "type": "memory"}
607
+ ))
608
+
609
+ results.append(EvaluationResult(
610
+ metric_name="gpu_memory_utilization_gb",
611
+ score=memory_stats.gpu_memory,
612
+ details={"category": "resources", "type": "gpu_memory"}
613
+ ))
614
+
615
+ # Encoder utilization
616
+ active_encoders = len(self.swarm_engine.get_active_encoders())
617
+ total_encoders = 100 # As specified in requirements
618
+
619
+ results.append(EvaluationResult(
620
+ metric_name="encoder_utilization_ratio",
621
+ score=active_encoders / total_encoders,
622
+ details={"category": "resources", "active": active_encoders, "total": total_encoders}
623
+ ))
624
+
625
+ except Exception as e:
626
+ self.logger.error(f"Resource evaluation failed: {e}")
627
+
628
+ return results
629
+
630
+ def _measure_throughput(self) -> float:
631
+ """Measure system throughput"""
632
+ try:
633
+ num_requests = 10
634
+ start_time = time.time()
635
+
636
+ for _ in range(num_requests):
637
+ self._simulate_inference_request()
638
+
639
+ total_time = time.time() - start_time
640
+ throughput = num_requests / total_time
641
+
642
+ return throughput
643
+ except Exception as e:
644
+ self.logger.error(f"Throughput measurement failed: {e}")
645
+ return 0.0
646
+
647
+ def _simulate_inference_request(self) -> bool:
648
+ """Simulate an inference request"""
649
+ try:
650
+ prompt = "This is a test prompt for evaluation."
651
+ result = self.swarm_engine.generate(prompt, max_length=50)
652
+ return result is not None
653
+ except Exception as e:
654
+ self.logger.error(f"Simulated request failed: {e}")
655
+ return False
656
+
657
+ def _calculate_overall_score(self, results: List[EvaluationResult]) -> float:
658
+ """Calculate overall benchmark score"""
659
+ if not results:
660
+ return 0.0
661
+
662
+ # Weight different categories
663
+ weights = {
664
+ "performance": 0.3,
665
+ "quality": 0.4,
666
+ "scalability": 0.2,
667
+ "resources": 0.1
668
+ }
669
+
670
+ category_scores = defaultdict(list)
671
+
672
+ for result in results:
673
+ category = result.details.get("category", "other")
674
+
675
+ # Normalize scores based on metric type
676
+ normalized_score = self._normalize_score(result)
677
+ category_scores[category].append(normalized_score)
678
+
679
+ # Calculate weighted average
680
+ total_score = 0.0
681
+ total_weight = 0.0
682
+
683
+ for category, scores in category_scores.items():
684
+ if category in weights and scores:
685
+ avg_score = sum(scores) / len(scores)
686
+ weight = weights[category]
687
+ total_score += avg_score * weight
688
+ total_weight += weight
689
+
690
+ return total_score / total_weight if total_weight > 0 else 0.0
691
+
692
+ def _normalize_score(self, result: EvaluationResult) -> float:
693
+ """Normalize score to 0-1 range"""
694
+ metric_name = result.metric_name
695
+ score = result.score
696
+
697
+ # Define normalization rules for different metrics
698
+ if "perplexity" in metric_name:
699
+ # Lower is better, normalize to 0-1 where 1 is best
700
+ return max(0.0, 1.0 - min(score / 100.0, 1.0))
701
+ elif "latency" in metric_name or "response_time" in metric_name:
702
+ # Lower is better, normalize based on reasonable thresholds
703
+ return max(0.0, 1.0 - min(score / 1000.0, 1.0)) # 1 second threshold
704
+ elif "throughput" in metric_name:
705
+ # Higher is better, normalize based on expected range
706
+ return min(score / 100.0, 1.0) # 100 requests/sec as max
707
+ elif "success_rate" in metric_name or "utilization" in metric_name:
708
+ # Already in 0-1 range
709
+ return score
710
+ else:
711
+ # Default: assume higher is better and clamp to 0-1
712
+ return min(max(score, 0.0), 1.0)
713
+
714
+ def export_evaluation_report(self, result: BenchmarkResult, filename: Optional[str] = None) -> str:
715
+ """Export evaluation report to file"""
716
+ if not filename:
717
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
718
+ filename = f"mamba_swarm_evaluation_{timestamp}.json"
719
+
720
+ # Convert to serializable format
721
+ report = {
722
+ "benchmark_name": result.benchmark_name,
723
+ "overall_score": result.overall_score,
724
+ "execution_time": result.execution_time,
725
+ "timestamp": result.timestamp,
726
+ "model_info": result.model_info,
727
+ "metrics": [
728
+ {
729
+ "name": metric.metric_name,
730
+ "score": metric.score,
731
+ "details": metric.details,
732
+ "timestamp": metric.timestamp
733
+ }
734
+ for metric in result.individual_metrics
735
+ ]
736
+ }
737
+
738
+ with open(filename, 'w') as f:
739
+ json.dump(report, f, indent=2, default=str)
740
+
741
+ self.logger.info(f"Evaluation report saved to {filename}")
742
+ return filename
743
+
744
+ # Example usage
745
+ if __name__ == "__main__":
746
+ # This would be used with actual SwarmEngine instance
747
+ # evaluator = MambaSwarmEvaluator(swarm_engine)
748
+ # result = evaluator.run_comprehensive_evaluation()
749
+ # report_file = evaluator.export_evaluation_report(result)
750
+
751
+ # Demo of individual components
752
+ print("Mamba Swarm Evaluator components initialized successfully")
753
+
754
+ # Example BLEU calculation
755
+ bleu_calc = BLEUScore()
756
+ reference = "The quick brown fox jumps over the lazy dog"
757
+ candidate = "The fast brown fox leaps over the sleepy dog"
758
+ bleu_score = bleu_calc.calculate_bleu(reference, candidate)
759
+ print(f"BLEU score: {bleu_score:.3f}")
760
+
761
+ # Example coherence analysis
762
+ coherence_analyzer = CoherenceAnalyzer()
763
+ text = "This is a coherent text. It flows well from sentence to sentence. The ideas are connected logically."
764
+ coherence_metrics = coherence_analyzer.analyze_coherence(text)
765
+ print(f"Coherence metrics: {coherence_metrics}")
monitoring/metrics.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Metrics Collection and Monitoring System for Mamba Swarm
3
+ Tracks performance, resource usage, and model behavior
4
+ """
5
+
6
+ import time
7
+ import threading
8
+ import json
9
+ import logging
10
+ from typing import Dict, List, Any, Optional, Callable
11
+ from dataclasses import dataclass, field, asdict
12
+ from collections import defaultdict, deque
13
+ from enum import Enum
14
+ import torch
15
+ import psutil
16
+ import numpy as np
17
+ from datetime import datetime, timedelta
18
+
19
+ class MetricType(Enum):
20
+ COUNTER = "counter"
21
+ GAUGE = "gauge"
22
+ HISTOGRAM = "histogram"
23
+ SUMMARY = "summary"
24
+
25
+ @dataclass
26
+ class MetricPoint:
27
+ timestamp: float
28
+ value: float
29
+ labels: Dict[str, str] = field(default_factory=dict)
30
+
31
+ @dataclass
32
+ class HistogramBucket:
33
+ upper_bound: float
34
+ count: int = 0
35
+
36
+ class Metric:
37
+ """Base metric class"""
38
+
39
+ def __init__(self, name: str, description: str, labels: Optional[List[str]] = None):
40
+ self.name = name
41
+ self.description = description
42
+ self.labels = labels or []
43
+ self.lock = threading.Lock()
44
+ self.created_at = time.time()
45
+
46
+ class Counter(Metric):
47
+ """Counter metric - monotonically increasing"""
48
+
49
+ def __init__(self, name: str, description: str, labels: Optional[List[str]] = None):
50
+ super().__init__(name, description, labels)
51
+ self.values = defaultdict(float)
52
+
53
+ def inc(self, value: float = 1.0, **label_values):
54
+ """Increment counter"""
55
+ label_key = self._make_label_key(label_values)
56
+ with self.lock:
57
+ self.values[label_key] += value
58
+
59
+ def get(self, **label_values) -> float:
60
+ """Get counter value"""
61
+ label_key = self._make_label_key(label_values)
62
+ return self.values.get(label_key, 0.0)
63
+
64
+ def _make_label_key(self, label_values: Dict[str, str]) -> str:
65
+ """Create key from label values"""
66
+ return "|".join(f"{k}={v}" for k, v in sorted(label_values.items()))
67
+
68
+ class Gauge(Metric):
69
+ """Gauge metric - can go up and down"""
70
+
71
+ def __init__(self, name: str, description: str, labels: Optional[List[str]] = None):
72
+ super().__init__(name, description, labels)
73
+ self.values = defaultdict(float)
74
+
75
+ def set(self, value: float, **label_values):
76
+ """Set gauge value"""
77
+ label_key = self._make_label_key(label_values)
78
+ with self.lock:
79
+ self.values[label_key] = value
80
+
81
+ def inc(self, value: float = 1.0, **label_values):
82
+ """Increment gauge"""
83
+ label_key = self._make_label_key(label_values)
84
+ with self.lock:
85
+ self.values[label_key] += value
86
+
87
+ def dec(self, value: float = 1.0, **label_values):
88
+ """Decrement gauge"""
89
+ self.inc(-value, **label_values)
90
+
91
+ def get(self, **label_values) -> float:
92
+ """Get gauge value"""
93
+ label_key = self._make_label_key(label_values)
94
+ return self.values.get(label_key, 0.0)
95
+
96
+ def _make_label_key(self, label_values: Dict[str, str]) -> str:
97
+ return "|".join(f"{k}={v}" for k, v in sorted(label_values.items()))
98
+
99
+ class Histogram(Metric):
100
+ """Histogram metric - tracks distribution of values"""
101
+
102
+ def __init__(self, name: str, description: str, buckets: Optional[List[float]] = None, labels: Optional[List[str]] = None):
103
+ super().__init__(name, description, labels)
104
+ self.buckets = buckets or [0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float('inf')]
105
+ self.bucket_counts = defaultdict(lambda: defaultdict(int))
106
+ self.sums = defaultdict(float)
107
+ self.counts = defaultdict(int)
108
+
109
+ def observe(self, value: float, **label_values):
110
+ """Observe a value"""
111
+ label_key = self._make_label_key(label_values)
112
+ with self.lock:
113
+ self.sums[label_key] += value
114
+ self.counts[label_key] += 1
115
+
116
+ for bucket in self.buckets:
117
+ if value <= bucket:
118
+ self.bucket_counts[label_key][bucket] += 1
119
+
120
+ def get_buckets(self, **label_values) -> Dict[float, int]:
121
+ """Get bucket counts"""
122
+ label_key = self._make_label_key(label_values)
123
+ return dict(self.bucket_counts[label_key])
124
+
125
+ def get_sum(self, **label_values) -> float:
126
+ """Get sum of observed values"""
127
+ label_key = self._make_label_key(label_values)
128
+ return self.sums[label_key]
129
+
130
+ def get_count(self, **label_values) -> int:
131
+ """Get count of observations"""
132
+ label_key = self._make_label_key(label_values)
133
+ return self.counts[label_key]
134
+
135
+ def _make_label_key(self, label_values: Dict[str, str]) -> str:
136
+ return "|".join(f"{k}={v}" for k, v in sorted(label_values.items()))
137
+
138
+ class Summary(Metric):
139
+ """Summary metric - tracks quantiles"""
140
+
141
+ def __init__(self, name: str, description: str, quantiles: Optional[List[float]] = None, labels: Optional[List[str]] = None, max_age: float = 300.0):
142
+ super().__init__(name, description, labels)
143
+ self.quantiles = quantiles or [0.5, 0.9, 0.95, 0.99]
144
+ self.max_age = max_age
145
+ self.observations = defaultdict(lambda: deque())
146
+ self.sums = defaultdict(float)
147
+ self.counts = defaultdict(int)
148
+
149
+ def observe(self, value: float, **label_values):
150
+ """Observe a value"""
151
+ label_key = self._make_label_key(label_values)
152
+ timestamp = time.time()
153
+
154
+ with self.lock:
155
+ self.observations[label_key].append((timestamp, value))
156
+ self.sums[label_key] += value
157
+ self.counts[label_key] += 1
158
+
159
+ # Clean old observations
160
+ self._clean_old_observations(label_key, timestamp)
161
+
162
+ def get_quantile(self, quantile: float, **label_values) -> float:
163
+ """Get quantile value"""
164
+ label_key = self._make_label_key(label_values)
165
+ with self.lock:
166
+ obs = self.observations[label_key]
167
+ if not obs:
168
+ return 0.0
169
+
170
+ values = [v for _, v in obs]
171
+ values.sort()
172
+ index = int(quantile * len(values))
173
+ return values[min(index, len(values) - 1)]
174
+
175
+ def get_sum(self, **label_values) -> float:
176
+ """Get sum of observed values"""
177
+ label_key = self._make_label_key(label_values)
178
+ return self.sums[label_key]
179
+
180
+ def get_count(self, **label_values) -> int:
181
+ """Get count of observations"""
182
+ label_key = self._make_label_key(label_values)
183
+ return self.counts[label_key]
184
+
185
+ def _clean_old_observations(self, label_key: str, current_time: float):
186
+ """Remove old observations"""
187
+ cutoff_time = current_time - self.max_age
188
+ obs = self.observations[label_key]
189
+
190
+ while obs and obs[0][0] < cutoff_time:
191
+ _, value = obs.popleft()
192
+ self.sums[label_key] -= value
193
+ self.counts[label_key] -= 1
194
+
195
+ def _make_label_key(self, label_values: Dict[str, str]) -> str:
196
+ return "|".join(f"{k}={v}" for k, v in sorted(label_values.items()))
197
+
198
+ class MetricsRegistry:
199
+ """Registry for all metrics"""
200
+
201
+ def __init__(self):
202
+ self.metrics: Dict[str, Metric] = {}
203
+ self.lock = threading.Lock()
204
+
205
+ def register(self, metric: Metric):
206
+ """Register a metric"""
207
+ with self.lock:
208
+ if metric.name in self.metrics:
209
+ raise ValueError(f"Metric {metric.name} already registered")
210
+ self.metrics[metric.name] = metric
211
+
212
+ def get_metric(self, name: str) -> Optional[Metric]:
213
+ """Get metric by name"""
214
+ return self.metrics.get(name)
215
+
216
+ def get_all_metrics(self) -> Dict[str, Metric]:
217
+ """Get all metrics"""
218
+ return self.metrics.copy()
219
+
220
+ class MambaSwarmMetrics:
221
+ """Metrics collector for Mamba Swarm"""
222
+
223
+ def __init__(self):
224
+ self.registry = MetricsRegistry()
225
+ self.logger = logging.getLogger(__name__)
226
+ self._setup_default_metrics()
227
+
228
+ # System monitoring
229
+ self.monitoring_thread = None
230
+ self.monitoring_interval = 10.0 # seconds
231
+ self.should_monitor = False
232
+
233
+ def _setup_default_metrics(self):
234
+ """Setup default metrics"""
235
+ # Request metrics
236
+ self.requests_total = Counter("requests_total", "Total number of requests", ["method", "endpoint", "status"])
237
+ self.request_duration = Histogram("request_duration_seconds", "Request duration in seconds", labels=["method", "endpoint"])
238
+
239
+ # Model metrics
240
+ self.inference_duration = Histogram("inference_duration_seconds", "Inference duration in seconds", labels=["model_unit"])
241
+ self.tokens_generated = Counter("tokens_generated_total", "Total tokens generated", ["model_unit"])
242
+ self.model_load = Gauge("model_load", "Current model load", ["model_unit"])
243
+
244
+ # System metrics
245
+ self.memory_usage = Gauge("memory_usage_bytes", "Memory usage in bytes", ["type"])
246
+ self.gpu_utilization = Gauge("gpu_utilization_percent", "GPU utilization percentage", ["device"])
247
+ self.active_connections = Gauge("active_connections", "Number of active connections")
248
+
249
+ # Swarm metrics
250
+ self.encoder_utilization = Gauge("encoder_utilization", "Encoder utilization", ["encoder_id"])
251
+ self.routing_decisions = Counter("routing_decisions_total", "Routing decisions", ["strategy", "target"])
252
+ self.load_balancing_decisions = Counter("load_balancing_decisions_total", "Load balancing decisions", ["algorithm"])
253
+
254
+ # Error metrics
255
+ self.errors_total = Counter("errors_total", "Total number of errors", ["type", "component"])
256
+
257
+ # Register all metrics
258
+ for attr_name in dir(self):
259
+ attr = getattr(self, attr_name)
260
+ if isinstance(attr, Metric):
261
+ self.registry.register(attr)
262
+
263
+ def start_monitoring(self):
264
+ """Start system monitoring"""
265
+ if self.monitoring_thread is not None:
266
+ return
267
+
268
+ self.should_monitor = True
269
+ self.monitoring_thread = threading.Thread(target=self._monitoring_loop, daemon=True)
270
+ self.monitoring_thread.start()
271
+ self.logger.info("Started metrics monitoring")
272
+
273
+ def stop_monitoring(self):
274
+ """Stop system monitoring"""
275
+ self.should_monitor = False
276
+ if self.monitoring_thread:
277
+ self.monitoring_thread.join(timeout=5.0)
278
+ self.monitoring_thread = None
279
+ self.logger.info("Stopped metrics monitoring")
280
+
281
+ def _monitoring_loop(self):
282
+ """System monitoring loop"""
283
+ while self.should_monitor:
284
+ try:
285
+ self._collect_system_metrics()
286
+ time.sleep(self.monitoring_interval)
287
+ except Exception as e:
288
+ self.logger.error(f"Error in monitoring loop: {e}")
289
+
290
+ def _collect_system_metrics(self):
291
+ """Collect system metrics"""
292
+ # Memory metrics
293
+ memory = psutil.virtual_memory()
294
+ self.memory_usage.set(memory.used, type="system")
295
+ self.memory_usage.set(memory.available, type="available")
296
+
297
+ # GPU metrics
298
+ if torch.cuda.is_available():
299
+ for i in range(torch.cuda.device_count()):
300
+ # GPU memory
301
+ gpu_memory = torch.cuda.memory_allocated(i)
302
+ self.memory_usage.set(gpu_memory, type=f"gpu_{i}")
303
+
304
+ # GPU utilization (simplified)
305
+ # In practice, you might use nvidia-ml-py for more detailed metrics
306
+ utilization = min(100.0, (gpu_memory / torch.cuda.max_memory_allocated(i)) * 100) if torch.cuda.max_memory_allocated(i) > 0 else 0.0
307
+ self.gpu_utilization.set(utilization, device=f"cuda:{i}")
308
+
309
+ def record_request(self, method: str, endpoint: str, status_code: int, duration: float):
310
+ """Record request metrics"""
311
+ self.requests_total.inc(method=method, endpoint=endpoint, status=str(status_code))
312
+ self.request_duration.observe(duration, method=method, endpoint=endpoint)
313
+
314
+ def record_inference(self, model_unit: str, duration: float, tokens: int):
315
+ """Record inference metrics"""
316
+ self.inference_duration.observe(duration, model_unit=model_unit)
317
+ self.tokens_generated.inc(tokens, model_unit=model_unit)
318
+
319
+ def record_error(self, error_type: str, component: str):
320
+ """Record error metrics"""
321
+ self.errors_total.inc(type=error_type, component=component)
322
+
323
+ def update_model_load(self, model_unit: str, load: float):
324
+ """Update model load"""
325
+ self.model_load.set(load, model_unit=model_unit)
326
+
327
+ def update_encoder_utilization(self, encoder_id: str, utilization: float):
328
+ """Update encoder utilization"""
329
+ self.encoder_utilization.set(utilization, encoder_id=encoder_id)
330
+
331
+ def record_routing_decision(self, strategy: str, target: str):
332
+ """Record routing decision"""
333
+ self.routing_decisions.inc(strategy=strategy, target=target)
334
+
335
+ def get_metrics_summary(self) -> Dict[str, Any]:
336
+ """Get metrics summary"""
337
+ summary = {}
338
+
339
+ for name, metric in self.registry.get_all_metrics().items():
340
+ if isinstance(metric, Counter):
341
+ summary[name] = {
342
+ "type": "counter",
343
+ "values": dict(metric.values)
344
+ }
345
+ elif isinstance(metric, Gauge):
346
+ summary[name] = {
347
+ "type": "gauge",
348
+ "values": dict(metric.values)
349
+ }
350
+ elif isinstance(metric, Histogram):
351
+ summary[name] = {
352
+ "type": "histogram",
353
+ "buckets": {k: dict(v) for k, v in metric.bucket_counts.items()},
354
+ "sums": dict(metric.sums),
355
+ "counts": dict(metric.counts)
356
+ }
357
+ elif isinstance(metric, Summary):
358
+ summary[name] = {
359
+ "type": "summary",
360
+ "sums": dict(metric.sums),
361
+ "counts": dict(metric.counts),
362
+ "quantiles": {
363
+ q: {k: metric.get_quantile(q, **self._parse_label_key(k)) for k in metric.observations.keys()}
364
+ for q in metric.quantiles
365
+ }
366
+ }
367
+
368
+ return summary
369
+
370
+ def _parse_label_key(self, label_key: str) -> Dict[str, str]:
371
+ """Parse label key back to dictionary"""
372
+ if not label_key:
373
+ return {}
374
+
375
+ labels = {}
376
+ for pair in label_key.split("|"):
377
+ if "=" in pair:
378
+ k, v = pair.split("=", 1)
379
+ labels[k] = v
380
+ return labels
381
+
382
+ def export_prometheus_format(self) -> str:
383
+ """Export metrics in Prometheus format"""
384
+ output = []
385
+
386
+ for name, metric in self.registry.get_all_metrics().items():
387
+ # Help text
388
+ output.append(f"# HELP {name} {metric.description}")
389
+
390
+ if isinstance(metric, Counter):
391
+ output.append(f"# TYPE {name} counter")
392
+ for label_key, value in metric.values.items():
393
+ labels = self._format_prometheus_labels(label_key)
394
+ output.append(f"{name}{labels} {value}")
395
+
396
+ elif isinstance(metric, Gauge):
397
+ output.append(f"# TYPE {name} gauge")
398
+ for label_key, value in metric.values.items():
399
+ labels = self._format_prometheus_labels(label_key)
400
+ output.append(f"{name}{labels} {value}")
401
+
402
+ elif isinstance(metric, Histogram):
403
+ output.append(f"# TYPE {name} histogram")
404
+ for label_key in metric.bucket_counts.keys():
405
+ labels_dict = self._parse_label_key(label_key)
406
+
407
+ # Buckets
408
+ for bucket, count in metric.bucket_counts[label_key].items():
409
+ bucket_labels = {**labels_dict, "le": str(bucket)}
410
+ bucket_label_str = self._format_prometheus_labels_dict(bucket_labels)
411
+ output.append(f"{name}_bucket{bucket_label_str} {count}")
412
+
413
+ # Sum and count
414
+ base_labels = self._format_prometheus_labels(label_key)
415
+ output.append(f"{name}_sum{base_labels} {metric.sums[label_key]}")
416
+ output.append(f"{name}_count{base_labels} {metric.counts[label_key]}")
417
+
418
+ elif isinstance(metric, Summary):
419
+ output.append(f"# TYPE {name} summary")
420
+ for label_key in metric.observations.keys():
421
+ labels_dict = self._parse_label_key(label_key)
422
+
423
+ # Quantiles
424
+ for quantile in metric.quantiles:
425
+ quantile_labels = {**labels_dict, "quantile": str(quantile)}
426
+ quantile_label_str = self._format_prometheus_labels_dict(quantile_labels)
427
+ quantile_value = metric.get_quantile(quantile, **labels_dict)
428
+ output.append(f"{name}{quantile_label_str} {quantile_value}")
429
+
430
+ # Sum and count
431
+ base_labels = self._format_prometheus_labels(label_key)
432
+ output.append(f"{name}_sum{base_labels} {metric.sums[label_key]}")
433
+ output.append(f"{name}_count{base_labels} {metric.counts[label_key]}")
434
+
435
+ output.append("") # Empty line between metrics
436
+
437
+ return "\n".join(output)
438
+
439
+ def _format_prometheus_labels(self, label_key: str) -> str:
440
+ """Format labels for Prometheus"""
441
+ if not label_key:
442
+ return ""
443
+
444
+ labels = self._parse_label_key(label_key)
445
+ return self._format_prometheus_labels_dict(labels)
446
+
447
+ def _format_prometheus_labels_dict(self, labels: Dict[str, str]) -> str:
448
+ """Format label dictionary for Prometheus"""
449
+ if not labels:
450
+ return ""
451
+
452
+ formatted_labels = []
453
+ for k, v in sorted(labels.items()):
454
+ # Escape quotes and backslashes
455
+ escaped_value = v.replace("\\", "\\\\").replace('"', '\\"')
456
+ formatted_labels.append(f'{k}="{escaped_value}"')
457
+
458
+ return "{" + ",".join(formatted_labels) + "}"
459
+
460
+ # Context managers for timing
461
+ class timer:
462
+ """Context manager for timing operations"""
463
+
464
+ def __init__(self, metric: Histogram, **labels):
465
+ self.metric = metric
466
+ self.labels = labels
467
+ self.start_time = None
468
+
469
+ def __enter__(self):
470
+ self.start_time = time.time()
471
+ return self
472
+
473
+ def __exit__(self, exc_type, exc_val, exc_tb):
474
+ if self.start_time is not None:
475
+ duration = time.time() - self.start_time
476
+ self.metric.observe(duration, **self.labels)
477
+
478
+ class request_timer:
479
+ """Context manager for timing requests"""
480
+
481
+ def __init__(self, metrics: MambaSwarmMetrics, method: str, endpoint: str):
482
+ self.metrics = metrics
483
+ self.method = method
484
+ self.endpoint = endpoint
485
+ self.start_time = None
486
+ self.status_code = 200
487
+
488
+ def __enter__(self):
489
+ self.start_time = time.time()
490
+ return self
491
+
492
+ def __exit__(self, exc_type, exc_val, exc_tb):
493
+ if exc_type is not None:
494
+ self.status_code = 500
495
+
496
+ if self.start_time is not None:
497
+ duration = time.time() - self.start_time
498
+ self.metrics.record_request(self.method, self.endpoint, self.status_code, duration)
499
+
500
+ def set_status(self, status_code: int):
501
+ """Set the response status code"""
502
+ self.status_code = status_code
503
+
504
+ # Decorator for automatic metrics collection
505
+ def measure_time(metric_name: str, **labels):
506
+ """Decorator to measure function execution time"""
507
+ def decorator(func):
508
+ def wrapper(*args, **kwargs):
509
+ # Assume first argument is self and has metrics attribute
510
+ if args and hasattr(args[0], 'metrics'):
511
+ metrics = args[0].metrics
512
+ metric = metrics.registry.get_metric(metric_name)
513
+ if metric and isinstance(metric, Histogram):
514
+ with timer(metric, **labels):
515
+ return func(*args, **kwargs)
516
+
517
+ return func(*args, **kwargs)
518
+ return wrapper
519
+ return decorator
520
+
521
+ # Metrics aggregator for multiple instances
522
+ class MetricsAggregator:
523
+ """Aggregates metrics from multiple Mamba Swarm instances"""
524
+
525
+ def __init__(self):
526
+ self.instances: Dict[str, MambaSwarmMetrics] = {}
527
+ self.lock = threading.Lock()
528
+
529
+ def add_instance(self, instance_id: str, metrics: MambaSwarmMetrics):
530
+ """Add metrics instance"""
531
+ with self.lock:
532
+ self.instances[instance_id] = metrics
533
+
534
+ def remove_instance(self, instance_id: str):
535
+ """Remove metrics instance"""
536
+ with self.lock:
537
+ self.instances.pop(instance_id, None)
538
+
539
+ def get_aggregated_summary(self) -> Dict[str, Any]:
540
+ """Get aggregated metrics summary"""
541
+ aggregated = defaultdict(lambda: defaultdict(float))
542
+
543
+ with self.lock:
544
+ for instance_id, metrics in self.instances.items():
545
+ summary = metrics.get_metrics_summary()
546
+
547
+ for metric_name, metric_data in summary.items():
548
+ if metric_data["type"] in ["counter", "gauge"]:
549
+ for label_key, value in metric_data["values"].items():
550
+ key = f"{metric_name}_{label_key}" if label_key else metric_name
551
+
552
+ if metric_data["type"] == "counter":
553
+ aggregated[key]["sum"] += value
554
+ else: # gauge
555
+ aggregated[key]["avg"] = (aggregated[key].get("avg", 0) + value) / 2
556
+ aggregated[key]["instances"] = aggregated[key].get("instances", 0) + 1
557
+
558
+ return dict(aggregated)
559
+
560
+ # FastAPI integration
561
+ from fastapi import FastAPI, Response
562
+ from fastapi.responses import PlainTextResponse
563
+
564
+ def add_metrics_endpoints(app: FastAPI, metrics: MambaSwarmMetrics):
565
+ """Add metrics endpoints to FastAPI app"""
566
+
567
+ @app.get("/metrics")
568
+ async def get_metrics():
569
+ """Get metrics in JSON format"""
570
+ return metrics.get_metrics_summary()
571
+
572
+ @app.get("/metrics/prometheus")
573
+ async def get_prometheus_metrics():
574
+ """Get metrics in Prometheus format"""
575
+ prometheus_data = metrics.export_prometheus_format()
576
+ return PlainTextResponse(prometheus_data, media_type="text/plain")
577
+
578
+ @app.middleware("http")
579
+ async def metrics_middleware(request, call_next):
580
+ """Middleware to collect request metrics"""
581
+ method = request.method
582
+ path = request.url.path
583
+
584
+ with request_timer(metrics, method, path) as timer_ctx:
585
+ response = await call_next(request)
586
+ timer_ctx.set_status(response.status_code)
587
+ return response
588
+
589
+ # Example usage
590
+ if __name__ == "__main__":
591
+ # Create metrics instance
592
+ metrics = MambaSwarmMetrics()
593
+ metrics.start_monitoring()
594
+
595
+ # Example metric recording
596
+ metrics.record_request("POST", "/generate", 200, 0.5)
597
+ metrics.record_inference("encoder_1", 0.3, 50)
598
+ metrics.update_encoder_utilization("encoder_1", 0.8)
599
+
600
+ # Get summary
601
+ summary = metrics.get_metrics_summary()
602
+ print(json.dumps(summary, indent=2))
603
+
604
+ # Export Prometheus format
605
+ prometheus_data = metrics.export_prometheus_format()
606
+ print("\nPrometheus format:")
607
+ print(prometheus_data)
608
+
609
+ metrics.stop_monitoring()
monitoring/profiler.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Performance Profiler for Mamba Swarm
3
+ Advanced profiling tools for performance analysis and optimization
4
+ """
5
+
6
+ import time
7
+ import cProfile
8
+ import pstats
9
+ import io
10
+ import threading
11
+ import functools
12
+ import traceback
13
+ import psutil
14
+ import torch
15
+ import numpy as np
16
+ from typing import Dict, List, Any, Optional, Callable, Union
17
+ from dataclasses import dataclass, field
18
+ from collections import defaultdict, deque
19
+ from contextlib import contextmanager
20
+ import logging
21
+ import json
22
+ from datetime import datetime
23
+ import os
24
+ import gc
25
+
26
+ @dataclass
27
+ class ProfileResult:
28
+ function_name: str
29
+ total_time: float
30
+ cumulative_time: float
31
+ call_count: int
32
+ per_call_time: float
33
+ filename: str
34
+ line_number: int
35
+
36
+ @dataclass
37
+ class MemorySnapshot:
38
+ timestamp: float
39
+ total_memory: float
40
+ gpu_memory: float
41
+ python_objects: int
42
+ tensor_count: int
43
+ cache_size: float
44
+
45
+ @dataclass
46
+ class PerformanceProfile:
47
+ timestamp: float
48
+ duration: float
49
+ cpu_usage: float
50
+ memory_usage: float
51
+ gpu_usage: float
52
+ function_calls: List[ProfileResult]
53
+ memory_snapshots: List[MemorySnapshot]
54
+ bottlenecks: List[str]
55
+ recommendations: List[str]
56
+
57
+ class FunctionTimer:
58
+ """Timer for individual function calls"""
59
+
60
+ def __init__(self, name: str):
61
+ self.name = name
62
+ self.calls = []
63
+ self.total_time = 0.0
64
+ self.call_count = 0
65
+ self.min_time = float('inf')
66
+ self.max_time = 0.0
67
+ self.lock = threading.Lock()
68
+
69
+ def add_call(self, duration: float):
70
+ """Add a function call duration"""
71
+ with self.lock:
72
+ self.calls.append(duration)
73
+ self.total_time += duration
74
+ self.call_count += 1
75
+ self.min_time = min(self.min_time, duration)
76
+ self.max_time = max(self.max_time, duration)
77
+
78
+ # Keep only recent calls
79
+ if len(self.calls) > 1000:
80
+ old_call = self.calls.pop(0)
81
+ self.total_time -= old_call
82
+ self.call_count -= 1
83
+
84
+ @property
85
+ def avg_time(self) -> float:
86
+ return self.total_time / max(self.call_count, 1)
87
+
88
+ @property
89
+ def percentile_95(self) -> float:
90
+ if not self.calls:
91
+ return 0.0
92
+ sorted_calls = sorted(self.calls)
93
+ index = int(0.95 * len(sorted_calls))
94
+ return sorted_calls[min(index, len(sorted_calls) - 1)]
95
+
96
+ def get_stats(self) -> Dict[str, Any]:
97
+ return {
98
+ "name": self.name,
99
+ "total_time": self.total_time,
100
+ "call_count": self.call_count,
101
+ "avg_time": self.avg_time,
102
+ "min_time": self.min_time if self.min_time != float('inf') else 0.0,
103
+ "max_time": self.max_time,
104
+ "percentile_95": self.percentile_95
105
+ }
106
+
107
+ class MemoryProfiler:
108
+ """Memory usage profiler"""
109
+
110
+ def __init__(self, sample_interval: float = 0.1):
111
+ self.sample_interval = sample_interval
112
+ self.snapshots = deque(maxlen=1000)
113
+ self.monitoring = False
114
+ self.monitor_thread = None
115
+ self.lock = threading.Lock()
116
+
117
+ def start_monitoring(self):
118
+ """Start memory monitoring"""
119
+ if self.monitoring:
120
+ return
121
+
122
+ self.monitoring = True
123
+ self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
124
+ self.monitor_thread.start()
125
+
126
+ def stop_monitoring(self):
127
+ """Stop memory monitoring"""
128
+ self.monitoring = False
129
+ if self.monitor_thread:
130
+ self.monitor_thread.join(timeout=1.0)
131
+
132
+ def _monitor_loop(self):
133
+ """Memory monitoring loop"""
134
+ while self.monitoring:
135
+ try:
136
+ snapshot = self._take_snapshot()
137
+ with self.lock:
138
+ self.snapshots.append(snapshot)
139
+ time.sleep(self.sample_interval)
140
+ except Exception as e:
141
+ logging.error(f"Memory monitoring error: {e}")
142
+
143
+ def _take_snapshot(self) -> MemorySnapshot:
144
+ """Take a memory snapshot"""
145
+ # System memory
146
+ memory = psutil.virtual_memory()
147
+ total_memory = memory.used / (1024**3) # GB
148
+
149
+ # GPU memory
150
+ gpu_memory = 0.0
151
+ if torch.cuda.is_available():
152
+ gpu_memory = torch.cuda.memory_allocated() / (1024**3)
153
+
154
+ # Python objects
155
+ python_objects = len(gc.get_objects())
156
+
157
+ # Tensor count
158
+ tensor_count = 0
159
+ for obj in gc.get_objects():
160
+ if isinstance(obj, torch.Tensor):
161
+ tensor_count += 1
162
+
163
+ # Cache size estimation
164
+ cache_size = 0.0 # Could be calculated based on specific cache implementations
165
+
166
+ return MemorySnapshot(
167
+ timestamp=time.time(),
168
+ total_memory=total_memory,
169
+ gpu_memory=gpu_memory,
170
+ python_objects=python_objects,
171
+ tensor_count=tensor_count,
172
+ cache_size=cache_size
173
+ )
174
+
175
+ def get_peak_memory(self) -> float:
176
+ """Get peak memory usage"""
177
+ with self.lock:
178
+ if not self.snapshots:
179
+ return 0.0
180
+ return max(snapshot.total_memory + snapshot.gpu_memory for snapshot in self.snapshots)
181
+
182
+ def get_memory_trend(self) -> List[float]:
183
+ """Get memory usage trend"""
184
+ with self.lock:
185
+ return [snapshot.total_memory + snapshot.gpu_memory for snapshot in self.snapshots]
186
+
187
+ class CPUProfiler:
188
+ """CPU profiling using cProfile"""
189
+
190
+ def __init__(self):
191
+ self.profiler = None
192
+ self.profiling = False
193
+ self.lock = threading.Lock()
194
+
195
+ def start_profiling(self):
196
+ """Start CPU profiling"""
197
+ with self.lock:
198
+ if self.profiling:
199
+ return
200
+
201
+ self.profiler = cProfile.Profile()
202
+ self.profiler.enable()
203
+ self.profiling = True
204
+
205
+ def stop_profiling(self) -> List[ProfileResult]:
206
+ """Stop CPU profiling and return results"""
207
+ with self.lock:
208
+ if not self.profiling or not self.profiler:
209
+ return []
210
+
211
+ self.profiler.disable()
212
+ self.profiling = False
213
+
214
+ # Analyze results
215
+ s = io.StringIO()
216
+ stats = pstats.Stats(self.profiler, stream=s)
217
+ stats.sort_stats('cumulative')
218
+
219
+ results = []
220
+ for func, (call_count, total_time, cumulative_time, callers) in stats.stats.items():
221
+ filename, line_number, function_name = func
222
+
223
+ result = ProfileResult(
224
+ function_name=function_name,
225
+ total_time=total_time,
226
+ cumulative_time=cumulative_time,
227
+ call_count=call_count,
228
+ per_call_time=total_time / call_count if call_count > 0 else 0.0,
229
+ filename=filename,
230
+ line_number=line_number
231
+ )
232
+ results.append(result)
233
+
234
+ # Sort by cumulative time
235
+ results.sort(key=lambda x: x.cumulative_time, reverse=True)
236
+ return results
237
+
238
+ class GPUProfiler:
239
+ """GPU profiling for CUDA operations"""
240
+
241
+ def __init__(self):
242
+ self.events = []
243
+ self.profiling = False
244
+ self.lock = threading.Lock()
245
+
246
+ def start_profiling(self):
247
+ """Start GPU profiling"""
248
+ if not torch.cuda.is_available():
249
+ return
250
+
251
+ with self.lock:
252
+ if self.profiling:
253
+ return
254
+
255
+ self.events = []
256
+ self.profiling = True
257
+ torch.cuda.synchronize()
258
+
259
+ def stop_profiling(self) -> Dict[str, Any]:
260
+ """Stop GPU profiling"""
261
+ if not torch.cuda.is_available():
262
+ return {}
263
+
264
+ with self.lock:
265
+ if not self.profiling:
266
+ return {}
267
+
268
+ torch.cuda.synchronize()
269
+ self.profiling = False
270
+
271
+ # Calculate GPU metrics
272
+ total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
273
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
274
+ cached_memory = torch.cuda.memory_reserved() / (1024**3)
275
+
276
+ return {
277
+ "total_memory_gb": total_memory,
278
+ "allocated_memory_gb": allocated_memory,
279
+ "cached_memory_gb": cached_memory,
280
+ "memory_utilization": allocated_memory / total_memory * 100,
281
+ "events": len(self.events)
282
+ }
283
+
284
+ @contextmanager
285
+ def profile_operation(self, name: str):
286
+ """Context manager for profiling GPU operations"""
287
+ if not torch.cuda.is_available() or not self.profiling:
288
+ yield
289
+ return
290
+
291
+ start_event = torch.cuda.Event(enable_timing=True)
292
+ end_event = torch.cuda.Event(enable_timing=True)
293
+
294
+ start_event.record()
295
+ try:
296
+ yield
297
+ finally:
298
+ end_event.record()
299
+ torch.cuda.synchronize()
300
+
301
+ elapsed_time = start_event.elapsed_time(end_event)
302
+ with self.lock:
303
+ self.events.append({
304
+ "name": name,
305
+ "duration_ms": elapsed_time,
306
+ "timestamp": time.time()
307
+ })
308
+
309
+ class MambaSwarmProfiler:
310
+ """Comprehensive profiler for Mamba Swarm"""
311
+
312
+ def __init__(self, enable_memory_monitoring: bool = True):
313
+ self.logger = logging.getLogger(__name__)
314
+
315
+ # Initialize profilers
316
+ self.cpu_profiler = CPUProfiler()
317
+ self.memory_profiler = MemoryProfiler()
318
+ self.gpu_profiler = GPUProfiler()
319
+
320
+ # Function timers
321
+ self.function_timers: Dict[str, FunctionTimer] = {}
322
+ self.timer_lock = threading.Lock()
323
+
324
+ # Profiling state
325
+ self.profiling_active = False
326
+ self.profile_start_time = 0.0
327
+
328
+ # Performance tracking
329
+ self.performance_history = deque(maxlen=100)
330
+
331
+ # Start memory monitoring if enabled
332
+ if enable_memory_monitoring:
333
+ self.memory_profiler.start_monitoring()
334
+
335
+ def start_profiling(self, include_cpu: bool = True, include_gpu: bool = True):
336
+ """Start comprehensive profiling"""
337
+ if self.profiling_active:
338
+ self.logger.warning("Profiling already active")
339
+ return
340
+
341
+ self.profile_start_time = time.time()
342
+ self.profiling_active = True
343
+
344
+ if include_cpu:
345
+ self.cpu_profiler.start_profiling()
346
+
347
+ if include_gpu:
348
+ self.gpu_profiler.start_profiling()
349
+
350
+ self.logger.info("Started performance profiling")
351
+
352
+ def stop_profiling(self) -> PerformanceProfile:
353
+ """Stop profiling and return results"""
354
+ if not self.profiling_active:
355
+ self.logger.warning("Profiling not active")
356
+ return None
357
+
358
+ end_time = time.time()
359
+ duration = end_time - self.profile_start_time
360
+ self.profiling_active = False
361
+
362
+ # Get CPU profile
363
+ cpu_results = self.cpu_profiler.stop_profiling()
364
+
365
+ # Get GPU profile
366
+ gpu_results = self.gpu_profiler.stop_profiling()
367
+
368
+ # Get system metrics
369
+ cpu_percent = psutil.cpu_percent()
370
+ memory_info = psutil.virtual_memory()
371
+ memory_percent = memory_info.percent
372
+
373
+ gpu_usage = 0.0
374
+ if torch.cuda.is_available():
375
+ gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
376
+
377
+ # Get memory snapshots
378
+ memory_snapshots = list(self.memory_profiler.snapshots)
379
+
380
+ # Analyze bottlenecks
381
+ bottlenecks = self._analyze_bottlenecks(cpu_results, gpu_results)
382
+
383
+ # Generate recommendations
384
+ recommendations = self._generate_recommendations(cpu_results, gpu_results, memory_snapshots)
385
+
386
+ profile = PerformanceProfile(
387
+ timestamp=end_time,
388
+ duration=duration,
389
+ cpu_usage=cpu_percent,
390
+ memory_usage=memory_percent,
391
+ gpu_usage=gpu_usage,
392
+ function_calls=cpu_results,
393
+ memory_snapshots=memory_snapshots,
394
+ bottlenecks=bottlenecks,
395
+ recommendations=recommendations
396
+ )
397
+
398
+ self.performance_history.append(profile)
399
+ self.logger.info(f"Completed performance profiling (duration: {duration:.2f}s)")
400
+
401
+ return profile
402
+
403
+ def _analyze_bottlenecks(self, cpu_results: List[ProfileResult], gpu_results: Dict[str, Any]) -> List[str]:
404
+ """Analyze performance bottlenecks"""
405
+ bottlenecks = []
406
+
407
+ # CPU bottlenecks
408
+ if cpu_results:
409
+ top_cpu_functions = cpu_results[:5]
410
+ for func in top_cpu_functions:
411
+ if func.cumulative_time > 1.0: # More than 1 second
412
+ bottlenecks.append(f"CPU: {func.function_name} ({func.cumulative_time:.2f}s)")
413
+
414
+ # Memory bottlenecks
415
+ peak_memory = self.memory_profiler.get_peak_memory()
416
+ if peak_memory > 8.0: # More than 8GB
417
+ bottlenecks.append(f"Memory: High usage ({peak_memory:.2f}GB)")
418
+
419
+ # GPU bottlenecks
420
+ if gpu_results and gpu_results.get("memory_utilization", 0) > 90:
421
+ bottlenecks.append("GPU: High memory utilization")
422
+
423
+ return bottlenecks
424
+
425
+ def _generate_recommendations(self, cpu_results: List[ProfileResult],
426
+ gpu_results: Dict[str, Any],
427
+ memory_snapshots: List[MemorySnapshot]) -> List[str]:
428
+ """Generate optimization recommendations"""
429
+ recommendations = []
430
+
431
+ # CPU recommendations
432
+ if cpu_results:
433
+ slow_functions = [f for f in cpu_results if f.per_call_time > 0.1]
434
+ if slow_functions:
435
+ recommendations.append("Consider optimizing slow functions or using caching")
436
+
437
+ # Memory recommendations
438
+ if memory_snapshots:
439
+ tensor_counts = [s.tensor_count for s in memory_snapshots]
440
+ if tensor_counts and max(tensor_counts) > 10000:
441
+ recommendations.append("High tensor count detected - consider tensor cleanup")
442
+
443
+ # GPU recommendations
444
+ if gpu_results:
445
+ if gpu_results.get("memory_utilization", 0) > 85:
446
+ recommendations.append("Consider reducing batch size or using gradient checkpointing")
447
+
448
+ return recommendations
449
+
450
+ def profile_function(self, func_name: str):
451
+ """Decorator for profiling individual functions"""
452
+ def decorator(func):
453
+ @functools.wraps(func)
454
+ def wrapper(*args, **kwargs):
455
+ start_time = time.time()
456
+ try:
457
+ result = func(*args, **kwargs)
458
+ return result
459
+ finally:
460
+ duration = time.time() - start_time
461
+
462
+ with self.timer_lock:
463
+ if func_name not in self.function_timers:
464
+ self.function_timers[func_name] = FunctionTimer(func_name)
465
+ self.function_timers[func_name].add_call(duration)
466
+
467
+ return wrapper
468
+ return decorator
469
+
470
+ @contextmanager
471
+ def profile_block(self, block_name: str):
472
+ """Context manager for profiling code blocks"""
473
+ start_time = time.time()
474
+ try:
475
+ yield
476
+ finally:
477
+ duration = time.time() - start_time
478
+
479
+ with self.timer_lock:
480
+ if block_name not in self.function_timers:
481
+ self.function_timers[block_name] = FunctionTimer(block_name)
482
+ self.function_timers[block_name].add_call(duration)
483
+
484
+ def get_function_stats(self) -> Dict[str, Dict[str, Any]]:
485
+ """Get statistics for all profiled functions"""
486
+ with self.timer_lock:
487
+ return {name: timer.get_stats() for name, timer in self.function_timers.items()}
488
+
489
+ def export_profile_report(self, filename: Optional[str] = None) -> str:
490
+ """Export comprehensive profile report"""
491
+ if not filename:
492
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
493
+ filename = f"mamba_swarm_profile_{timestamp}.json"
494
+
495
+ report = {
496
+ "timestamp": time.time(),
497
+ "profiler_stats": {
498
+ "function_timers": self.get_function_stats(),
499
+ "peak_memory_gb": self.memory_profiler.get_peak_memory(),
500
+ "memory_trend": self.memory_profiler.get_memory_trend()[-50:], # Last 50 samples
501
+ },
502
+ "performance_history": [
503
+ {
504
+ "timestamp": p.timestamp,
505
+ "duration": p.duration,
506
+ "cpu_usage": p.cpu_usage,
507
+ "memory_usage": p.memory_usage,
508
+ "gpu_usage": p.gpu_usage,
509
+ "bottlenecks": p.bottlenecks,
510
+ "recommendations": p.recommendations
511
+ }
512
+ for p in list(self.performance_history)[-10:] # Last 10 profiles
513
+ ]
514
+ }
515
+
516
+ with open(filename, 'w') as f:
517
+ json.dump(report, f, indent=2)
518
+
519
+ self.logger.info(f"Profile report exported to {filename}")
520
+ return filename
521
+
522
+ def cleanup(self):
523
+ """Cleanup profiler resources"""
524
+ self.memory_profiler.stop_monitoring()
525
+ if self.profiling_active:
526
+ self.stop_profiling()
527
+
528
+ # Utility functions and decorators
529
+ def profile_inference(profiler: MambaSwarmProfiler):
530
+ """Decorator for profiling inference functions"""
531
+ return profiler.profile_function("inference")
532
+
533
+ def profile_training_step(profiler: MambaSwarmProfiler):
534
+ """Decorator for profiling training steps"""
535
+ return profiler.profile_function("training_step")
536
+
537
+ def profile_forward_pass(profiler: MambaSwarmProfiler):
538
+ """Decorator for profiling forward passes"""
539
+ return profiler.profile_function("forward_pass")
540
+
541
+ # Example usage
542
+ if __name__ == "__main__":
543
+ # Create profiler
544
+ profiler = MambaSwarmProfiler()
545
+
546
+ # Start profiling
547
+ profiler.start_profiling()
548
+
549
+ # Simulate some work
550
+ @profiler.profile_function("test_function")
551
+ def test_function():
552
+ time.sleep(0.1)
553
+ return "result"
554
+
555
+ # Run test
556
+ for i in range(10):
557
+ test_function()
558
+
559
+ # Use context manager
560
+ with profiler.profile_block("test_block"):
561
+ time.sleep(0.05)
562
+
563
+ # Stop profiling
564
+ profile_result = profiler.stop_profiling()
565
+
566
+ # Print results
567
+ if profile_result:
568
+ print(f"Profile duration: {profile_result.duration:.2f}s")
569
+ print(f"CPU usage: {profile_result.cpu_usage:.1f}%")
570
+ print(f"Memory usage: {profile_result.memory_usage:.1f}%")
571
+ print(f"Bottlenecks: {profile_result.bottlenecks}")
572
+ print(f"Recommendations: {profile_result.recommendations}")
573
+
574
+ # Export report
575
+ report_file = profiler.export_profile_report()
576
+ print(f"Report saved to: {report_file}")
577
+
578
+ # Cleanup
579
+ profiler.cleanup()