import matplotlib.pyplot as plt import numpy as np # --- 1. 数据定义 (已根据您的指正进行修正) --- # 简化模型名称 model_names = ['FP32 PTV3', 'Ours'] # 【修改】图表的类别现在是 FLOPs 和 Params categories = ['FLOPs (G)', 'Params (M)'] # 【修改】定义每个模型在每个类别下的数据 # 格式: [Baseline Value, Tiny Value] flops_data = [15.50, 1.43] params_data = [11.76, 1.36] # --- 2. 图表绘制 --- # 设置全局字体 plt.rcParams['font.family'] = 'sans-serif' plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans'] # 创建图表和坐标轴 fig, ax = plt.subplots(figsize=(8, 6), dpi=100) # 设置条形图的位置和宽度 bar_width = 0.35 x = np.arange(len(categories)) # 【修改】绘制条形图,现在使用 params_data rects1 = ax.bar(x - bar_width/2, [flops_data[0], params_data[0]], bar_width, label=model_names[0], color='#4c72b0', edgecolor='black', linewidth=1.2) rects2 = ax.bar(x + bar_width/2, [flops_data[1], params_data[1]], bar_width, label=model_names[1], color='#dd8452', edgecolor='black', linewidth=1.2) # --- 3. 添加必要的标签 --- ax.set_ylabel('Value', fontsize=14, labelpad=10) ax.set_xticks(x) ax.set_xticklabels(categories, fontsize=12) ax.legend(fontsize=11, loc='upper right') # 调整图例位置 # 【修改】Y轴范围根据新数据调整 ax.set_ylim(0, max(flops_data[0], params_data[0]) * 1.2) ax.grid(True, which='major', axis='y', linestyle='--', color='gray', alpha=0.6) ax.set_axisbelow(True) # --- 4. 在条形图上添加数值标签 --- def autolabel(rects, ax): """在每个条形图上方附加一个文本标签,显示其高度。""" for rect in rects: height = rect.get_height() ax.annotate(f'{height:.2f}', xy=(rect.get_x() + rect.get_width() / 2, height), xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=10, weight='bold') autolabel(rects1, ax) autolabel(rects2, ax) # --- 5. 保存图表到文件 --- fig.tight_layout() # 修改保存文件名为 flops_params_comparison.png plt.savefig('flops_params_comparison.png', dpi=300, bbox_inches='tight') print("FLOPs与Params对比图已成功保存为 flops_params_comparison.png")