cherrydata / run.py
Voidljc
Your commit message
aa24fe8
Raw
History Blame Contribute Delete
922 Bytes
from ultralytics import YOLO
import matplotlib.pyplot as plt
import numpy as np
# 1. 加载你的模型
model = YOLO("D:/wampee_yolov8/wampee/wampee/ultralytics/runs/detect/train4/weights/best.pt")
# 2. 使用测试集评估
results = model.val(data="D:/wampee_yolov8/wampee/wampee/ultralytics/train.yaml", split="test")
# 3. 提取指标
precision = results.box.pr
recall = results.box.re
f1 = results.box.f1
class_names = [results.names[i] for i in range(len(precision))]
# 4. 绘图
x = np.arange(len(class_names))
plt.figure(figsize=(8, 5))
plt.plot(x, precision, marker='o', label='Precision')
plt.plot(x, recall, marker='s', label='Recall')
plt.plot(x, f1, marker='^', label='F1-score')
plt.xticks(x, class_names, rotation=45)
plt.xlabel("Class")
plt.ylabel("Score")
plt.title("Precision / Recall / F1-score per Class")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("prf1_per_class.png")
plt.show()