AIguysingstoo's picture
Upload 9 files
e9bb6c3 verified
import torch
import time
import gc
import psutil
from dataclasses import dataclass
from typing import Dict, List, Optional, Any
from torch.nn import CrossEntropyLoss
@dataclass
class BenchmarkConfig:
"""Configuration for benchmarking."""
model_name: str
dataset_name: str = "tatsu-lab/alpaca"
num_samples: int = 20
max_new_tokens: int = 100
quantization_type: str = "none"
use_torch_compile: bool = False
calculate_perplexity: bool = False
device: Optional[str] = None
seed: int = 42
@dataclass
class BenchmarkResult:
"""Single benchmark result."""
prompt_id: int
prompt: str
generated_text: str
input_tokens: int
output_tokens: int
total_time_seconds: float
tokens_per_second: float
first_token_latency_seconds: float
peak_memory_mb: float
perplexity: Optional[float] = None
class MemoryTracker:
"""Handles memory tracking across different devices."""
def __init__(self, device: str):
self.device = device
def reset_stats(self):
"""Reset memory tracking."""
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def get_peak_memory_mb(self) -> float:
"""Get peak memory usage in MB."""
if self.device == "cuda" and torch.cuda.is_available():
return torch.cuda.max_memory_allocated() / (1024 * 1024)
else:
return psutil.Process().memory_info().rss / (1024 * 1024)
def synchronize(self):
"""Synchronize device operations."""
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif self.device == "mps" and hasattr(torch.backends, 'mps'):
if hasattr(torch.mps, 'synchronize'):
torch.mps.synchronize()
def clear_cache(self):
"""Clear memory cache."""
gc.collect()
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
class PerplexityCalculator:
"""Handles perplexity calculation."""
def __init__(self, model, tokenizer, device: str):
self.model = model
self.tokenizer = tokenizer
self.device = device
def calculate(self, text: str) -> float:
"""Calculate perplexity of text."""
try:
encodings = self.tokenizer(text, return_tensors="pt").to(self.device)
input_ids = encodings.input_ids
if input_ids.size(1) <= 1:
return float('inf')
with torch.no_grad():
outputs = self.model(input_ids=input_ids, labels=input_ids.clone())
if hasattr(outputs, 'loss') and outputs.loss is not None:
return torch.exp(outputs.loss).item()
# Fallback manual calculation
logits = outputs.logits[:, :-1, :].contiguous()
labels = input_ids[:, 1:].contiguous()
loss_fn = CrossEntropyLoss()
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
return torch.exp(loss).item()
except Exception as e:
print(f"Perplexity calculation failed: {e}")
return None
class InferenceRunner:
"""Handles model inference with timing and memory tracking."""
def __init__(self, model, tokenizer, device: str):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.memory_tracker = MemoryTracker(device)
def run_single_inference(self, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
"""Run inference on a single prompt."""
# Tokenize input
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
input_token_count = input_ids.shape[1]
# Reset memory tracking
self.memory_tracker.reset_stats()
initial_memory = self.memory_tracker.get_peak_memory_mb()
# Generation parameters
gen_params = {
"max_new_tokens": max_new_tokens,
"do_sample": False,
"pad_token_id": self.tokenizer.eos_token_id
}
# Time first token
self.memory_tracker.synchronize()
first_token_start = time.time()
with torch.no_grad():
first_output = self.model.generate(input_ids, max_new_tokens=1, **{k: v for k, v in gen_params.items() if k != 'max_new_tokens'})
self.memory_tracker.synchronize()
first_token_latency = time.time() - first_token_start
# Full generation
start_time = time.time()
with torch.no_grad():
outputs = self.model.generate(input_ids, **gen_params)
self.memory_tracker.synchronize()
total_time = time.time() - start_time
# Calculate metrics
output_ids = outputs[0][input_token_count:]
generated_token_count = len(output_ids)
tokens_per_second = generated_token_count / total_time if total_time > 0 else 0
# Get memory usage
peak_memory_mb = self.memory_tracker.get_peak_memory_mb()
if self.device != "cuda":
peak_memory_mb = peak_memory_mb - initial_memory
# Decode output
generated_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
# Clear memory
self.memory_tracker.clear_cache()
return {
"input_tokens": input_token_count,
"output_tokens": generated_token_count,
"total_time_seconds": total_time,
"tokens_per_second": tokens_per_second,
"first_token_latency_seconds": first_token_latency,
"peak_memory_mb": peak_memory_mb,
"generated_text": generated_text
}