Spaces:
Sleeping
Sleeping
| import torch | |
| import time | |
| import gc | |
| import psutil | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Any | |
| from torch.nn import CrossEntropyLoss | |
| class BenchmarkConfig: | |
| """Configuration for benchmarking.""" | |
| model_name: str | |
| dataset_name: str = "tatsu-lab/alpaca" | |
| num_samples: int = 20 | |
| max_new_tokens: int = 100 | |
| quantization_type: str = "none" | |
| use_torch_compile: bool = False | |
| calculate_perplexity: bool = False | |
| device: Optional[str] = None | |
| seed: int = 42 | |
| class BenchmarkResult: | |
| """Single benchmark result.""" | |
| prompt_id: int | |
| prompt: str | |
| generated_text: str | |
| input_tokens: int | |
| output_tokens: int | |
| total_time_seconds: float | |
| tokens_per_second: float | |
| first_token_latency_seconds: float | |
| peak_memory_mb: float | |
| perplexity: Optional[float] = None | |
| class MemoryTracker: | |
| """Handles memory tracking across different devices.""" | |
| def __init__(self, device: str): | |
| self.device = device | |
| def reset_stats(self): | |
| """Reset memory tracking.""" | |
| if self.device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| def get_peak_memory_mb(self) -> float: | |
| """Get peak memory usage in MB.""" | |
| if self.device == "cuda" and torch.cuda.is_available(): | |
| return torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| else: | |
| return psutil.Process().memory_info().rss / (1024 * 1024) | |
| def synchronize(self): | |
| """Synchronize device operations.""" | |
| if self.device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elif self.device == "mps" and hasattr(torch.backends, 'mps'): | |
| if hasattr(torch.mps, 'synchronize'): | |
| torch.mps.synchronize() | |
| def clear_cache(self): | |
| """Clear memory cache.""" | |
| gc.collect() | |
| if self.device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| class PerplexityCalculator: | |
| """Handles perplexity calculation.""" | |
| def __init__(self, model, tokenizer, device: str): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| def calculate(self, text: str) -> float: | |
| """Calculate perplexity of text.""" | |
| try: | |
| encodings = self.tokenizer(text, return_tensors="pt").to(self.device) | |
| input_ids = encodings.input_ids | |
| if input_ids.size(1) <= 1: | |
| return float('inf') | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids=input_ids, labels=input_ids.clone()) | |
| if hasattr(outputs, 'loss') and outputs.loss is not None: | |
| return torch.exp(outputs.loss).item() | |
| # Fallback manual calculation | |
| logits = outputs.logits[:, :-1, :].contiguous() | |
| labels = input_ids[:, 1:].contiguous() | |
| loss_fn = CrossEntropyLoss() | |
| loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| return torch.exp(loss).item() | |
| except Exception as e: | |
| print(f"Perplexity calculation failed: {e}") | |
| return None | |
| class InferenceRunner: | |
| """Handles model inference with timing and memory tracking.""" | |
| def __init__(self, model, tokenizer, device: str): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| self.memory_tracker = MemoryTracker(device) | |
| def run_single_inference(self, prompt: str, max_new_tokens: int) -> Dict[str, Any]: | |
| """Run inference on a single prompt.""" | |
| # Tokenize input | |
| input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| input_token_count = input_ids.shape[1] | |
| # Reset memory tracking | |
| self.memory_tracker.reset_stats() | |
| initial_memory = self.memory_tracker.get_peak_memory_mb() | |
| # Generation parameters | |
| gen_params = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": False, | |
| "pad_token_id": self.tokenizer.eos_token_id | |
| } | |
| # Time first token | |
| self.memory_tracker.synchronize() | |
| first_token_start = time.time() | |
| with torch.no_grad(): | |
| first_output = self.model.generate(input_ids, max_new_tokens=1, **{k: v for k, v in gen_params.items() if k != 'max_new_tokens'}) | |
| self.memory_tracker.synchronize() | |
| first_token_latency = time.time() - first_token_start | |
| # Full generation | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = self.model.generate(input_ids, **gen_params) | |
| self.memory_tracker.synchronize() | |
| total_time = time.time() - start_time | |
| # Calculate metrics | |
| output_ids = outputs[0][input_token_count:] | |
| generated_token_count = len(output_ids) | |
| tokens_per_second = generated_token_count / total_time if total_time > 0 else 0 | |
| # Get memory usage | |
| peak_memory_mb = self.memory_tracker.get_peak_memory_mb() | |
| if self.device != "cuda": | |
| peak_memory_mb = peak_memory_mb - initial_memory | |
| # Decode output | |
| generated_text = self.tokenizer.decode(output_ids, skip_special_tokens=True) | |
| # Clear memory | |
| self.memory_tracker.clear_cache() | |
| return { | |
| "input_tokens": input_token_count, | |
| "output_tokens": generated_token_count, | |
| "total_time_seconds": total_time, | |
| "tokens_per_second": tokens_per_second, | |
| "first_token_latency_seconds": first_token_latency, | |
| "peak_memory_mb": peak_memory_mb, | |
| "generated_text": generated_text | |
| } |