File size: 6,024 Bytes
e9bb6c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import time
import gc
import psutil
from dataclasses import dataclass
from typing import Dict, List, Optional, Any
from torch.nn import CrossEntropyLoss

@dataclass
class BenchmarkConfig:
    """Configuration for benchmarking."""
    model_name: str
    dataset_name: str = "tatsu-lab/alpaca"
    num_samples: int = 20
    max_new_tokens: int = 100
    quantization_type: str = "none"
    use_torch_compile: bool = False
    calculate_perplexity: bool = False
    device: Optional[str] = None
    seed: int = 42

@dataclass 
class BenchmarkResult:
    """Single benchmark result."""
    prompt_id: int
    prompt: str
    generated_text: str
    input_tokens: int
    output_tokens: int
    total_time_seconds: float
    tokens_per_second: float
    first_token_latency_seconds: float
    peak_memory_mb: float
    perplexity: Optional[float] = None

class MemoryTracker:
    """Handles memory tracking across different devices."""
    
    def __init__(self, device: str):
        self.device = device
        
    def reset_stats(self):
        """Reset memory tracking."""
        if self.device == "cuda" and torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
    
    def get_peak_memory_mb(self) -> float:
        """Get peak memory usage in MB."""
        if self.device == "cuda" and torch.cuda.is_available():
            return torch.cuda.max_memory_allocated() / (1024 * 1024)
        else:
            return psutil.Process().memory_info().rss / (1024 * 1024)
    
    def synchronize(self):
        """Synchronize device operations."""
        if self.device == "cuda" and torch.cuda.is_available():
            torch.cuda.synchronize()
        elif self.device == "mps" and hasattr(torch.backends, 'mps'):
            if hasattr(torch.mps, 'synchronize'):
                torch.mps.synchronize()
    
    def clear_cache(self):
        """Clear memory cache."""
        gc.collect()
        if self.device == "cuda" and torch.cuda.is_available():
            torch.cuda.empty_cache()

class PerplexityCalculator:
    """Handles perplexity calculation."""
    
    def __init__(self, model, tokenizer, device: str):
        self.model = model
        self.tokenizer = tokenizer  
        self.device = device
        
    def calculate(self, text: str) -> float:
        """Calculate perplexity of text."""
        try:
            encodings = self.tokenizer(text, return_tensors="pt").to(self.device)
            input_ids = encodings.input_ids
            
            if input_ids.size(1) <= 1:
                return float('inf')
                
            with torch.no_grad():
                outputs = self.model(input_ids=input_ids, labels=input_ids.clone())
                
                if hasattr(outputs, 'loss') and outputs.loss is not None:
                    return torch.exp(outputs.loss).item()
                
                # Fallback manual calculation
                logits = outputs.logits[:, :-1, :].contiguous()
                labels = input_ids[:, 1:].contiguous()
                
                loss_fn = CrossEntropyLoss()
                loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
                return torch.exp(loss).item()
                
        except Exception as e:
            print(f"Perplexity calculation failed: {e}")
            return None

class InferenceRunner:
    """Handles model inference with timing and memory tracking."""
    
    def __init__(self, model, tokenizer, device: str):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.memory_tracker = MemoryTracker(device)
        
    def run_single_inference(self, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
        """Run inference on a single prompt."""
        # Tokenize input
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        input_token_count = input_ids.shape[1]
        
        # Reset memory tracking
        self.memory_tracker.reset_stats()
        initial_memory = self.memory_tracker.get_peak_memory_mb()
        
        # Generation parameters
        gen_params = {
            "max_new_tokens": max_new_tokens,
            "do_sample": False,
            "pad_token_id": self.tokenizer.eos_token_id
        }
        
        # Time first token
        self.memory_tracker.synchronize()
        first_token_start = time.time()
        
        with torch.no_grad():
            first_output = self.model.generate(input_ids, max_new_tokens=1, **{k: v for k, v in gen_params.items() if k != 'max_new_tokens'})
            
        self.memory_tracker.synchronize()
        first_token_latency = time.time() - first_token_start
        
        # Full generation
        start_time = time.time()
        
        with torch.no_grad():
            outputs = self.model.generate(input_ids, **gen_params)
            
        self.memory_tracker.synchronize()
        total_time = time.time() - start_time
        
        # Calculate metrics
        output_ids = outputs[0][input_token_count:]
        generated_token_count = len(output_ids)
        tokens_per_second = generated_token_count / total_time if total_time > 0 else 0
        
        # Get memory usage
        peak_memory_mb = self.memory_tracker.get_peak_memory_mb()
        if self.device != "cuda":
            peak_memory_mb = peak_memory_mb - initial_memory
            
        # Decode output
        generated_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
        
        # Clear memory
        self.memory_tracker.clear_cache()
        
        return {
            "input_tokens": input_token_count,
            "output_tokens": generated_token_count,
            "total_time_seconds": total_time,
            "tokens_per_second": tokens_per_second,
            "first_token_latency_seconds": first_token_latency,
            "peak_memory_mb": peak_memory_mb,
            "generated_text": generated_text
        }