File size: 5,762 Bytes
e9bb6c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import time
import json
import os
import numpy as np
from typing import Dict, List, Any
from dataclasses import asdict

from models.quantization import ModelLoader, QuantizationType
from core.benchmark import BenchmarkConfig, BenchmarkResult, InferenceRunner, PerplexityCalculator
from core.data import DatasetLoader
from core.utils import get_device

class ModelBenchmarker:
    """Main benchmarking agent."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        
    def load_model(self, config: BenchmarkConfig):
        """Load model based on configuration."""
        self.device = get_device(config.device)
        
        quant_type = QuantizationType(config.quantization_type)
        
        if quant_type == QuantizationType.NONE:
            self.model, self.tokenizer = ModelLoader.load_standard(config.model_name, self.device)
        else:
            # Try Transformers integration first, fallback to direct API
            try:
                self.model, self.tokenizer = ModelLoader.load_quantized_transformers(config.model_name, quant_type)
                self.device = str(next(self.model.parameters()).device)
            except Exception as e:
                print(f"Transformers integration failed, using direct API: {e}")
                self.model, self.tokenizer = ModelLoader.load_quantized_direct(config.model_name, quant_type, self.device)
        
        # Apply torch.compile if requested
        if config.use_torch_compile:
            print("Applying torch.compile...")
            self.model = torch.compile(self.model)
            
    def run_benchmark(self, config: BenchmarkConfig) -> Dict[str, Any]:
        """Run benchmark with given configuration."""
        if self.model is None:
            self.load_model(config)
            
        # Get sample prompts
        prompts, indices = DatasetLoader.get_sample_prompts(config.dataset_name, config.num_samples, config.seed)
        
        # Setup inference runner
        inference_runner = InferenceRunner(self.model, self.tokenizer, self.device)
        
        # Setup perplexity calculator if needed
        perplexity_calc = None
        if config.calculate_perplexity:
            perplexity_calc = PerplexityCalculator(self.model, self.tokenizer, self.device)
        
        results = []
        
        for i, prompt in enumerate(prompts):
            print(f"Processing prompt {i+1}/{len(prompts)}")
            
            # Run inference
            inference_result = inference_runner.run_single_inference(prompt, config.max_new_tokens)
            
            # Calculate perplexity if requested
            perplexity = None
            if perplexity_calc:
                perplexity = perplexity_calc.calculate(inference_result["generated_text"])
            
            # Create result
            result = BenchmarkResult(
                prompt_id=i,
                prompt=prompt,
                generated_text=inference_result["generated_text"],
                input_tokens=inference_result["input_tokens"],
                output_tokens=inference_result["output_tokens"],
                total_time_seconds=inference_result["total_time_seconds"],
                tokens_per_second=inference_result["tokens_per_second"],
                first_token_latency_seconds=inference_result["first_token_latency_seconds"],
                peak_memory_mb=inference_result["peak_memory_mb"],
                perplexity=perplexity
            )
            
            results.append(result)
        
        # Calculate summary
        summary = self._create_summary(config, results)
        
        return {
            "summary": summary,
            "samples": [asdict(result) for result in results]
        }
    
    def _create_summary(self, config: BenchmarkConfig, results: List[BenchmarkResult]) -> Dict[str, Any]:
        """Create benchmark summary."""
        avg_tokens_per_second = sum(r.tokens_per_second for r in results) / len(results)
        avg_first_token_latency = sum(r.first_token_latency_seconds for r in results) / len(results)
        max_memory_mb = max(r.peak_memory_mb for r in results)
        
        avg_perplexity = None
        if config.calculate_perplexity:
            valid_perplexities = [r.perplexity for r in results if r.perplexity is not None and not np.isinf(r.perplexity)]
            if valid_perplexities:
                avg_perplexity = sum(valid_perplexities) / len(valid_perplexities)
        
        optimization_desc = config.quantization_type
        if config.use_torch_compile:
            optimization_desc += " + torch.compile"
            
        return {
            "model_name": f"{config.model_name} ({optimization_desc})",
            "device": self.device,
            "num_samples": len(results),
            "avg_tokens_per_second": avg_tokens_per_second,
            "avg_first_token_latency_seconds": avg_first_token_latency,
            "max_memory_mb": max_memory_mb,
            "avg_perplexity": avg_perplexity,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "optimization_type": optimization_desc
        }
    
    def save_results(self, results: Dict[str, Any], output_dir: str = "benchmark_results") -> str:
        """Save benchmark results."""
        os.makedirs(output_dir, exist_ok=True)
        
        model_name = results["summary"]["model_name"].split('/')[-1].replace(' ', '_')
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(output_dir, f"{model_name}_{timestamp}.json")
        
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
            
        return output_file