|
|
|
|
|
""" |
|
|
遍历 /opt/data/chinese_celeb_dataset 下的图片,使用 YOLO 人脸检测并删除没有检测到人脸的图片。 |
|
|
|
|
|
用法示例: |
|
|
python test/remove_faceless_images.py --dry-run |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Iterable, List, Optional |
|
|
|
|
|
import config |
|
|
|
|
|
try: |
|
|
from ultralytics import YOLO |
|
|
except ImportError as exc: |
|
|
raise SystemExit("缺少 ultralytics,请先执行 pip install ultralytics") from exc |
|
|
|
|
|
|
|
|
DEFAULT_DATASET_DIR = Path("/opt/data/chinese_celeb_dataset") |
|
|
MODEL_DIR = Path(config.MODELS_PATH) |
|
|
YOLO_MODEL_NAME = config.YOLO_MODEL |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="使用 YOLO 检测 /opt/data/chinese_celeb_dataset 中的图片并删除无脸图片" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset-dir", |
|
|
type=Path, |
|
|
default=DEFAULT_DATASET_DIR, |
|
|
help="需要检查的根目录(默认:/opt/data/chinese_celeb_dataset)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--extensions", |
|
|
type=str, |
|
|
default=".jpg,.jpeg,.png,.webp,.bmp", |
|
|
help="需要检查的图片扩展名,逗号分隔", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--confidence", |
|
|
type=float, |
|
|
default=config.FACE_CONFIDENCE, |
|
|
help="YOLO 检测的人脸置信度阈值", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dry-run", |
|
|
action="store_true", |
|
|
help="仅输出将被删除的文件,不真正删除,便于先预览结果", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verbose", |
|
|
action="store_true", |
|
|
help="输出更多调试信息", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_yolo_model() -> YOLO: |
|
|
""" |
|
|
优先加载本地 models 目录下配置好的模型,如果不存在则回退为模型名称(会触发自动下载)。 |
|
|
""" |
|
|
candidates: List[str] = [] |
|
|
local_path = MODEL_DIR / YOLO_MODEL_NAME |
|
|
if local_path.exists(): |
|
|
candidates.append(str(local_path)) |
|
|
candidates.append(YOLO_MODEL_NAME) |
|
|
|
|
|
last_error: Optional[Exception] = None |
|
|
for candidate in candidates: |
|
|
try: |
|
|
config.logger.info("尝试加载 YOLO 模型:%s", candidate) |
|
|
return YOLO(candidate) |
|
|
except Exception as exc: |
|
|
last_error = exc |
|
|
config.logger.warning("加载 YOLO 模型失败:%s -> %s", candidate, exc) |
|
|
|
|
|
raise RuntimeError(f"无法加载 YOLO 模型:{YOLO_MODEL_NAME}") from last_error |
|
|
|
|
|
|
|
|
def iter_image_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]: |
|
|
lower_exts = tuple(ext.strip().lower() for ext in extensions if ext.strip()) |
|
|
for path in root.rglob("*"): |
|
|
if not path.is_file(): |
|
|
continue |
|
|
if path.suffix.lower() in lower_exts: |
|
|
yield path |
|
|
|
|
|
|
|
|
def has_face(model: YOLO, image_path: Path, confidence: float, verbose: bool = False) -> bool: |
|
|
""" |
|
|
使用 YOLO 检测图片中是否存在人脸。检测到任意一个框即可视为有人脸。 |
|
|
""" |
|
|
try: |
|
|
results = model(image_path, conf=confidence, verbose=False) |
|
|
except Exception as exc: |
|
|
config.logger.error("检测失败,跳过 %s:%s", image_path, exc) |
|
|
return False |
|
|
|
|
|
for result in results: |
|
|
boxes = getattr(result, "boxes", None) |
|
|
if boxes is None: |
|
|
continue |
|
|
if len(boxes) > 0: |
|
|
if verbose: |
|
|
faces = [] |
|
|
for box in boxes: |
|
|
cls_id = int(box.cls[0]) if getattr(box, "cls", None) is not None else -1 |
|
|
score = float(box.conf[0]) if getattr(box, "conf", None) is not None else 0.0 |
|
|
faces.append({"cls": cls_id, "conf": score}) |
|
|
config.logger.info("检测到人脸:%s -> %s", image_path, faces) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
dataset_dir: Path = args.dataset_dir.expanduser().resolve() |
|
|
if not dataset_dir.exists(): |
|
|
raise SystemExit(f"目录不存在:{dataset_dir}") |
|
|
|
|
|
model = load_yolo_model() |
|
|
image_paths = list(iter_image_files(dataset_dir, args.extensions.split(","))) |
|
|
total = len(image_paths) |
|
|
if total == 0: |
|
|
print(f"目录 {dataset_dir} 下没有匹配到图片文件") |
|
|
return |
|
|
|
|
|
removed = 0 |
|
|
errored = 0 |
|
|
for idx, image_path in enumerate(image_paths, start=1): |
|
|
if idx % 100 == 0 or args.verbose: |
|
|
print(f"[{idx}/{total}] 正在处理 {image_path}") |
|
|
|
|
|
try: |
|
|
if has_face(model, image_path, args.confidence, args.verbose): |
|
|
continue |
|
|
except Exception as exc: |
|
|
errored += 1 |
|
|
config.logger.error("检测过程中发生异常,跳过 %s:%s", image_path, exc) |
|
|
continue |
|
|
|
|
|
if args.dry_run: |
|
|
print(f"[DRY-RUN] 将删除:{image_path}") |
|
|
else: |
|
|
try: |
|
|
image_path.unlink() |
|
|
print(f"已删除:{image_path}") |
|
|
except Exception as exc: |
|
|
errored += 1 |
|
|
config.logger.error("删除失败 %s:%s", image_path, exc) |
|
|
continue |
|
|
removed += 1 |
|
|
|
|
|
print( |
|
|
f"扫描完成,检测图片 {total} 张,删除 {removed} 张无脸图片,异常 {errored} 张,数据保存在:{dataset_dir}" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
main() |
|
|
except KeyboardInterrupt: |
|
|
sys.exit("用户中断") |
|
|
|