File size: 2,346 Bytes
7b95dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import matplotlib.pyplot as plt
import numpy as np

# --- 1. 数据定义 (已根据您的指正进行修正) ---
# 简化模型名称
model_names = ['FP32 PTV3', 'Ours']
# 【修改】图表的类别现在是 FLOPs 和 Params
categories = ['FLOPs (G)', 'Params (M)']

# 【修改】定义每个模型在每个类别下的数据
# 格式: [Baseline Value, Tiny Value]
flops_data = [15.50, 1.43]
params_data = [11.76, 1.36]

# --- 2. 图表绘制 ---
# 设置全局字体
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']

# 创建图表和坐标轴
fig, ax = plt.subplots(figsize=(8, 6), dpi=100)

# 设置条形图的位置和宽度
bar_width = 0.35
x = np.arange(len(categories))

# 【修改】绘制条形图,现在使用 params_data
rects1 = ax.bar(x - bar_width/2, [flops_data[0], params_data[0]], bar_width, 
                label=model_names[0], color='#4c72b0', edgecolor='black', linewidth=1.2)
rects2 = ax.bar(x + bar_width/2, [flops_data[1], params_data[1]], bar_width, 
                label=model_names[1], color='#dd8452', edgecolor='black', linewidth=1.2)

# --- 3. 添加必要的标签 ---
ax.set_ylabel('Value', fontsize=14, labelpad=10)
ax.set_xticks(x)
ax.set_xticklabels(categories, fontsize=12)
ax.legend(fontsize=11, loc='upper right') # 调整图例位置

# 【修改】Y轴范围根据新数据调整
ax.set_ylim(0, max(flops_data[0], params_data[0]) * 1.2)
ax.grid(True, which='major', axis='y', linestyle='--', color='gray', alpha=0.6)
ax.set_axisbelow(True)

# --- 4. 在条形图上添加数值标签 ---
def autolabel(rects, ax):
    """在每个条形图上方附加一个文本标签,显示其高度。"""
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.2f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 5),
                    textcoords="offset points",
                    ha='center', va='bottom',
                    fontsize=10, weight='bold')

autolabel(rects1, ax)
autolabel(rects2, ax)

# --- 5. 保存图表到文件 ---
fig.tight_layout()
# 修改保存文件名为 flops_params_comparison.png
plt.savefig('flops_params_comparison.png', dpi=300, bbox_inches='tight') 

print("FLOPs与Params对比图已成功保存为 flops_params_comparison.png")