Spaces:
Sleeping
Sleeping
File size: 6,024 Bytes
e9bb6c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import time
import gc
import psutil
from dataclasses import dataclass
from typing import Dict, List, Optional, Any
from torch.nn import CrossEntropyLoss
@dataclass
class BenchmarkConfig:
"""Configuration for benchmarking."""
model_name: str
dataset_name: str = "tatsu-lab/alpaca"
num_samples: int = 20
max_new_tokens: int = 100
quantization_type: str = "none"
use_torch_compile: bool = False
calculate_perplexity: bool = False
device: Optional[str] = None
seed: int = 42
@dataclass
class BenchmarkResult:
"""Single benchmark result."""
prompt_id: int
prompt: str
generated_text: str
input_tokens: int
output_tokens: int
total_time_seconds: float
tokens_per_second: float
first_token_latency_seconds: float
peak_memory_mb: float
perplexity: Optional[float] = None
class MemoryTracker:
"""Handles memory tracking across different devices."""
def __init__(self, device: str):
self.device = device
def reset_stats(self):
"""Reset memory tracking."""
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def get_peak_memory_mb(self) -> float:
"""Get peak memory usage in MB."""
if self.device == "cuda" and torch.cuda.is_available():
return torch.cuda.max_memory_allocated() / (1024 * 1024)
else:
return psutil.Process().memory_info().rss / (1024 * 1024)
def synchronize(self):
"""Synchronize device operations."""
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif self.device == "mps" and hasattr(torch.backends, 'mps'):
if hasattr(torch.mps, 'synchronize'):
torch.mps.synchronize()
def clear_cache(self):
"""Clear memory cache."""
gc.collect()
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
class PerplexityCalculator:
"""Handles perplexity calculation."""
def __init__(self, model, tokenizer, device: str):
self.model = model
self.tokenizer = tokenizer
self.device = device
def calculate(self, text: str) -> float:
"""Calculate perplexity of text."""
try:
encodings = self.tokenizer(text, return_tensors="pt").to(self.device)
input_ids = encodings.input_ids
if input_ids.size(1) <= 1:
return float('inf')
with torch.no_grad():
outputs = self.model(input_ids=input_ids, labels=input_ids.clone())
if hasattr(outputs, 'loss') and outputs.loss is not None:
return torch.exp(outputs.loss).item()
# Fallback manual calculation
logits = outputs.logits[:, :-1, :].contiguous()
labels = input_ids[:, 1:].contiguous()
loss_fn = CrossEntropyLoss()
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
return torch.exp(loss).item()
except Exception as e:
print(f"Perplexity calculation failed: {e}")
return None
class InferenceRunner:
"""Handles model inference with timing and memory tracking."""
def __init__(self, model, tokenizer, device: str):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.memory_tracker = MemoryTracker(device)
def run_single_inference(self, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
"""Run inference on a single prompt."""
# Tokenize input
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
input_token_count = input_ids.shape[1]
# Reset memory tracking
self.memory_tracker.reset_stats()
initial_memory = self.memory_tracker.get_peak_memory_mb()
# Generation parameters
gen_params = {
"max_new_tokens": max_new_tokens,
"do_sample": False,
"pad_token_id": self.tokenizer.eos_token_id
}
# Time first token
self.memory_tracker.synchronize()
first_token_start = time.time()
with torch.no_grad():
first_output = self.model.generate(input_ids, max_new_tokens=1, **{k: v for k, v in gen_params.items() if k != 'max_new_tokens'})
self.memory_tracker.synchronize()
first_token_latency = time.time() - first_token_start
# Full generation
start_time = time.time()
with torch.no_grad():
outputs = self.model.generate(input_ids, **gen_params)
self.memory_tracker.synchronize()
total_time = time.time() - start_time
# Calculate metrics
output_ids = outputs[0][input_token_count:]
generated_token_count = len(output_ids)
tokens_per_second = generated_token_count / total_time if total_time > 0 else 0
# Get memory usage
peak_memory_mb = self.memory_tracker.get_peak_memory_mb()
if self.device != "cuda":
peak_memory_mb = peak_memory_mb - initial_memory
# Decode output
generated_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
# Clear memory
self.memory_tracker.clear_cache()
return {
"input_tokens": input_token_count,
"output_tokens": generated_token_count,
"total_time_seconds": total_time,
"tokens_per_second": tokens_per_second,
"first_token_latency_seconds": first_token_latency,
"peak_memory_mb": peak_memory_mb,
"generated_text": generated_text
} |