picpocket2 / test /remove_faceless_images.py
chawin.chen
init
7a6cb13
#!/usr/bin/env python3
"""
遍历 /opt/data/chinese_celeb_dataset 下的图片,使用 YOLO 人脸检测并删除没有检测到人脸的图片。
用法示例:
python test/remove_faceless_images.py --dry-run
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Iterable, List, Optional
import config
try:
from ultralytics import YOLO
except ImportError as exc: # pragma: no cover - 运行期缺依赖提示
raise SystemExit("缺少 ultralytics,请先执行 pip install ultralytics") from exc
# 默认数据集与模型配置
DEFAULT_DATASET_DIR = Path("/opt/data/chinese_celeb_dataset")
MODEL_DIR = Path(config.MODELS_PATH)
YOLO_MODEL_NAME = config.YOLO_MODEL
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="使用 YOLO 检测 /opt/data/chinese_celeb_dataset 中的图片并删除无脸图片"
)
parser.add_argument(
"--dataset-dir",
type=Path,
default=DEFAULT_DATASET_DIR,
help="需要检查的根目录(默认:/opt/data/chinese_celeb_dataset)",
)
parser.add_argument(
"--extensions",
type=str,
default=".jpg,.jpeg,.png,.webp,.bmp",
help="需要检查的图片扩展名,逗号分隔",
)
parser.add_argument(
"--confidence",
type=float,
default=config.FACE_CONFIDENCE,
help="YOLO 检测的人脸置信度阈值",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="仅输出将被删除的文件,不真正删除,便于先预览结果",
)
parser.add_argument(
"--verbose",
action="store_true",
help="输出更多调试信息",
)
return parser.parse_args()
def load_yolo_model() -> YOLO:
"""
优先加载本地 models 目录下配置好的模型,如果不存在则回退为模型名称(会触发自动下载)。
"""
candidates: List[str] = []
local_path = MODEL_DIR / YOLO_MODEL_NAME
if local_path.exists():
candidates.append(str(local_path))
candidates.append(YOLO_MODEL_NAME)
last_error: Optional[Exception] = None
for candidate in candidates:
try:
config.logger.info("尝试加载 YOLO 模型:%s", candidate)
return YOLO(candidate)
except Exception as exc: # pragma: no cover
last_error = exc
config.logger.warning("加载 YOLO 模型失败:%s -> %s", candidate, exc)
raise RuntimeError(f"无法加载 YOLO 模型:{YOLO_MODEL_NAME}") from last_error
def iter_image_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]:
lower_exts = tuple(ext.strip().lower() for ext in extensions if ext.strip())
for path in root.rglob("*"):
if not path.is_file():
continue
if path.suffix.lower() in lower_exts:
yield path
def has_face(model: YOLO, image_path: Path, confidence: float, verbose: bool = False) -> bool:
"""
使用 YOLO 检测图片中是否存在人脸。检测到任意一个框即可视为有人脸。
"""
try:
results = model(image_path, conf=confidence, verbose=False)
except Exception as exc: # pragma: no cover
config.logger.error("检测失败,跳过 %s:%s", image_path, exc)
return False
for result in results:
boxes = getattr(result, "boxes", None)
if boxes is None:
continue
if len(boxes) > 0:
if verbose:
faces = []
for box in boxes:
cls_id = int(box.cls[0]) if getattr(box, "cls", None) is not None else -1
score = float(box.conf[0]) if getattr(box, "conf", None) is not None else 0.0
faces.append({"cls": cls_id, "conf": score})
config.logger.info("检测到人脸:%s -> %s", image_path, faces)
return True
return False
def main() -> None:
args = parse_args()
dataset_dir: Path = args.dataset_dir.expanduser().resolve()
if not dataset_dir.exists():
raise SystemExit(f"目录不存在:{dataset_dir}")
model = load_yolo_model()
image_paths = list(iter_image_files(dataset_dir, args.extensions.split(",")))
total = len(image_paths)
if total == 0:
print(f"目录 {dataset_dir} 下没有匹配到图片文件")
return
removed = 0
errored = 0
for idx, image_path in enumerate(image_paths, start=1):
if idx % 100 == 0 or args.verbose:
print(f"[{idx}/{total}] 正在处理 {image_path}")
try:
if has_face(model, image_path, args.confidence, args.verbose):
continue
except Exception as exc: # pragma: no cover
errored += 1
config.logger.error("检测过程中发生异常,跳过 %s:%s", image_path, exc)
continue
if args.dry_run:
print(f"[DRY-RUN] 将删除:{image_path}")
else:
try:
image_path.unlink()
print(f"已删除:{image_path}")
except Exception as exc: # pragma: no cover
errored += 1
config.logger.error("删除失败 %s:%s", image_path, exc)
continue
removed += 1
print(
f"扫描完成,检测图片 {total} 张,删除 {removed} 张无脸图片,异常 {errored} 张,数据保存在:{dataset_dir}"
)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt: # pragma: no cover
sys.exit("用户中断")