Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import albumentations as A | |
| from torch.utils.data import DataLoader | |
| from PIL import Image, ImageOps | |
| from animaloc.datasets import CSVDataset | |
| from animaloc.data.transforms import DownSample | |
| from inference.utils_io import mkdir, get_temp_image_path | |
| def build_normalize_transform(mean: list, std: list) -> A.Normalize: | |
| """ | |
| Construye una transformaci贸n de normalizaci贸n id茅ntica | |
| a la utilizada durante el entrenamiento. | |
| """ | |
| return A.Normalize(mean=mean, std=std, p=1.0) | |
| def build_end_transforms(down_ratio: int = 2): | |
| """ | |
| Construye el conjunto de transformaciones finales utilizadas | |
| durante la inferencia. | |
| """ | |
| return [DownSample(down_ratio=down_ratio, anno_type="point")] | |
| def create_single_image_dataset( | |
| image_pil, | |
| mean: list, | |
| std: list, | |
| down_ratio: int = 2 | |
| ): | |
| """ | |
| Crea un CSVDataset temporal y su DataLoader a partir de una 煤nica imagen PIL. | |
| Guarda la imagen temporalmente en disco (resources/uploads) para la API de HerdNet. | |
| Retorna | |
| ------- | |
| dataset : CSVDataset | |
| Dataset temporal con una sola imagen. | |
| dataloader : DataLoader | |
| Cargador de datos correspondiente. | |
| temp_path : str | |
| Ruta absoluta de la imagen guardada temporalmente. | |
| """ | |
| # Crear directorio y archivo temporal | |
| upload_dir = "resources/uploads" | |
| mkdir(upload_dir) | |
| temp_path = get_temp_image_path(upload_dir) | |
| image_pil.save(temp_path, format="JPEG") | |
| # Construir DataFrame para CSVDataset | |
| df = pd.DataFrame({ | |
| "images": [os.path.basename(temp_path)], | |
| "x": [0], | |
| "y": [0], | |
| "labels": [1], | |
| }) | |
| # Normalizaci贸n Albumentations | |
| normalize = A.Normalize(mean=mean, std=std, p=1.0) | |
| end_transforms = [DownSample(down_ratio=down_ratio, anno_type="point")] | |
| # Crear dataset y dataloader | |
| dataset = CSVDataset( | |
| csv_file=df, | |
| root_dir=os.path.dirname(temp_path), | |
| albu_transforms=[normalize], | |
| end_transforms=end_transforms, | |
| ) | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False) | |
| return dataset, dataloader, temp_path | |
| def create_single_image_dataset_safe( | |
| image_pil, | |
| down_ratio: int = 2, | |
| patch_size: int = 512, | |
| overlap: int = 160, | |
| mean: tuple = (0.485, 0.456, 0.406), | |
| std: tuple = (0.229, 0.224, 0.225), | |
| ): | |
| """ | |
| Prepara una imagen para inferencia con HerdNet, asegurando compatibilidad | |
| geom茅trica, crom谩tica y num茅rica con los par谩metros del modelo. | |
| Esta funci贸n reemplaza completamente la antigua create_single_image_dataset | |
| y evita fallos comunes al recibir im谩genes de distinta fuente, tama帽o o formato. | |
| Par谩metros | |
| ---------- | |
| image_pil : PIL.Image | |
| Imagen de entrada (cualquier resoluci贸n o fuente). | |
| down_ratio : int | |
| Factor de reducci贸n espacial del modelo (por defecto 2). | |
| patch_size : int | |
| Tama帽o del parche usado en entrenamiento (por defecto 512). | |
| overlap : int | |
| Superposici贸n entre parches (por defecto 160). | |
| mean, std : tuple | |
| Par谩metros de normalizaci贸n (por defecto ImageNet). | |
| Retorna | |
| ------- | |
| dataset : CSVDataset | |
| Dataset temporal con una 煤nica imagen lista para inferencia. | |
| dataloader : DataLoader | |
| Cargador de datos asociado al dataset. | |
| temp_path : str | |
| Ruta absoluta del archivo temporal guardado. | |
| """ | |
| # ======================================================= | |
| # 1. Correcci贸n de orientaci贸n y canales | |
| # ======================================================= | |
| image_pil = ImageOps.exif_transpose(image_pil) | |
| image_pil = image_pil.convert("RGB") | |
| # ======================================================= | |
| # 2. Limpieza y validaci贸n de valores | |
| # ======================================================= | |
| arr = np.array(image_pil).astype(np.uint8) | |
| arr = np.clip(arr, 0, 255) | |
| image_pil = Image.fromarray(arr) | |
| # ======================================================= | |
| # 3. Ajuste de resoluci贸n compatible con el modelo | |
| # ======================================================= | |
| w, h = image_pil.size | |
| # Forzar m煤ltiplos del down_ratio | |
| new_w = int(np.ceil(w / down_ratio) * down_ratio) | |
| new_h = int(np.ceil(h / down_ratio) * down_ratio) | |
| pad_w, pad_h = new_w - w, new_h - h | |
| if pad_w > 0 or pad_h > 0: | |
| image_pil = ImageOps.expand(image_pil, border=(0, 0, pad_w, pad_h), fill=(0, 0, 0)) | |
| # Forzar m煤ltiplos de patch_size - overlap | |
| step = patch_size - overlap | |
| w, h = image_pil.size | |
| if (w % step != 0) or (h % step != 0): | |
| new_w = int(np.ceil(w / step) * step) | |
| new_h = int(np.ceil(h / step) * step) | |
| image_pil = image_pil.resize((new_w, new_h), Image.BILINEAR) | |
| # ======================================================= | |
| # 4. Guardado temporal | |
| # ======================================================= | |
| upload_dir = "resources/uploads" | |
| mkdir(upload_dir) | |
| temp_path = get_temp_image_path(upload_dir) | |
| image_pil.save(temp_path, format="JPEG", quality=95) | |
| # ======================================================= | |
| # 5. DataFrame temporal (estructura compatible con CSVDataset) | |
| # ======================================================= | |
| df = pd.DataFrame({ | |
| "images": [os.path.basename(temp_path)], | |
| "x": [0], | |
| "y": [0], | |
| "labels": [1], | |
| }) | |
| # ======================================================= | |
| # 6. Normalizaci贸n + transformaciones finales | |
| # ======================================================= | |
| normalize = A.Normalize(mean=mean, std=std, p=1.0) | |
| end_transforms = [DownSample(down_ratio=down_ratio, anno_type="point")] | |
| dataset = CSVDataset( | |
| csv_file=df, | |
| root_dir=os.path.dirname(temp_path), | |
| albu_transforms=[normalize], | |
| end_transforms=end_transforms, | |
| ) | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False) | |
| # ======================================================= | |
| # 7. Log de diagn贸stico | |
| # ======================================================= | |
| print("[PREPROCESS] Imagen preparada para inferencia:") | |
| print(f" - Guardada en: {temp_path}") | |
| print(f" - Tama帽o final: {image_pil.size}") | |
| print(f" - down_ratio: {down_ratio}") | |
| print(f" - patch_size: {patch_size}") | |
| print(f" - overlap: {overlap}") | |
| return dataset, dataloader, temp_path | |