Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import tempfile | |
| import cv2 | |
| import numpy as np | |
| from pytorch_lightning import Trainer | |
| from torch.utils.data import DataLoader | |
| from anomalib.config import get_configurable_parameters | |
| from anomalib.data.inference import InferenceDataset | |
| from anomalib.data.utils import InputNormalizationMethod, get_transforms | |
| from anomalib.models import get_model | |
| from anomalib.utils.callbacks import get_callbacks | |
| from utils.preprocess import fill_background_and_crop, convert_to_png | |
| from utils.save_anomaly_map_as_png import save_anomaly_map_as_png, overlay_images, overlay_anomaly_outline_on_base_image | |
| from utils.tensor_to_list import tensor_to_list | |
| def has_blue_pixels(image_path): | |
| # 画像を読み込む | |
| image = cv2.imread(str(image_path)) | |
| # 青色のピクセルに対応するマスクを生成 | |
| blue_pixels = (image[:, :, 0] == 0) & ( | |
| image[:, :, 1] == 0) & (image[:, :, 2] == 255) | |
| # 青色のピクセルが存在するかどうかを判定 | |
| return np.any(blue_pixels) | |
| def infer( | |
| input: np.ndarray, | |
| suffix: str, | |
| threshold, | |
| threshold_median, | |
| edge_threshold, | |
| configPath: Path, | |
| weightsPath: Path): | |
| """Run inference.""" | |
| config = get_configurable_parameters(config_path=configPath) | |
| config.trainer.resume_from_checkpoint = weightsPath | |
| config.visualization.mode = "simple" | |
| sessionId = str(np.random.randint(0, 1000000)) | |
| config.visualization.save_images = True | |
| config.visualization.image_save_path = "/tmp/" + sessionId | |
| # create model and trainer | |
| model = get_model(config) | |
| callbacks = get_callbacks(config) | |
| trainer = Trainer( | |
| callbacks=callbacks, | |
| **config.trainer, | |
| ) | |
| # get the transforms | |
| if "transform_config" in config.dataset.keys(): | |
| transform_config = config.dataset.transform_config.eval | |
| else: | |
| transform_config = None | |
| image_size = (config.dataset.image_size[0], config.dataset.image_size[1]) | |
| center_crop = config.dataset.get("center_crop") | |
| if center_crop is not None: | |
| center_crop = tuple(center_crop) | |
| normalization = InputNormalizationMethod(config.dataset.normalization) | |
| transform = get_transforms( | |
| config=transform_config, | |
| image_size=image_size, | |
| center_crop=center_crop, | |
| normalization=normalization | |
| ) | |
| # save the input image to a temporary file | |
| input_path = Path(tempfile.mktemp(suffix=suffix)) | |
| cv2.imwrite(str(input_path), input) | |
| png_path = Path(convert_to_png(input_path)) | |
| fill_background_and_crop(png_path, edge_threshold=edge_threshold) | |
| # create the dataset | |
| dataset = InferenceDataset( | |
| png_path, | |
| image_size=tuple(config.dataset.image_size), | |
| transform=transform | |
| ) | |
| dataloader = DataLoader(dataset) | |
| # generate predictions | |
| result = trainer.predict( | |
| model=model, | |
| dataloaders=[dataloader], | |
| return_predictions=True, | |
| ) | |
| anomaly_map = result[0]["anomaly_maps"].squeeze() | |
| anomaly_map_path = "/tmp/" + sessionId + "/tmp/anomaly_map.png" | |
| save_anomaly_map_as_png(anomaly_map, anomaly_map_path, threshold) | |
| overlay_path = "/tmp/" + sessionId + "/tmp/overlay.png" | |
| overlay_images(str(png_path), str(anomaly_map_path), str(overlay_path)) | |
| overlay_outline_path = "/tmp/" + sessionId + "/tmp/overlay_outline.png" | |
| overlay_anomaly_outline_on_base_image( | |
| str(png_path), anomaly_map, str(overlay_outline_path), threshold, ) | |
| anomaly_map_list = tensor_to_list(anomaly_map) | |
| # read the output image | |
| output_path = "/tmp/" + sessionId + "/tmp/" + png_path.name | |
| # Check if the output image has blue pixels | |
| anomaly_map_max = anomaly_map.max().item() | |
| anomaly_map_median = anomaly_map.median().item() | |
| is_anomaly = anomaly_map_max - \ | |
| anomaly_map_median > threshold or anomaly_map_median > threshold_median | |
| return is_anomaly, anomaly_map_max, anomaly_map_median, anomaly_map_list, cv2.imread(str(overlay_path)), cv2.imread(str(anomaly_map_path)), cv2.imread(str(overlay_outline_path)) | |