Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| analyze.py — 单张图片坐姿检测 | |
| 用法: | |
| python analyze.py <image_path> | |
| python analyze.py <image_path> --save | |
| python analyze.py <image_path> --save <output_path> | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # 切换到脚本所在目录,确保 load_model.py 里的相对路径(./data/inference_models/)能正确找到模型 | |
| os.chdir(Path(__file__).parent) | |
| import sys | |
| import types | |
| # yolov5 兼容 shim(新版 huggingface_hub 移除了 utils._errors 子模块) | |
| try: | |
| import huggingface_hub.utils._errors # noqa: F401 | |
| except (ModuleNotFoundError, ImportError): | |
| import huggingface_hub.errors as _hf_errors | |
| _shim = types.ModuleType("huggingface_hub.utils._errors") | |
| for _name in dir(_hf_errors): | |
| setattr(_shim, _name, getattr(_hf_errors, _name)) | |
| sys.modules["huggingface_hub.utils._errors"] = _shim | |
| import torch | |
| # PyTorch 2.6+ 默认 weights_only=True,旧版 yolov5 模型需要关闭 | |
| _orig_torch_load = torch.load | |
| def _patched_torch_load(*args, **kwargs): | |
| kwargs.setdefault("weights_only", False) | |
| return _orig_torch_load(*args, **kwargs) | |
| torch.load = _patched_torch_load | |
| import cv2 | |
| from app_models.load_model import InferenceModel | |
| def draw_result(img, x1, y1, x2, y2, label, conf): | |
| """在原图上叠加黄色检测框和标签""" | |
| color = (0, 255, 255) # 黄色 (BGR) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), color, 2) | |
| text = f"{label} {conf:.2f}" | |
| # 标签背景,防止文字看不清 | |
| (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) | |
| cv2.rectangle(img, (x1, y1 - th - 10), (x1 + tw + 4, y1), color, -1) | |
| cv2.putText(img, text, (x1 + 2, y1 - 6), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) | |
| return img | |
| def main(): | |
| parser = argparse.ArgumentParser(description="坐姿检测(YOLOv5)") | |
| parser.add_argument("image", help="输入图片路径(JPG / PNG)") | |
| parser.add_argument( | |
| "--save", | |
| nargs="?", | |
| const="", # --save 不带路径时用默认名 | |
| metavar="OUTPUT", | |
| help="保存标注图;不指定路径则存为 <原文件名>_result.jpg", | |
| ) | |
| args = parser.parse_args() | |
| image_path = Path(args.image).resolve() | |
| if not image_path.exists(): | |
| print(f"错误:找不到图片 {image_path}") | |
| sys.exit(1) | |
| # 读图 | |
| img = cv2.imread(str(image_path)) | |
| if img is None: | |
| print(f"错误:无法读取图片 {image_path}") | |
| sys.exit(1) | |
| # 加载模型 & 推理 | |
| model = InferenceModel("small640.pt") | |
| results = model.predict(img) | |
| x1, y1, x2, y2, cls, conf = InferenceModel.get_results(results) | |
| # 模型已设 conf=0.50,结果为空说明低于阈值 | |
| if cls is None: | |
| print("未检测到人") | |
| return | |
| label = "good" if cls == 0 else "bad" | |
| print(f"姿势:{label}(置信度 {conf:.2f})") | |
| print(f"BBox:[x1={x1}, y1={y1}, x2={x2}, y2={y2}]") | |
| # 保存标注图(仅在 --save 时) | |
| if args.save is not None: | |
| if args.save == "": | |
| output_path = image_path.parent / (image_path.stem + "_result" + image_path.suffix) | |
| else: | |
| output_path = Path(args.save) | |
| annotated = draw_result(img.copy(), x1, y1, x2, y2, label, conf) | |
| cv2.imwrite(str(output_path), annotated) | |
| print(f"标注图已保存:{output_path}") | |
| if __name__ == "__main__": | |
| main() | |