Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- README.md +65 -5
- agent/benchmarker.py +136 -0
- core/benchmark.py +172 -0
- core/data.py +34 -0
- core/utils.py +36 -0
- interfaces/gradio_app.py +252 -0
- main.py +28 -0
- models/quantization.py +91 -0
- requirements.txt +283 -0
README.md
CHANGED
|
@@ -1,12 +1,72 @@
|
|
| 1 |
---
|
| 2 |
title: Optimization Engineer
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.33.1
|
| 8 |
-
app_file:
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Optimization Engineer
|
| 3 |
+
emoji: ⚡
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.33.1
|
| 8 |
+
app_file: main.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
short_description: A modular, simplified model optimizing agent !
|
| 12 |
+
tags:
|
| 13 |
+
- mcp-server-track
|
| 14 |
+
- optimization
|
| 15 |
+
- mcp-server
|
| 16 |
+
- gradio
|
| 17 |
---
|
| 18 |
|
| 19 |
+
# Optimization Engineer 🚀
|
| 20 |
+
|
| 21 |
+
An intelligent optimization engineer that serves as both a Gradio web application and an MCP (Model Context Protocol) server for advanced optimization tasks.
|
| 22 |
+
|
| 23 |
+
## 🎯 MCP Server Track Submission
|
| 24 |
+
|
| 25 |
+
This space is submitted for the MCP Server Track. It functions as:
|
| 26 |
+
1. **Gradio App**: Interactive web interface for optimization tasks
|
| 27 |
+
2. **MCP Server**: Can be connected to MCP clients like Claude Desktop, Cursor, etc.
|
| 28 |
+
|
| 29 |
+
## 🎥 Demo Video
|
| 30 |
+
|
| 31 |
+
[Link to demo video showing MCP server in action - TO BE ADDED]
|
| 32 |
+
|
| 33 |
+
## ✨ Features
|
| 34 |
+
|
| 35 |
+
- Interactive optimization interface
|
| 36 |
+
- MCP server capabilities for external tool integration
|
| 37 |
+
- Advanced optimization algorithms and techniques
|
| 38 |
+
- Real-time performance monitoring and benchmarking
|
| 39 |
+
|
| 40 |
+
## 🚀 Usage
|
| 41 |
+
|
| 42 |
+
### As a Gradio App
|
| 43 |
+
Simply use the interface above to interact with the optimization tools.
|
| 44 |
+
|
| 45 |
+
### As an MCP Server
|
| 46 |
+
(Local Only) Copy the generated MCP server URL to your Host's integration to access the tools through Claude Desktop.
|
| 47 |
+
|
| 48 |
+
## 🛠️ Development
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
# Clone the repository
|
| 52 |
+
git clone https://huggingface.co/spaces/AIguysingstoo/optimization-engineer
|
| 53 |
+
|
| 54 |
+
# Install dependencies
|
| 55 |
+
pip install -r requirements.txt
|
| 56 |
+
|
| 57 |
+
# Run locally
|
| 58 |
+
python main.py gradio
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## 📋 Requirements
|
| 62 |
+
|
| 63 |
+
- Python 3.10+
|
| 64 |
+
- Dependencies listed in requirements.txt
|
| 65 |
+
|
| 66 |
+
## 🤝 Contributing
|
| 67 |
+
|
| 68 |
+
Feel free to submit issues and enhancement requests!
|
| 69 |
+
|
| 70 |
+
## 📄 License
|
| 71 |
+
|
| 72 |
+
Apache 2.0 License
|
agent/benchmarker.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import Dict, List, Any
|
| 7 |
+
from dataclasses import asdict
|
| 8 |
+
|
| 9 |
+
from models.quantization import ModelLoader, QuantizationType
|
| 10 |
+
from core.benchmark import BenchmarkConfig, BenchmarkResult, InferenceRunner, PerplexityCalculator
|
| 11 |
+
from core.data import DatasetLoader
|
| 12 |
+
from core.utils import get_device
|
| 13 |
+
|
| 14 |
+
class ModelBenchmarker:
|
| 15 |
+
"""Main benchmarking agent."""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.model = None
|
| 19 |
+
self.tokenizer = None
|
| 20 |
+
self.device = None
|
| 21 |
+
|
| 22 |
+
def load_model(self, config: BenchmarkConfig):
|
| 23 |
+
"""Load model based on configuration."""
|
| 24 |
+
self.device = get_device(config.device)
|
| 25 |
+
|
| 26 |
+
quant_type = QuantizationType(config.quantization_type)
|
| 27 |
+
|
| 28 |
+
if quant_type == QuantizationType.NONE:
|
| 29 |
+
self.model, self.tokenizer = ModelLoader.load_standard(config.model_name, self.device)
|
| 30 |
+
else:
|
| 31 |
+
# Try Transformers integration first, fallback to direct API
|
| 32 |
+
try:
|
| 33 |
+
self.model, self.tokenizer = ModelLoader.load_quantized_transformers(config.model_name, quant_type)
|
| 34 |
+
self.device = str(next(self.model.parameters()).device)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Transformers integration failed, using direct API: {e}")
|
| 37 |
+
self.model, self.tokenizer = ModelLoader.load_quantized_direct(config.model_name, quant_type, self.device)
|
| 38 |
+
|
| 39 |
+
# Apply torch.compile if requested
|
| 40 |
+
if config.use_torch_compile:
|
| 41 |
+
print("Applying torch.compile...")
|
| 42 |
+
self.model = torch.compile(self.model)
|
| 43 |
+
|
| 44 |
+
def run_benchmark(self, config: BenchmarkConfig) -> Dict[str, Any]:
|
| 45 |
+
"""Run benchmark with given configuration."""
|
| 46 |
+
if self.model is None:
|
| 47 |
+
self.load_model(config)
|
| 48 |
+
|
| 49 |
+
# Get sample prompts
|
| 50 |
+
prompts, indices = DatasetLoader.get_sample_prompts(config.dataset_name, config.num_samples, config.seed)
|
| 51 |
+
|
| 52 |
+
# Setup inference runner
|
| 53 |
+
inference_runner = InferenceRunner(self.model, self.tokenizer, self.device)
|
| 54 |
+
|
| 55 |
+
# Setup perplexity calculator if needed
|
| 56 |
+
perplexity_calc = None
|
| 57 |
+
if config.calculate_perplexity:
|
| 58 |
+
perplexity_calc = PerplexityCalculator(self.model, self.tokenizer, self.device)
|
| 59 |
+
|
| 60 |
+
results = []
|
| 61 |
+
|
| 62 |
+
for i, prompt in enumerate(prompts):
|
| 63 |
+
print(f"Processing prompt {i+1}/{len(prompts)}")
|
| 64 |
+
|
| 65 |
+
# Run inference
|
| 66 |
+
inference_result = inference_runner.run_single_inference(prompt, config.max_new_tokens)
|
| 67 |
+
|
| 68 |
+
# Calculate perplexity if requested
|
| 69 |
+
perplexity = None
|
| 70 |
+
if perplexity_calc:
|
| 71 |
+
perplexity = perplexity_calc.calculate(inference_result["generated_text"])
|
| 72 |
+
|
| 73 |
+
# Create result
|
| 74 |
+
result = BenchmarkResult(
|
| 75 |
+
prompt_id=i,
|
| 76 |
+
prompt=prompt,
|
| 77 |
+
generated_text=inference_result["generated_text"],
|
| 78 |
+
input_tokens=inference_result["input_tokens"],
|
| 79 |
+
output_tokens=inference_result["output_tokens"],
|
| 80 |
+
total_time_seconds=inference_result["total_time_seconds"],
|
| 81 |
+
tokens_per_second=inference_result["tokens_per_second"],
|
| 82 |
+
first_token_latency_seconds=inference_result["first_token_latency_seconds"],
|
| 83 |
+
peak_memory_mb=inference_result["peak_memory_mb"],
|
| 84 |
+
perplexity=perplexity
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
results.append(result)
|
| 88 |
+
|
| 89 |
+
# Calculate summary
|
| 90 |
+
summary = self._create_summary(config, results)
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"summary": summary,
|
| 94 |
+
"samples": [asdict(result) for result in results]
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
def _create_summary(self, config: BenchmarkConfig, results: List[BenchmarkResult]) -> Dict[str, Any]:
|
| 98 |
+
"""Create benchmark summary."""
|
| 99 |
+
avg_tokens_per_second = sum(r.tokens_per_second for r in results) / len(results)
|
| 100 |
+
avg_first_token_latency = sum(r.first_token_latency_seconds for r in results) / len(results)
|
| 101 |
+
max_memory_mb = max(r.peak_memory_mb for r in results)
|
| 102 |
+
|
| 103 |
+
avg_perplexity = None
|
| 104 |
+
if config.calculate_perplexity:
|
| 105 |
+
valid_perplexities = [r.perplexity for r in results if r.perplexity is not None and not np.isinf(r.perplexity)]
|
| 106 |
+
if valid_perplexities:
|
| 107 |
+
avg_perplexity = sum(valid_perplexities) / len(valid_perplexities)
|
| 108 |
+
|
| 109 |
+
optimization_desc = config.quantization_type
|
| 110 |
+
if config.use_torch_compile:
|
| 111 |
+
optimization_desc += " + torch.compile"
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"model_name": f"{config.model_name} ({optimization_desc})",
|
| 115 |
+
"device": self.device,
|
| 116 |
+
"num_samples": len(results),
|
| 117 |
+
"avg_tokens_per_second": avg_tokens_per_second,
|
| 118 |
+
"avg_first_token_latency_seconds": avg_first_token_latency,
|
| 119 |
+
"max_memory_mb": max_memory_mb,
|
| 120 |
+
"avg_perplexity": avg_perplexity,
|
| 121 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 122 |
+
"optimization_type": optimization_desc
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
def save_results(self, results: Dict[str, Any], output_dir: str = "benchmark_results") -> str:
|
| 126 |
+
"""Save benchmark results."""
|
| 127 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
model_name = results["summary"]["model_name"].split('/')[-1].replace(' ', '_')
|
| 130 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 131 |
+
output_file = os.path.join(output_dir, f"{model_name}_{timestamp}.json")
|
| 132 |
+
|
| 133 |
+
with open(output_file, 'w') as f:
|
| 134 |
+
json.dump(results, f, indent=2)
|
| 135 |
+
|
| 136 |
+
return output_file
|
core/benchmark.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
import gc
|
| 4 |
+
import psutil
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Dict, List, Optional, Any
|
| 7 |
+
from torch.nn import CrossEntropyLoss
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class BenchmarkConfig:
|
| 11 |
+
"""Configuration for benchmarking."""
|
| 12 |
+
model_name: str
|
| 13 |
+
dataset_name: str = "tatsu-lab/alpaca"
|
| 14 |
+
num_samples: int = 20
|
| 15 |
+
max_new_tokens: int = 100
|
| 16 |
+
quantization_type: str = "none"
|
| 17 |
+
use_torch_compile: bool = False
|
| 18 |
+
calculate_perplexity: bool = False
|
| 19 |
+
device: Optional[str] = None
|
| 20 |
+
seed: int = 42
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class BenchmarkResult:
|
| 24 |
+
"""Single benchmark result."""
|
| 25 |
+
prompt_id: int
|
| 26 |
+
prompt: str
|
| 27 |
+
generated_text: str
|
| 28 |
+
input_tokens: int
|
| 29 |
+
output_tokens: int
|
| 30 |
+
total_time_seconds: float
|
| 31 |
+
tokens_per_second: float
|
| 32 |
+
first_token_latency_seconds: float
|
| 33 |
+
peak_memory_mb: float
|
| 34 |
+
perplexity: Optional[float] = None
|
| 35 |
+
|
| 36 |
+
class MemoryTracker:
|
| 37 |
+
"""Handles memory tracking across different devices."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, device: str):
|
| 40 |
+
self.device = device
|
| 41 |
+
|
| 42 |
+
def reset_stats(self):
|
| 43 |
+
"""Reset memory tracking."""
|
| 44 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
| 45 |
+
torch.cuda.reset_peak_memory_stats()
|
| 46 |
+
|
| 47 |
+
def get_peak_memory_mb(self) -> float:
|
| 48 |
+
"""Get peak memory usage in MB."""
|
| 49 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
| 50 |
+
return torch.cuda.max_memory_allocated() / (1024 * 1024)
|
| 51 |
+
else:
|
| 52 |
+
return psutil.Process().memory_info().rss / (1024 * 1024)
|
| 53 |
+
|
| 54 |
+
def synchronize(self):
|
| 55 |
+
"""Synchronize device operations."""
|
| 56 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
| 57 |
+
torch.cuda.synchronize()
|
| 58 |
+
elif self.device == "mps" and hasattr(torch.backends, 'mps'):
|
| 59 |
+
if hasattr(torch.mps, 'synchronize'):
|
| 60 |
+
torch.mps.synchronize()
|
| 61 |
+
|
| 62 |
+
def clear_cache(self):
|
| 63 |
+
"""Clear memory cache."""
|
| 64 |
+
gc.collect()
|
| 65 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
| 66 |
+
torch.cuda.empty_cache()
|
| 67 |
+
|
| 68 |
+
class PerplexityCalculator:
|
| 69 |
+
"""Handles perplexity calculation."""
|
| 70 |
+
|
| 71 |
+
def __init__(self, model, tokenizer, device: str):
|
| 72 |
+
self.model = model
|
| 73 |
+
self.tokenizer = tokenizer
|
| 74 |
+
self.device = device
|
| 75 |
+
|
| 76 |
+
def calculate(self, text: str) -> float:
|
| 77 |
+
"""Calculate perplexity of text."""
|
| 78 |
+
try:
|
| 79 |
+
encodings = self.tokenizer(text, return_tensors="pt").to(self.device)
|
| 80 |
+
input_ids = encodings.input_ids
|
| 81 |
+
|
| 82 |
+
if input_ids.size(1) <= 1:
|
| 83 |
+
return float('inf')
|
| 84 |
+
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
outputs = self.model(input_ids=input_ids, labels=input_ids.clone())
|
| 87 |
+
|
| 88 |
+
if hasattr(outputs, 'loss') and outputs.loss is not None:
|
| 89 |
+
return torch.exp(outputs.loss).item()
|
| 90 |
+
|
| 91 |
+
# Fallback manual calculation
|
| 92 |
+
logits = outputs.logits[:, :-1, :].contiguous()
|
| 93 |
+
labels = input_ids[:, 1:].contiguous()
|
| 94 |
+
|
| 95 |
+
loss_fn = CrossEntropyLoss()
|
| 96 |
+
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 97 |
+
return torch.exp(loss).item()
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"Perplexity calculation failed: {e}")
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
class InferenceRunner:
|
| 104 |
+
"""Handles model inference with timing and memory tracking."""
|
| 105 |
+
|
| 106 |
+
def __init__(self, model, tokenizer, device: str):
|
| 107 |
+
self.model = model
|
| 108 |
+
self.tokenizer = tokenizer
|
| 109 |
+
self.device = device
|
| 110 |
+
self.memory_tracker = MemoryTracker(device)
|
| 111 |
+
|
| 112 |
+
def run_single_inference(self, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
|
| 113 |
+
"""Run inference on a single prompt."""
|
| 114 |
+
# Tokenize input
|
| 115 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
| 116 |
+
input_token_count = input_ids.shape[1]
|
| 117 |
+
|
| 118 |
+
# Reset memory tracking
|
| 119 |
+
self.memory_tracker.reset_stats()
|
| 120 |
+
initial_memory = self.memory_tracker.get_peak_memory_mb()
|
| 121 |
+
|
| 122 |
+
# Generation parameters
|
| 123 |
+
gen_params = {
|
| 124 |
+
"max_new_tokens": max_new_tokens,
|
| 125 |
+
"do_sample": False,
|
| 126 |
+
"pad_token_id": self.tokenizer.eos_token_id
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Time first token
|
| 130 |
+
self.memory_tracker.synchronize()
|
| 131 |
+
first_token_start = time.time()
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
first_output = self.model.generate(input_ids, max_new_tokens=1, **{k: v for k, v in gen_params.items() if k != 'max_new_tokens'})
|
| 135 |
+
|
| 136 |
+
self.memory_tracker.synchronize()
|
| 137 |
+
first_token_latency = time.time() - first_token_start
|
| 138 |
+
|
| 139 |
+
# Full generation
|
| 140 |
+
start_time = time.time()
|
| 141 |
+
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
outputs = self.model.generate(input_ids, **gen_params)
|
| 144 |
+
|
| 145 |
+
self.memory_tracker.synchronize()
|
| 146 |
+
total_time = time.time() - start_time
|
| 147 |
+
|
| 148 |
+
# Calculate metrics
|
| 149 |
+
output_ids = outputs[0][input_token_count:]
|
| 150 |
+
generated_token_count = len(output_ids)
|
| 151 |
+
tokens_per_second = generated_token_count / total_time if total_time > 0 else 0
|
| 152 |
+
|
| 153 |
+
# Get memory usage
|
| 154 |
+
peak_memory_mb = self.memory_tracker.get_peak_memory_mb()
|
| 155 |
+
if self.device != "cuda":
|
| 156 |
+
peak_memory_mb = peak_memory_mb - initial_memory
|
| 157 |
+
|
| 158 |
+
# Decode output
|
| 159 |
+
generated_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
|
| 160 |
+
|
| 161 |
+
# Clear memory
|
| 162 |
+
self.memory_tracker.clear_cache()
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
"input_tokens": input_token_count,
|
| 166 |
+
"output_tokens": generated_token_count,
|
| 167 |
+
"total_time_seconds": total_time,
|
| 168 |
+
"tokens_per_second": tokens_per_second,
|
| 169 |
+
"first_token_latency_seconds": first_token_latency,
|
| 170 |
+
"peak_memory_mb": peak_memory_mb,
|
| 171 |
+
"generated_text": generated_text
|
| 172 |
+
}
|
core/data.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
class DatasetLoader:
|
| 6 |
+
"""Handles dataset loading and sampling."""
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def get_sample_prompts(dataset_name: str, num_samples: int, seed: int = 42) -> Tuple[List[str], List[int]]:
|
| 10 |
+
"""Get sample prompts from dataset."""
|
| 11 |
+
print(f"Loading dataset: {dataset_name}")
|
| 12 |
+
|
| 13 |
+
dataset = load_dataset(dataset_name)
|
| 14 |
+
split_name = 'train' if 'train' in dataset else list(dataset.keys())[0]
|
| 15 |
+
|
| 16 |
+
random.seed(seed)
|
| 17 |
+
indices = random.sample(range(len(dataset[split_name])), num_samples)
|
| 18 |
+
|
| 19 |
+
# Handle different dataset formats
|
| 20 |
+
samples = []
|
| 21 |
+
for idx in indices:
|
| 22 |
+
item = dataset[split_name][idx]
|
| 23 |
+
if 'instruction' in item:
|
| 24 |
+
samples.append(item['instruction'])
|
| 25 |
+
elif 'text' in item:
|
| 26 |
+
samples.append(item['text'])
|
| 27 |
+
elif 'prompt' in item:
|
| 28 |
+
samples.append(item['prompt'])
|
| 29 |
+
else:
|
| 30 |
+
# Fallback - use first text field
|
| 31 |
+
text_field = next(k for k, v in item.items() if isinstance(v, str))
|
| 32 |
+
samples.append(item[text_field])
|
| 33 |
+
|
| 34 |
+
return samples, indices
|
core/utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import psutil
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
def get_device(device: Optional[str] = None) -> str:
|
| 6 |
+
"""Auto-detect or validate device."""
|
| 7 |
+
if device is None:
|
| 8 |
+
if torch.cuda.is_available():
|
| 9 |
+
return "cuda"
|
| 10 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 11 |
+
return "mps"
|
| 12 |
+
else:
|
| 13 |
+
return "cpu"
|
| 14 |
+
return device
|
| 15 |
+
|
| 16 |
+
def get_system_info() -> str:
|
| 17 |
+
"""Get formatted system information."""
|
| 18 |
+
info = ["# System Information\n"]
|
| 19 |
+
|
| 20 |
+
# CPU
|
| 21 |
+
info.append(f"**CPU**: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count()} logical cores")
|
| 22 |
+
info.append(f"**Memory**: {psutil.virtual_memory().total / (1024**3):.2f} GB")
|
| 23 |
+
|
| 24 |
+
# GPU
|
| 25 |
+
if torch.cuda.is_available():
|
| 26 |
+
info.append(f"**CUDA**: {torch.cuda.get_device_name(0)}")
|
| 27 |
+
info.append(f"**CUDA Memory**: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
|
| 28 |
+
info.append(f"**CUDA Version**: {torch.version.cuda}")
|
| 29 |
+
|
| 30 |
+
# MPS
|
| 31 |
+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 32 |
+
info.append("**Apple Silicon**: MPS Available")
|
| 33 |
+
|
| 34 |
+
info.append(f"**PyTorch**: {torch.__version__}")
|
| 35 |
+
|
| 36 |
+
return "\n".join(info)
|
interfaces/gradio_app.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import plotly.express as px
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
from agent.benchmarker import ModelBenchmarker
|
| 8 |
+
from core.benchmark import BenchmarkConfig
|
| 9 |
+
from core.utils import get_system_info
|
| 10 |
+
|
| 11 |
+
class GradioApp:
|
| 12 |
+
"""Gradio web interface for model benchmarking."""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.benchmarker = ModelBenchmarker()
|
| 16 |
+
self.history = []
|
| 17 |
+
|
| 18 |
+
def benchmark_single(
|
| 19 |
+
self,
|
| 20 |
+
model_name: str,
|
| 21 |
+
dataset_name: str,
|
| 22 |
+
num_samples: int,
|
| 23 |
+
max_tokens: int,
|
| 24 |
+
quantization: str,
|
| 25 |
+
torch_compile: bool,
|
| 26 |
+
perplexity: bool,
|
| 27 |
+
device: str
|
| 28 |
+
) -> Tuple[str, str, str]:
|
| 29 |
+
"""Run single model benchmark."""
|
| 30 |
+
try:
|
| 31 |
+
config = BenchmarkConfig(
|
| 32 |
+
model_name=model_name,
|
| 33 |
+
dataset_name=dataset_name,
|
| 34 |
+
num_samples=num_samples,
|
| 35 |
+
max_new_tokens=max_tokens,
|
| 36 |
+
quantization_type=quantization,
|
| 37 |
+
use_torch_compile=torch_compile,
|
| 38 |
+
calculate_perplexity=perplexity,
|
| 39 |
+
device=device if device != "auto" else None
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
results = self.benchmarker.run_benchmark(config)
|
| 43 |
+
self.history.append(results)
|
| 44 |
+
|
| 45 |
+
# Format summary
|
| 46 |
+
summary = results["summary"]
|
| 47 |
+
summary_text = f"""## Benchmark Results
|
| 48 |
+
|
| 49 |
+
**Model**: {summary['model_name']}
|
| 50 |
+
**Device**: {summary['device']}
|
| 51 |
+
**Optimization**: {summary['optimization_type']}
|
| 52 |
+
|
| 53 |
+
### Performance Metrics
|
| 54 |
+
- **Throughput**: {summary['avg_tokens_per_second']:.2f} tokens/second
|
| 55 |
+
- **First Token Latency**: {summary['avg_first_token_latency_seconds']:.4f} seconds
|
| 56 |
+
- **Peak Memory**: {summary['max_memory_mb']:.2f} MB
|
| 57 |
+
- **Samples**: {summary['num_samples']}
|
| 58 |
+
{f"- **Perplexity**: {summary['avg_perplexity']:.4f}" if summary.get('avg_perplexity') else ""}
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# Sample results table
|
| 62 |
+
samples_df = pd.DataFrame(results['samples'])
|
| 63 |
+
if not samples_df.empty:
|
| 64 |
+
display_cols = ['prompt_id', 'input_tokens', 'output_tokens', 'tokens_per_second', 'first_token_latency_seconds']
|
| 65 |
+
samples_table = samples_df[display_cols].head(10).to_html(index=False)
|
| 66 |
+
else:
|
| 67 |
+
samples_table = "No sample data available"
|
| 68 |
+
|
| 69 |
+
return summary_text, samples_table, "✅ Benchmark completed!"
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
return f"❌ Error: {str(e)}", "", f"❌ Failed: {str(e)}"
|
| 73 |
+
|
| 74 |
+
def compare_optimizations(
|
| 75 |
+
self,
|
| 76 |
+
model_name: str,
|
| 77 |
+
dataset_name: str,
|
| 78 |
+
num_samples: int,
|
| 79 |
+
optimizations: List[str]
|
| 80 |
+
) -> Tuple[str, go.Figure, str]:
|
| 81 |
+
"""Compare different quantization."""
|
| 82 |
+
try:
|
| 83 |
+
results = []
|
| 84 |
+
|
| 85 |
+
for opt in optimizations:
|
| 86 |
+
config = BenchmarkConfig(
|
| 87 |
+
model_name=model_name,
|
| 88 |
+
dataset_name=dataset_name,
|
| 89 |
+
num_samples=num_samples,
|
| 90 |
+
quantization_type=opt,
|
| 91 |
+
calculate_perplexity=True
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
benchmarker = ModelBenchmarker() # Fresh instance
|
| 95 |
+
result = benchmarker.run_benchmark(config)
|
| 96 |
+
results.append(result["summary"])
|
| 97 |
+
|
| 98 |
+
# Create comparison
|
| 99 |
+
df = pd.DataFrame(results)
|
| 100 |
+
|
| 101 |
+
# Create plot
|
| 102 |
+
fig = go.Figure()
|
| 103 |
+
|
| 104 |
+
fig.add_trace(go.Bar(
|
| 105 |
+
name='Throughput',
|
| 106 |
+
x=df['optimization_type'],
|
| 107 |
+
y=df['avg_tokens_per_second'],
|
| 108 |
+
yaxis='y'
|
| 109 |
+
))
|
| 110 |
+
|
| 111 |
+
fig.add_trace(go.Scatter(
|
| 112 |
+
name='Memory (MB)',
|
| 113 |
+
x=df['optimization_type'],
|
| 114 |
+
y=df['max_memory_mb'],
|
| 115 |
+
yaxis='y2',
|
| 116 |
+
mode='lines+markers',
|
| 117 |
+
line=dict(color='red')
|
| 118 |
+
))
|
| 119 |
+
|
| 120 |
+
fig.update_layout(
|
| 121 |
+
title=f'Optimization Comparison: {model_name}',
|
| 122 |
+
xaxis_title='Optimization',
|
| 123 |
+
yaxis=dict(title='Throughput (tok/s)', side='left'),
|
| 124 |
+
yaxis2=dict(title='Memory (MB)', side='right', overlaying='y')
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Summary text
|
| 128 |
+
best_throughput = max(results, key=lambda x: x['avg_tokens_per_second'])
|
| 129 |
+
best_memory = min(results, key=lambda x: x['max_memory_mb'])
|
| 130 |
+
|
| 131 |
+
summary = f"""## Comparison Results
|
| 132 |
+
|
| 133 |
+
### Best Configurations
|
| 134 |
+
- **Highest Throughput**: {best_throughput['optimization_type']} ({best_throughput['avg_tokens_per_second']:.2f} tok/s)
|
| 135 |
+
- **Lowest Memory**: {best_memory['optimization_type']} ({best_memory['max_memory_mb']:.2f} MB)
|
| 136 |
+
|
| 137 |
+
### Results Table
|
| 138 |
+
| Optimization | Throughput | Memory | Perplexity |
|
| 139 |
+
|--------------|-----------|---------|-----------|
|
| 140 |
+
{chr(10).join([f"| {r['optimization_type']} | {r['avg_tokens_per_second']:.2f} | {r['max_memory_mb']:.2f} | {r.get('avg_perplexity', 'N/A')} |" for r in results])}
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
return summary, fig, "✅ Comparison completed!"
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
return f"❌ Error: {str(e)}", go.Figure(), f"❌ Failed: {str(e)}"
|
| 147 |
+
|
| 148 |
+
def get_history(self) -> str:
|
| 149 |
+
"""Get benchmark history."""
|
| 150 |
+
if not self.history:
|
| 151 |
+
return "No benchmarks run yet."
|
| 152 |
+
|
| 153 |
+
history_text = "# Benchmark History\n\n"
|
| 154 |
+
for i, result in enumerate(self.history):
|
| 155 |
+
summary = result["summary"]
|
| 156 |
+
history_text += f"""## Run {i+1}
|
| 157 |
+
- **Model**: {summary['model_name']}
|
| 158 |
+
- **Time**: {summary['timestamp']}
|
| 159 |
+
- **Throughput**: {summary['avg_tokens_per_second']:.2f} tok/s
|
| 160 |
+
- **Memory**: {summary['max_memory_mb']:.2f} MB
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
return history_text
|
| 166 |
+
|
| 167 |
+
def create_interface(self):
|
| 168 |
+
"""Create Gradio interface."""
|
| 169 |
+
with gr.Blocks(title="Model Benchmark Agent", theme=gr.themes.Soft()) as app:
|
| 170 |
+
gr.Markdown("# 🚀 Model Benchmark Agent")
|
| 171 |
+
gr.Markdown("Benchmark Hugging Face models with optimum-quanto quantization")
|
| 172 |
+
|
| 173 |
+
with gr.Tabs():
|
| 174 |
+
# Single Benchmark Tab
|
| 175 |
+
with gr.TabItem("Single Benchmark"):
|
| 176 |
+
with gr.Row():
|
| 177 |
+
with gr.Column():
|
| 178 |
+
model_input = gr.Textbox("facebook/opt-iml-max-1.3b", label="Model Name")
|
| 179 |
+
dataset_input = gr.Textbox("tatsu-lab/alpaca", label="Dataset")
|
| 180 |
+
num_samples = gr.Slider(1, 100, 20, step=1, label="Samples")
|
| 181 |
+
max_tokens = gr.Slider(10, 512, 100, label="Max Tokens")
|
| 182 |
+
quantization = gr.Dropdown(
|
| 183 |
+
["none", "int8", "int4", "int2", "float8"],
|
| 184 |
+
value="none",
|
| 185 |
+
label="Quantization"
|
| 186 |
+
)
|
| 187 |
+
torch_compile = gr.Checkbox(label="Use torch.compile")
|
| 188 |
+
perplexity = gr.Checkbox(label="Calculate Perplexity")
|
| 189 |
+
device = gr.Dropdown(["auto", "cuda", "cpu", "mps"], value="auto", label="Device")
|
| 190 |
+
|
| 191 |
+
benchmark_btn = gr.Button("🚀 Run Benchmark", variant="primary")
|
| 192 |
+
|
| 193 |
+
with gr.Column():
|
| 194 |
+
results_md = gr.Markdown()
|
| 195 |
+
samples_html = gr.HTML()
|
| 196 |
+
status_text = gr.Textbox(label="Status", interactive=False)
|
| 197 |
+
|
| 198 |
+
benchmark_btn.click(
|
| 199 |
+
self.benchmark_single,
|
| 200 |
+
inputs=[model_input, dataset_input, num_samples, max_tokens, quantization, torch_compile, perplexity, device],
|
| 201 |
+
outputs=[results_md, samples_html, status_text]
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Comparison Tab
|
| 205 |
+
with gr.TabItem("Compare Optimizations"):
|
| 206 |
+
with gr.Row():
|
| 207 |
+
with gr.Column():
|
| 208 |
+
comp_model = gr.Textbox("facebook/opt-iml-max-1.3b", label="Model")
|
| 209 |
+
comp_dataset = gr.Textbox("tatsu-lab/alpaca", label="Dataset")
|
| 210 |
+
comp_samples = gr.Slider(1, 50, 10, step=1, label="Samples")
|
| 211 |
+
comp_opts = gr.CheckboxGroup(
|
| 212 |
+
["none", "int8", "int4", "int2"],
|
| 213 |
+
value=["none", "int8"],
|
| 214 |
+
label="Optimizations to Compare"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
compare_btn = gr.Button("📊 Compare", variant="primary")
|
| 218 |
+
|
| 219 |
+
with gr.Column():
|
| 220 |
+
comp_results = gr.Markdown()
|
| 221 |
+
comp_plot = gr.Plot()
|
| 222 |
+
comp_status = gr.Textbox(label="Status", interactive=False)
|
| 223 |
+
|
| 224 |
+
compare_btn.click(
|
| 225 |
+
self.compare_optimizations,
|
| 226 |
+
inputs=[comp_model, comp_dataset, comp_samples, comp_opts],
|
| 227 |
+
outputs=[comp_results, comp_plot, comp_status]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# History Tab
|
| 231 |
+
with gr.TabItem("History"):
|
| 232 |
+
history_md = gr.Markdown()
|
| 233 |
+
refresh_btn = gr.Button("🔄 Refresh")
|
| 234 |
+
refresh_btn.click(self.get_history, outputs=[history_md])
|
| 235 |
+
|
| 236 |
+
# System Info Tab
|
| 237 |
+
with gr.TabItem("System Info"):
|
| 238 |
+
sys_info_md = gr.Markdown()
|
| 239 |
+
sys_info_btn = gr.Button("📋 Get System Info")
|
| 240 |
+
sys_info_btn.click(get_system_info, outputs=[sys_info_md])
|
| 241 |
+
|
| 242 |
+
return app
|
| 243 |
+
|
| 244 |
+
def launch_app():
|
| 245 |
+
"""Launch the Gradio app."""
|
| 246 |
+
app = GradioApp()
|
| 247 |
+
interface = app.create_interface()
|
| 248 |
+
interface.launch(share=False,
|
| 249 |
+
server_name="0.0.0.0",
|
| 250 |
+
server_port=7860,
|
| 251 |
+
show_error=True,
|
| 252 |
+
mcp_server=True)
|
main.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Main entry point for the Model Benchmark Agent."""
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Disable tokenizer parallelism to avoid forking issues
|
| 6 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 7 |
+
|
| 8 |
+
#!/usr/bin/env python3
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
from interfaces.gradio_app import launch_app
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
# Check if running on HuggingFace Spaces
|
| 15 |
+
is_huggingface = os.getenv("SPACE_ID") is not None
|
| 16 |
+
|
| 17 |
+
# If on HuggingFace or gradio argument passed, launch Gradio
|
| 18 |
+
if is_huggingface or (len(sys.argv) > 1 and sys.argv[1] == "gradio"):
|
| 19 |
+
launch_app()
|
| 20 |
+
else:
|
| 21 |
+
# Your existing logic for other modes
|
| 22 |
+
print("Usage: python main.py [gradio]")
|
| 23 |
+
print("Available modes:")
|
| 24 |
+
print(" gradio - Launch Gradio interface")
|
| 25 |
+
# Add other modes you support
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
models/quantization.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
| 3 |
+
from optimum.quanto import quantize, freeze, qint8, qint4, qint2, qfloat8
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Tuple, Any, Optional
|
| 6 |
+
|
| 7 |
+
class QuantizationType(Enum):
|
| 8 |
+
"""Supported quantization types."""
|
| 9 |
+
NONE = "none"
|
| 10 |
+
INT8 = "int8"
|
| 11 |
+
INT4 = "int4"
|
| 12 |
+
INT2 = "int2"
|
| 13 |
+
FLOAT8 = "float8"
|
| 14 |
+
|
| 15 |
+
class ModelLoader:
|
| 16 |
+
"""Handles model loading with different quantization strategies."""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def load_standard(model_name: str, device: str) -> Tuple[Any, Any]:
|
| 20 |
+
"""Load model without quantization."""
|
| 21 |
+
print(f"Loading {model_name} (standard)")
|
| 22 |
+
|
| 23 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
+
model_name,
|
| 25 |
+
trust_remote_code=True,
|
| 26 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 27 |
+
device_map=device if device != "cpu" else None
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
if device == "cpu":
|
| 31 |
+
model = model.to(device)
|
| 32 |
+
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 34 |
+
if tokenizer.pad_token is None:
|
| 35 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 36 |
+
|
| 37 |
+
return model, tokenizer
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def load_quantized_transformers(model_name: str, quant_type: QuantizationType) -> Tuple[Any, Any]:
|
| 41 |
+
"""Load model using Transformers QuantoConfig integration."""
|
| 42 |
+
print(f"Loading {model_name} with {quant_type.value} quantization (Transformers)")
|
| 43 |
+
|
| 44 |
+
quant_config = QuantoConfig(weights=quant_type.value)
|
| 45 |
+
|
| 46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 47 |
+
model_name,
|
| 48 |
+
trust_remote_code=True,
|
| 49 |
+
torch_dtype="auto",
|
| 50 |
+
device_map="auto",
|
| 51 |
+
quantization_config=quant_config
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 55 |
+
if tokenizer.pad_token is None:
|
| 56 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 57 |
+
|
| 58 |
+
return model, tokenizer
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def load_quantized_direct(model_name: str, quant_type: QuantizationType, device: str) -> Tuple[Any, Any]:
|
| 62 |
+
"""Load model using direct quanto quantization API."""
|
| 63 |
+
print(f"Loading {model_name} with {quant_type.value} quantization (Direct API)")
|
| 64 |
+
|
| 65 |
+
# Load base model
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
model_name,
|
| 68 |
+
trust_remote_code=True,
|
| 69 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 70 |
+
device_map=device if device != "cpu" else None
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if device == "cpu":
|
| 74 |
+
model = model.to(device)
|
| 75 |
+
|
| 76 |
+
# Apply quantization
|
| 77 |
+
quant_map = {
|
| 78 |
+
QuantizationType.INT8: qint8,
|
| 79 |
+
QuantizationType.INT4: qint4,
|
| 80 |
+
QuantizationType.INT2: qint2,
|
| 81 |
+
QuantizationType.FLOAT8: qfloat8
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
quantize(model, weights=quant_map[quant_type])
|
| 85 |
+
freeze(model)
|
| 86 |
+
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 88 |
+
if tokenizer.pad_token is None:
|
| 89 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 90 |
+
|
| 91 |
+
return model, tokenizer
|
requirements.txt
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile pyproject.toml --output-file requirements.txt
|
| 3 |
+
accelerate==1.7.0
|
| 4 |
+
# via model-benchmark-agent (pyproject.toml)
|
| 5 |
+
aiofiles==24.1.0
|
| 6 |
+
# via gradio
|
| 7 |
+
aiohappyeyeballs==2.6.1
|
| 8 |
+
# via aiohttp
|
| 9 |
+
aiohttp==3.12.12
|
| 10 |
+
# via fsspec
|
| 11 |
+
aiosignal==1.3.2
|
| 12 |
+
# via aiohttp
|
| 13 |
+
annotated-types==0.7.0
|
| 14 |
+
# via pydantic
|
| 15 |
+
anyio==4.9.0
|
| 16 |
+
# via
|
| 17 |
+
# gradio
|
| 18 |
+
# httpx
|
| 19 |
+
# mcp
|
| 20 |
+
# sse-starlette
|
| 21 |
+
# starlette
|
| 22 |
+
attrs==25.3.0
|
| 23 |
+
# via aiohttp
|
| 24 |
+
certifi==2025.4.26
|
| 25 |
+
# via
|
| 26 |
+
# httpcore
|
| 27 |
+
# httpx
|
| 28 |
+
# requests
|
| 29 |
+
charset-normalizer==3.4.2
|
| 30 |
+
# via requests
|
| 31 |
+
click==8.2.1
|
| 32 |
+
# via
|
| 33 |
+
# typer
|
| 34 |
+
# uvicorn
|
| 35 |
+
datasets==3.6.0
|
| 36 |
+
# via model-benchmark-agent (pyproject.toml)
|
| 37 |
+
dill==0.3.8
|
| 38 |
+
# via
|
| 39 |
+
# datasets
|
| 40 |
+
# multiprocess
|
| 41 |
+
fastapi==0.115.12
|
| 42 |
+
# via gradio
|
| 43 |
+
ffmpy==0.6.0
|
| 44 |
+
# via gradio
|
| 45 |
+
filelock==3.18.0
|
| 46 |
+
# via
|
| 47 |
+
# datasets
|
| 48 |
+
# huggingface-hub
|
| 49 |
+
# torch
|
| 50 |
+
# transformers
|
| 51 |
+
frozenlist==1.7.0
|
| 52 |
+
# via
|
| 53 |
+
# aiohttp
|
| 54 |
+
# aiosignal
|
| 55 |
+
fsspec==2025.3.0
|
| 56 |
+
# via
|
| 57 |
+
# datasets
|
| 58 |
+
# gradio-client
|
| 59 |
+
# huggingface-hub
|
| 60 |
+
# torch
|
| 61 |
+
gradio==5.33.1
|
| 62 |
+
# via model-benchmark-agent (pyproject.toml)
|
| 63 |
+
gradio-client==1.10.3
|
| 64 |
+
# via gradio
|
| 65 |
+
groovy==0.1.2
|
| 66 |
+
# via gradio
|
| 67 |
+
h11==0.16.0
|
| 68 |
+
# via
|
| 69 |
+
# httpcore
|
| 70 |
+
# uvicorn
|
| 71 |
+
hf-xet==1.1.3
|
| 72 |
+
# via huggingface-hub
|
| 73 |
+
httpcore==1.0.9
|
| 74 |
+
# via httpx
|
| 75 |
+
httpx==0.28.1
|
| 76 |
+
# via
|
| 77 |
+
# gradio
|
| 78 |
+
# gradio-client
|
| 79 |
+
# mcp
|
| 80 |
+
# safehttpx
|
| 81 |
+
httpx-sse==0.4.0
|
| 82 |
+
# via mcp
|
| 83 |
+
huggingface-hub==0.32.5
|
| 84 |
+
# via
|
| 85 |
+
# accelerate
|
| 86 |
+
# datasets
|
| 87 |
+
# gradio
|
| 88 |
+
# gradio-client
|
| 89 |
+
# optimum-quanto
|
| 90 |
+
# tokenizers
|
| 91 |
+
# transformers
|
| 92 |
+
idna==3.10
|
| 93 |
+
# via
|
| 94 |
+
# anyio
|
| 95 |
+
# httpx
|
| 96 |
+
# requests
|
| 97 |
+
# yarl
|
| 98 |
+
jinja2==3.1.6
|
| 99 |
+
# via
|
| 100 |
+
# gradio
|
| 101 |
+
# torch
|
| 102 |
+
markdown-it-py==3.0.0
|
| 103 |
+
# via rich
|
| 104 |
+
markupsafe==3.0.2
|
| 105 |
+
# via
|
| 106 |
+
# gradio
|
| 107 |
+
# jinja2
|
| 108 |
+
mcp==1.9.3
|
| 109 |
+
# via model-benchmark-agent (pyproject.toml)
|
| 110 |
+
mdurl==0.1.2
|
| 111 |
+
# via markdown-it-py
|
| 112 |
+
mpmath==1.3.0
|
| 113 |
+
# via sympy
|
| 114 |
+
multidict==6.4.4
|
| 115 |
+
# via
|
| 116 |
+
# aiohttp
|
| 117 |
+
# yarl
|
| 118 |
+
multiprocess==0.70.16
|
| 119 |
+
# via datasets
|
| 120 |
+
narwhals==1.42.0
|
| 121 |
+
# via plotly
|
| 122 |
+
networkx==3.5
|
| 123 |
+
# via torch
|
| 124 |
+
ninja==1.11.1.4
|
| 125 |
+
# via optimum-quanto
|
| 126 |
+
numpy==2.3.0
|
| 127 |
+
# via
|
| 128 |
+
# model-benchmark-agent (pyproject.toml)
|
| 129 |
+
# accelerate
|
| 130 |
+
# datasets
|
| 131 |
+
# gradio
|
| 132 |
+
# optimum-quanto
|
| 133 |
+
# pandas
|
| 134 |
+
# transformers
|
| 135 |
+
optimum-quanto==0.2.7
|
| 136 |
+
# via model-benchmark-agent (pyproject.toml)
|
| 137 |
+
orjson==3.10.18
|
| 138 |
+
# via gradio
|
| 139 |
+
packaging==25.0
|
| 140 |
+
# via
|
| 141 |
+
# accelerate
|
| 142 |
+
# datasets
|
| 143 |
+
# gradio
|
| 144 |
+
# gradio-client
|
| 145 |
+
# huggingface-hub
|
| 146 |
+
# plotly
|
| 147 |
+
# transformers
|
| 148 |
+
pandas==2.3.0
|
| 149 |
+
# via
|
| 150 |
+
# model-benchmark-agent (pyproject.toml)
|
| 151 |
+
# datasets
|
| 152 |
+
# gradio
|
| 153 |
+
pillow==11.2.1
|
| 154 |
+
# via gradio
|
| 155 |
+
plotly==6.1.2
|
| 156 |
+
# via model-benchmark-agent (pyproject.toml)
|
| 157 |
+
propcache==0.3.2
|
| 158 |
+
# via
|
| 159 |
+
# aiohttp
|
| 160 |
+
# yarl
|
| 161 |
+
psutil==7.0.0
|
| 162 |
+
# via
|
| 163 |
+
# model-benchmark-agent (pyproject.toml)
|
| 164 |
+
# accelerate
|
| 165 |
+
pyarrow==20.0.0
|
| 166 |
+
# via datasets
|
| 167 |
+
pydantic==2.11.5
|
| 168 |
+
# via
|
| 169 |
+
# model-benchmark-agent (pyproject.toml)
|
| 170 |
+
# fastapi
|
| 171 |
+
# gradio
|
| 172 |
+
# mcp
|
| 173 |
+
# pydantic-settings
|
| 174 |
+
pydantic-core==2.33.2
|
| 175 |
+
# via pydantic
|
| 176 |
+
pydantic-settings==2.9.1
|
| 177 |
+
# via mcp
|
| 178 |
+
pydub==0.25.1
|
| 179 |
+
# via gradio
|
| 180 |
+
pygments==2.19.1
|
| 181 |
+
# via rich
|
| 182 |
+
python-dateutil==2.9.0.post0
|
| 183 |
+
# via pandas
|
| 184 |
+
python-dotenv==1.1.0
|
| 185 |
+
# via pydantic-settings
|
| 186 |
+
python-multipart==0.0.20
|
| 187 |
+
# via
|
| 188 |
+
# gradio
|
| 189 |
+
# mcp
|
| 190 |
+
pytz==2025.2
|
| 191 |
+
# via pandas
|
| 192 |
+
pyyaml==6.0.2
|
| 193 |
+
# via
|
| 194 |
+
# accelerate
|
| 195 |
+
# datasets
|
| 196 |
+
# gradio
|
| 197 |
+
# huggingface-hub
|
| 198 |
+
# transformers
|
| 199 |
+
regex==2024.11.6
|
| 200 |
+
# via transformers
|
| 201 |
+
requests==2.32.4
|
| 202 |
+
# via
|
| 203 |
+
# datasets
|
| 204 |
+
# huggingface-hub
|
| 205 |
+
# transformers
|
| 206 |
+
rich==14.0.0
|
| 207 |
+
# via typer
|
| 208 |
+
ruff==0.11.13
|
| 209 |
+
# via gradio
|
| 210 |
+
safehttpx==0.1.6
|
| 211 |
+
# via gradio
|
| 212 |
+
safetensors==0.5.3
|
| 213 |
+
# via
|
| 214 |
+
# accelerate
|
| 215 |
+
# optimum-quanto
|
| 216 |
+
# transformers
|
| 217 |
+
semantic-version==2.10.0
|
| 218 |
+
# via gradio
|
| 219 |
+
setuptools==80.9.0
|
| 220 |
+
# via torch
|
| 221 |
+
shellingham==1.5.4
|
| 222 |
+
# via typer
|
| 223 |
+
six==1.17.0
|
| 224 |
+
# via python-dateutil
|
| 225 |
+
sniffio==1.3.1
|
| 226 |
+
# via anyio
|
| 227 |
+
sse-starlette==2.3.6
|
| 228 |
+
# via mcp
|
| 229 |
+
starlette==0.46.2
|
| 230 |
+
# via
|
| 231 |
+
# fastapi
|
| 232 |
+
# gradio
|
| 233 |
+
# mcp
|
| 234 |
+
sympy==1.14.0
|
| 235 |
+
# via torch
|
| 236 |
+
tokenizers==0.21.1
|
| 237 |
+
# via transformers
|
| 238 |
+
tomlkit==0.13.3
|
| 239 |
+
# via gradio
|
| 240 |
+
torch==2.7.1
|
| 241 |
+
# via
|
| 242 |
+
# model-benchmark-agent (pyproject.toml)
|
| 243 |
+
# accelerate
|
| 244 |
+
# optimum-quanto
|
| 245 |
+
tqdm==4.67.1
|
| 246 |
+
# via
|
| 247 |
+
# datasets
|
| 248 |
+
# huggingface-hub
|
| 249 |
+
# transformers
|
| 250 |
+
transformers==4.52.4
|
| 251 |
+
# via model-benchmark-agent (pyproject.toml)
|
| 252 |
+
typer==0.16.0
|
| 253 |
+
# via gradio
|
| 254 |
+
typing-extensions==4.14.0
|
| 255 |
+
# via
|
| 256 |
+
# anyio
|
| 257 |
+
# fastapi
|
| 258 |
+
# gradio
|
| 259 |
+
# gradio-client
|
| 260 |
+
# huggingface-hub
|
| 261 |
+
# pydantic
|
| 262 |
+
# pydantic-core
|
| 263 |
+
# torch
|
| 264 |
+
# typer
|
| 265 |
+
# typing-inspection
|
| 266 |
+
typing-inspection==0.4.1
|
| 267 |
+
# via
|
| 268 |
+
# pydantic
|
| 269 |
+
# pydantic-settings
|
| 270 |
+
tzdata==2025.2
|
| 271 |
+
# via pandas
|
| 272 |
+
urllib3==2.4.0
|
| 273 |
+
# via requests
|
| 274 |
+
uvicorn==0.34.3
|
| 275 |
+
# via
|
| 276 |
+
# gradio
|
| 277 |
+
# mcp
|
| 278 |
+
websockets==15.0.1
|
| 279 |
+
# via gradio-client
|
| 280 |
+
xxhash==3.5.0
|
| 281 |
+
# via datasets
|
| 282 |
+
yarl==1.20.1
|
| 283 |
+
# via aiohttp
|