| | import matplotlib.pyplot as plt |
| |
|
| | |
| | labels = ['Llama-3 (BF16)', 'INT8 Quant', 'JiRack Ternary (2-bit)'] |
| | vram_usage = [810, 405, 243] |
| | colors = ['#ff9999', '#66b3ff', '#99ff99'] |
| |
|
| | def generate_vram_chart(): |
| | plt.figure(figsize=(10, 6)) |
| | bars = plt.bar(labels, vram_usage, color=colors) |
| | |
| | plt.title('VRAM Weight Footprint: 405B Model Comparison', fontsize=14) |
| | plt.ylabel('VRAM Usage (GB)', fontsize=12) |
| | plt.grid(axis='y', linestyle='--', alpha=0.7) |
| | |
| | |
| | for bar in bars: |
| | yval = bar.get_height() |
| | plt.text(bar.get_x() + bar.get_width()/2, yval + 10, f'{yval} GB', ha='center', va='bottom', fontweight='bold') |
| |
|
| | plt.tight_layout() |
| | plt.savefig('vram_benchmark_405b.png') |
| | print("Benchmark chart saved as 'vram_benchmark_405b.png'") |
| | plt.show() |
| |
|
| | if __name__ == "__main__": |
| | generate_vram_chart() |