anomaly-detection / lightning_inference.py
sanbasan383's picture
物体検出のしきい値をパラメータとして渡せるようにする
ac0e360
from pathlib import Path
import tempfile
import cv2
import numpy as np
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from anomalib.config import get_configurable_parameters
from anomalib.data.inference import InferenceDataset
from anomalib.data.utils import InputNormalizationMethod, get_transforms
from anomalib.models import get_model
from anomalib.utils.callbacks import get_callbacks
from utils.preprocess import fill_background_and_crop, convert_to_png
from utils.save_anomaly_map_as_png import save_anomaly_map_as_png, overlay_images, overlay_anomaly_outline_on_base_image
from utils.tensor_to_list import tensor_to_list
def has_blue_pixels(image_path):
# 画像を読み込む
image = cv2.imread(str(image_path))
# 青色のピクセルに対応するマスクを生成
blue_pixels = (image[:, :, 0] == 0) & (
image[:, :, 1] == 0) & (image[:, :, 2] == 255)
# 青色のピクセルが存在するかどうかを判定
return np.any(blue_pixels)
def infer(
input: np.ndarray,
suffix: str,
threshold,
threshold_median,
edge_threshold,
configPath: Path,
weightsPath: Path):
"""Run inference."""
config = get_configurable_parameters(config_path=configPath)
config.trainer.resume_from_checkpoint = weightsPath
config.visualization.mode = "simple"
sessionId = str(np.random.randint(0, 1000000))
config.visualization.save_images = True
config.visualization.image_save_path = "/tmp/" + sessionId
# create model and trainer
model = get_model(config)
callbacks = get_callbacks(config)
trainer = Trainer(
callbacks=callbacks,
**config.trainer,
)
# get the transforms
if "transform_config" in config.dataset.keys():
transform_config = config.dataset.transform_config.eval
else:
transform_config = None
image_size = (config.dataset.image_size[0], config.dataset.image_size[1])
center_crop = config.dataset.get("center_crop")
if center_crop is not None:
center_crop = tuple(center_crop)
normalization = InputNormalizationMethod(config.dataset.normalization)
transform = get_transforms(
config=transform_config,
image_size=image_size,
center_crop=center_crop,
normalization=normalization
)
# save the input image to a temporary file
input_path = Path(tempfile.mktemp(suffix=suffix))
cv2.imwrite(str(input_path), input)
png_path = Path(convert_to_png(input_path))
fill_background_and_crop(png_path, edge_threshold=edge_threshold)
# create the dataset
dataset = InferenceDataset(
png_path,
image_size=tuple(config.dataset.image_size),
transform=transform
)
dataloader = DataLoader(dataset)
# generate predictions
result = trainer.predict(
model=model,
dataloaders=[dataloader],
return_predictions=True,
)
anomaly_map = result[0]["anomaly_maps"].squeeze()
anomaly_map_path = "/tmp/" + sessionId + "/tmp/anomaly_map.png"
save_anomaly_map_as_png(anomaly_map, anomaly_map_path, threshold)
overlay_path = "/tmp/" + sessionId + "/tmp/overlay.png"
overlay_images(str(png_path), str(anomaly_map_path), str(overlay_path))
overlay_outline_path = "/tmp/" + sessionId + "/tmp/overlay_outline.png"
overlay_anomaly_outline_on_base_image(
str(png_path), anomaly_map, str(overlay_outline_path), threshold, )
anomaly_map_list = tensor_to_list(anomaly_map)
# read the output image
output_path = "/tmp/" + sessionId + "/tmp/" + png_path.name
# Check if the output image has blue pixels
anomaly_map_max = anomaly_map.max().item()
anomaly_map_median = anomaly_map.median().item()
is_anomaly = anomaly_map_max - \
anomaly_map_median > threshold or anomaly_map_median > threshold_median
return is_anomaly, anomaly_map_max, anomaly_map_median, anomaly_map_list, cv2.imread(str(overlay_path)), cv2.imread(str(anomaly_map_path)), cv2.imread(str(overlay_outline_path))