JiRackTernary_405b / Benchmark.py
kgrabko's picture
Create Benchmark.py
d1d1463 verified
import matplotlib.pyplot as plt
# Data for 405B model weights in GB
labels = ['Llama-3 (BF16)', 'INT8 Quant', 'JiRack Ternary (2-bit)']
vram_usage = [810, 405, 243]
colors = ['#ff9999', '#66b3ff', '#99ff99']
def generate_vram_chart():
plt.figure(figsize=(10, 6))
bars = plt.bar(labels, vram_usage, color=colors)
plt.title('VRAM Weight Footprint: 405B Model Comparison', fontsize=14)
plt.ylabel('VRAM Usage (GB)', fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)
# Add values on top of bars
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 10, f'{yval} GB', ha='center', va='bottom', fontweight='bold')
plt.tight_layout()
plt.savefig('vram_benchmark_405b.png')
print("Benchmark chart saved as 'vram_benchmark_405b.png'")
plt.show()
if __name__ == "__main__":
generate_vram_chart()