File size: 6,924 Bytes
a496893 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """Инференс по полному фото детали кузова.
Алгоритм:
1) Вырезаем панель из фона.
2) Скользящим окном (PATCH_SIZE с шагом PATCH_STRIDE) собираем патчи.
3) Прогоняем батчем через сеть -> вероятность "defect" для каждого патча.
4) Аккумулируем вероятности в карту дефектов того же размера, что панель.
5) Возвращаем: вердикт по детали, маску, координаты bounding box'ов дефектов,
визуализацию (наложение тепловой карты).
Запуск:
python -m src.infer --image путь/к/фото.jpg --out runs/result.jpg
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from . import config as C
from .model import build_model
from .prepare_data import crop_panel, imread_unicode, imwrite_unicode
_TRANSFORM = A.Compose([
A.Resize(C.IMG_SIZE, C.IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def load_model(checkpoint: Path | str = None, device: torch.device | str = "cpu"):
ckpt_path = Path(checkpoint) if checkpoint else C.CHECKPOINTS / "best.pt"
state = torch.load(ckpt_path, map_location=device, weights_only=False)
from .model import DefectClassifier
backbone = state.get("backbone", C.BACKBONE)
model = DefectClassifier(backbone=backbone, pretrained=False).to(device)
model.load_state_dict(state["model"])
model.eval()
return model
def _slide_coords(h: int, w: int, size: int, stride: int) -> list[tuple[int, int]]:
if h < size or w < size:
return [(0, 0)]
ys = list(range(0, h - size + 1, stride))
xs = list(range(0, w - size + 1, stride))
if ys[-1] != h - size: ys.append(h - size)
if xs[-1] != w - size: xs.append(w - size)
return [(y, x) for y in ys for x in xs]
def _to_batch(patches: list[np.ndarray]) -> torch.Tensor:
tensors = [_TRANSFORM(image=cv2.cvtColor(p, cv2.COLOR_BGR2RGB))["image"]
for p in patches]
return torch.stack(tensors, dim=0)
def predict_image(image_bgr: np.ndarray, model, device,
threshold: float = C.DEFECT_THRESHOLD,
panel_defect_ratio: float = C.PANEL_DEFECT_RATIO) -> dict[str, Any]:
"""Возвращает dict с результатом анализа полного фото."""
panel = crop_panel(image_bgr) if C.PANEL_CROP else image_bgr
H, W = panel.shape[:2]
coords = _slide_coords(H, W, C.PATCH_SIZE, C.PATCH_STRIDE)
patches = [panel[y:y + C.PATCH_SIZE, x:x + C.PATCH_SIZE] for y, x in coords]
if not patches:
patches = [cv2.resize(panel, (C.PATCH_SIZE, C.PATCH_SIZE))]
coords = [(0, 0)]
# инференс батчами
bs = 32
probs = []
with torch.no_grad():
for i in range(0, len(patches), bs):
batch = _to_batch(patches[i:i + bs]).to(device)
logits = model(batch)
p = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
probs.extend(p.tolist())
# карта вероятностей дефекта по панели
heatmap = np.zeros((H, W), dtype=np.float32)
weights = np.zeros((H, W), dtype=np.float32)
for (y, x), p in zip(coords, probs):
ye = min(y + C.PATCH_SIZE, H); xe = min(x + C.PATCH_SIZE, W)
heatmap[y:ye, x:xe] += p
weights[y:ye, x:xe] += 1.0
heatmap = heatmap / np.maximum(weights, 1e-6)
# бинарная маска дефектов
mask = (heatmap >= threshold).astype(np.uint8) * 255
defect_pixels = int(mask.sum() / 255)
defect_ratio = defect_pixels / max(H * W, 1)
# bbox'ы дефектов
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
for c in contours:
if cv2.contourArea(c) < 200: # отсекаем шум
continue
x, y, w, h = cv2.boundingRect(c)
roi = heatmap[y:y + h, x:x + w]
boxes.append({
"x": int(x), "y": int(y), "w": int(w), "h": int(h),
"confidence": float(roi.max()),
"mean_prob": float(roi.mean()),
})
is_defect = bool(defect_ratio >= panel_defect_ratio and len(boxes) > 0)
return {
"is_defect": is_defect,
"defect_ratio": float(defect_ratio),
"max_prob": float(heatmap.max()),
"boxes": boxes,
"panel_size": {"h": int(H), "w": int(W)},
"heatmap": heatmap,
"panel": panel,
}
def render_visualization(result: dict) -> np.ndarray:
"""Накладывает тепловую карту и bbox'ы на панель."""
panel = result["panel"].copy()
hm = result["heatmap"]
hm_norm = np.clip(hm, 0.0, 1.0)
colored = cv2.applyColorMap((hm_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(panel, 0.6, colored, 0.4, 0)
for b in result["boxes"]:
x, y, w, h = b["x"], b["y"], b["w"], b["h"]
cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 0, 255), 3)
label = f"{b['confidence']:.2f}"
cv2.putText(overlay, label, (x, max(20, y - 8)),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
verdict = "DEFECT" if result["is_defect"] else "OK"
color = (0, 0, 255) if result["is_defect"] else (0, 200, 0)
cv2.rectangle(overlay, (0, 0), (320, 60), (0, 0, 0), -1)
cv2.putText(overlay, verdict, (12, 44), cv2.FONT_HERSHEY_SIMPLEX, 1.4, color, 3)
return overlay
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--image", required=True, type=Path)
ap.add_argument("--checkpoint", type=Path, default=None)
ap.add_argument("--out", type=Path, default=C.RUNS / "result.jpg")
ap.add_argument("--threshold", type=float, default=C.DEFECT_THRESHOLD)
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(args.checkpoint, device)
bgr = imread_unicode(args.image)
if bgr is None:
raise SystemExit(f"Не удалось прочитать {args.image}")
res = predict_image(bgr, model, device, threshold=args.threshold)
args.out.parent.mkdir(parents=True, exist_ok=True)
imwrite_unicode(args.out, render_visualization(res), [cv2.IMWRITE_JPEG_QUALITY, 90])
# JSON-отчёт без numpy-полей
report = {k: v for k, v in res.items() if k not in {"heatmap", "panel"}}
print(json.dumps(report, indent=2, ensure_ascii=False))
print(f"\nВизуализация: {args.out}")
if __name__ == "__main__":
main()
|