Spaces:
Sleeping
Sleeping
File size: 6,491 Bytes
d04393d 29c4311 d04393d 29c4311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import os
import numpy as np
import pandas as pd
import albumentations as A
from torch.utils.data import DataLoader
from PIL import Image, ImageOps
from animaloc.datasets import CSVDataset
from animaloc.data.transforms import DownSample
from inference.utils_io import mkdir, get_temp_image_path
def build_normalize_transform(mean: list, std: list) -> A.Normalize:
"""
Construye una transformaci贸n de normalizaci贸n id茅ntica
a la utilizada durante el entrenamiento.
"""
return A.Normalize(mean=mean, std=std, p=1.0)
def build_end_transforms(down_ratio: int = 2):
"""
Construye el conjunto de transformaciones finales utilizadas
durante la inferencia.
"""
return [DownSample(down_ratio=down_ratio, anno_type="point")]
def create_single_image_dataset(
image_pil,
mean: list,
std: list,
down_ratio: int = 2
):
"""
Crea un CSVDataset temporal y su DataLoader a partir de una 煤nica imagen PIL.
Guarda la imagen temporalmente en disco (resources/uploads) para la API de HerdNet.
Retorna
-------
dataset : CSVDataset
Dataset temporal con una sola imagen.
dataloader : DataLoader
Cargador de datos correspondiente.
temp_path : str
Ruta absoluta de la imagen guardada temporalmente.
"""
# Crear directorio y archivo temporal
upload_dir = "resources/uploads"
mkdir(upload_dir)
temp_path = get_temp_image_path(upload_dir)
image_pil.save(temp_path, format="JPEG")
# Construir DataFrame para CSVDataset
df = pd.DataFrame({
"images": [os.path.basename(temp_path)],
"x": [0],
"y": [0],
"labels": [1],
})
# Normalizaci贸n Albumentations
normalize = A.Normalize(mean=mean, std=std, p=1.0)
end_transforms = [DownSample(down_ratio=down_ratio, anno_type="point")]
# Crear dataset y dataloader
dataset = CSVDataset(
csv_file=df,
root_dir=os.path.dirname(temp_path),
albu_transforms=[normalize],
end_transforms=end_transforms,
)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
return dataset, dataloader, temp_path
def create_single_image_dataset_safe(
image_pil,
down_ratio: int = 2,
patch_size: int = 512,
overlap: int = 160,
mean: tuple = (0.485, 0.456, 0.406),
std: tuple = (0.229, 0.224, 0.225),
):
"""
Prepara una imagen para inferencia con HerdNet, asegurando compatibilidad
geom茅trica, crom谩tica y num茅rica con los par谩metros del modelo.
Esta funci贸n reemplaza completamente la antigua create_single_image_dataset
y evita fallos comunes al recibir im谩genes de distinta fuente, tama帽o o formato.
Par谩metros
----------
image_pil : PIL.Image
Imagen de entrada (cualquier resoluci贸n o fuente).
down_ratio : int
Factor de reducci贸n espacial del modelo (por defecto 2).
patch_size : int
Tama帽o del parche usado en entrenamiento (por defecto 512).
overlap : int
Superposici贸n entre parches (por defecto 160).
mean, std : tuple
Par谩metros de normalizaci贸n (por defecto ImageNet).
Retorna
-------
dataset : CSVDataset
Dataset temporal con una 煤nica imagen lista para inferencia.
dataloader : DataLoader
Cargador de datos asociado al dataset.
temp_path : str
Ruta absoluta del archivo temporal guardado.
"""
# =======================================================
# 1. Correcci贸n de orientaci贸n y canales
# =======================================================
image_pil = ImageOps.exif_transpose(image_pil)
image_pil = image_pil.convert("RGB")
# =======================================================
# 2. Limpieza y validaci贸n de valores
# =======================================================
arr = np.array(image_pil).astype(np.uint8)
arr = np.clip(arr, 0, 255)
image_pil = Image.fromarray(arr)
# =======================================================
# 3. Ajuste de resoluci贸n compatible con el modelo
# =======================================================
w, h = image_pil.size
# Forzar m煤ltiplos del down_ratio
new_w = int(np.ceil(w / down_ratio) * down_ratio)
new_h = int(np.ceil(h / down_ratio) * down_ratio)
pad_w, pad_h = new_w - w, new_h - h
if pad_w > 0 or pad_h > 0:
image_pil = ImageOps.expand(image_pil, border=(0, 0, pad_w, pad_h), fill=(0, 0, 0))
# Forzar m煤ltiplos de patch_size - overlap
step = patch_size - overlap
w, h = image_pil.size
if (w % step != 0) or (h % step != 0):
new_w = int(np.ceil(w / step) * step)
new_h = int(np.ceil(h / step) * step)
image_pil = image_pil.resize((new_w, new_h), Image.BILINEAR)
# =======================================================
# 4. Guardado temporal
# =======================================================
upload_dir = "resources/uploads"
mkdir(upload_dir)
temp_path = get_temp_image_path(upload_dir)
image_pil.save(temp_path, format="JPEG", quality=95)
# =======================================================
# 5. DataFrame temporal (estructura compatible con CSVDataset)
# =======================================================
df = pd.DataFrame({
"images": [os.path.basename(temp_path)],
"x": [0],
"y": [0],
"labels": [1],
})
# =======================================================
# 6. Normalizaci贸n + transformaciones finales
# =======================================================
normalize = A.Normalize(mean=mean, std=std, p=1.0)
end_transforms = [DownSample(down_ratio=down_ratio, anno_type="point")]
dataset = CSVDataset(
csv_file=df,
root_dir=os.path.dirname(temp_path),
albu_transforms=[normalize],
end_transforms=end_transforms,
)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
# =======================================================
# 7. Log de diagn贸stico
# =======================================================
print("[PREPROCESS] Imagen preparada para inferencia:")
print(f" - Guardada en: {temp_path}")
print(f" - Tama帽o final: {image_pil.size}")
print(f" - down_ratio: {down_ratio}")
print(f" - patch_size: {patch_size}")
print(f" - overlap: {overlap}")
return dataset, dataloader, temp_path
|