| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| |
| |
| model_names = ['FP32 PTV3', 'Ours'] |
| |
| categories = ['FLOPs (G)', 'Params (M)'] |
|
|
| |
| |
| flops_data = [15.50, 1.43] |
| params_data = [11.76, 1.36] |
|
|
| |
| |
| plt.rcParams['font.family'] = 'sans-serif' |
| plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans'] |
|
|
| |
| fig, ax = plt.subplots(figsize=(8, 6), dpi=100) |
|
|
| |
| bar_width = 0.35 |
| x = np.arange(len(categories)) |
|
|
| |
| rects1 = ax.bar(x - bar_width/2, [flops_data[0], params_data[0]], bar_width, |
| label=model_names[0], color='#4c72b0', edgecolor='black', linewidth=1.2) |
| rects2 = ax.bar(x + bar_width/2, [flops_data[1], params_data[1]], bar_width, |
| label=model_names[1], color='#dd8452', edgecolor='black', linewidth=1.2) |
|
|
| |
| ax.set_ylabel('Value', fontsize=14, labelpad=10) |
| ax.set_xticks(x) |
| ax.set_xticklabels(categories, fontsize=12) |
| ax.legend(fontsize=11, loc='upper right') |
|
|
| |
| ax.set_ylim(0, max(flops_data[0], params_data[0]) * 1.2) |
| ax.grid(True, which='major', axis='y', linestyle='--', color='gray', alpha=0.6) |
| ax.set_axisbelow(True) |
|
|
| |
| def autolabel(rects, ax): |
| """在每个条形图上方附加一个文本标签,显示其高度。""" |
| for rect in rects: |
| height = rect.get_height() |
| ax.annotate(f'{height:.2f}', |
| xy=(rect.get_x() + rect.get_width() / 2, height), |
| xytext=(0, 5), |
| textcoords="offset points", |
| ha='center', va='bottom', |
| fontsize=10, weight='bold') |
|
|
| autolabel(rects1, ax) |
| autolabel(rects2, ax) |
|
|
| |
| fig.tight_layout() |
| |
| plt.savefig('flops_params_comparison.png', dpi=300, bbox_inches='tight') |
|
|
| print("FLOPs与Params对比图已成功保存为 flops_params_comparison.png") |