File size: 2,346 Bytes
7b95dc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import matplotlib.pyplot as plt
import numpy as np
# --- 1. 数据定义 (已根据您的指正进行修正) ---
# 简化模型名称
model_names = ['FP32 PTV3', 'Ours']
# 【修改】图表的类别现在是 FLOPs 和 Params
categories = ['FLOPs (G)', 'Params (M)']
# 【修改】定义每个模型在每个类别下的数据
# 格式: [Baseline Value, Tiny Value]
flops_data = [15.50, 1.43]
params_data = [11.76, 1.36]
# --- 2. 图表绘制 ---
# 设置全局字体
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
# 创建图表和坐标轴
fig, ax = plt.subplots(figsize=(8, 6), dpi=100)
# 设置条形图的位置和宽度
bar_width = 0.35
x = np.arange(len(categories))
# 【修改】绘制条形图,现在使用 params_data
rects1 = ax.bar(x - bar_width/2, [flops_data[0], params_data[0]], bar_width,
label=model_names[0], color='#4c72b0', edgecolor='black', linewidth=1.2)
rects2 = ax.bar(x + bar_width/2, [flops_data[1], params_data[1]], bar_width,
label=model_names[1], color='#dd8452', edgecolor='black', linewidth=1.2)
# --- 3. 添加必要的标签 ---
ax.set_ylabel('Value', fontsize=14, labelpad=10)
ax.set_xticks(x)
ax.set_xticklabels(categories, fontsize=12)
ax.legend(fontsize=11, loc='upper right') # 调整图例位置
# 【修改】Y轴范围根据新数据调整
ax.set_ylim(0, max(flops_data[0], params_data[0]) * 1.2)
ax.grid(True, which='major', axis='y', linestyle='--', color='gray', alpha=0.6)
ax.set_axisbelow(True)
# --- 4. 在条形图上添加数值标签 ---
def autolabel(rects, ax):
"""在每个条形图上方附加一个文本标签,显示其高度。"""
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height:.2f}',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 5),
textcoords="offset points",
ha='center', va='bottom',
fontsize=10, weight='bold')
autolabel(rects1, ax)
autolabel(rects2, ax)
# --- 5. 保存图表到文件 ---
fig.tight_layout()
# 修改保存文件名为 flops_params_comparison.png
plt.savefig('flops_params_comparison.png', dpi=300, bbox_inches='tight')
print("FLOPs与Params对比图已成功保存为 flops_params_comparison.png") |