MaxMagician
Initial HF Space: Gradio sitting posture demo
c3155e8
#!/usr/bin/env python3
"""
analyze.py — 单张图片坐姿检测
用法:
python analyze.py <image_path>
python analyze.py <image_path> --save
python analyze.py <image_path> --save <output_path>
"""
import argparse
import os
import sys
from pathlib import Path
# 切换到脚本所在目录,确保 load_model.py 里的相对路径(./data/inference_models/)能正确找到模型
os.chdir(Path(__file__).parent)
import sys
import types
# yolov5 兼容 shim(新版 huggingface_hub 移除了 utils._errors 子模块)
try:
import huggingface_hub.utils._errors # noqa: F401
except (ModuleNotFoundError, ImportError):
import huggingface_hub.errors as _hf_errors
_shim = types.ModuleType("huggingface_hub.utils._errors")
for _name in dir(_hf_errors):
setattr(_shim, _name, getattr(_hf_errors, _name))
sys.modules["huggingface_hub.utils._errors"] = _shim
import torch
# PyTorch 2.6+ 默认 weights_only=True,旧版 yolov5 模型需要关闭
_orig_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
import cv2
from app_models.load_model import InferenceModel
def draw_result(img, x1, y1, x2, y2, label, conf):
"""在原图上叠加黄色检测框和标签"""
color = (0, 255, 255) # 黄色 (BGR)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
text = f"{label} {conf:.2f}"
# 标签背景,防止文字看不清
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
cv2.rectangle(img, (x1, y1 - th - 10), (x1 + tw + 4, y1), color, -1)
cv2.putText(img, text, (x1 + 2, y1 - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
return img
def main():
parser = argparse.ArgumentParser(description="坐姿检测(YOLOv5)")
parser.add_argument("image", help="输入图片路径(JPG / PNG)")
parser.add_argument(
"--save",
nargs="?",
const="", # --save 不带路径时用默认名
metavar="OUTPUT",
help="保存标注图;不指定路径则存为 <原文件名>_result.jpg",
)
args = parser.parse_args()
image_path = Path(args.image).resolve()
if not image_path.exists():
print(f"错误:找不到图片 {image_path}")
sys.exit(1)
# 读图
img = cv2.imread(str(image_path))
if img is None:
print(f"错误:无法读取图片 {image_path}")
sys.exit(1)
# 加载模型 & 推理
model = InferenceModel("small640.pt")
results = model.predict(img)
x1, y1, x2, y2, cls, conf = InferenceModel.get_results(results)
# 模型已设 conf=0.50,结果为空说明低于阈值
if cls is None:
print("未检测到人")
return
label = "good" if cls == 0 else "bad"
print(f"姿势:{label}(置信度 {conf:.2f})")
print(f"BBox:[x1={x1}, y1={y1}, x2={x2}, y2={y2}]")
# 保存标注图(仅在 --save 时)
if args.save is not None:
if args.save == "":
output_path = image_path.parent / (image_path.stem + "_result" + image_path.suffix)
else:
output_path = Path(args.save)
annotated = draw_result(img.copy(), x1, y1, x2, y2, label, conf)
cv2.imwrite(str(output_path), annotated)
print(f"标注图已保存:{output_path}")
if __name__ == "__main__":
main()