import gradio as gr from huggingface_hub import hf_hub_download import os, time, traceback import numpy as np from pathlib import Path from PIL import Image import torch import torchvision.transforms as T from fastai.vision.all import * # --- 1. OPTIMIZACIÓN CPU --- os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" # --- 2. DEFINICIONES NECESARIAS PARA DESERIALIZAR (COPIADAS DEL NOTEBOOK) --- # Estas clases deben existir para que 'load_learner' pueda leer el archivo, # aunque no las usaremos para la predicción real. from albumentations import Compose, ElasticTransform, GridDistortion, HorizontalFlip, VerticalFlip, Rotate # Parche para problemas de rutas entre Windows/Linux import pathlib temp = pathlib.PosixPath pathlib.WindowsPath = pathlib.PosixPath def get_y_fn(x): return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png")) def ParentSplitter(x): return Path(x).parent.name=="test" class SegmentationAlbumentationsTransform(ItemTransform): split_idx = 0 def __init__(self, aug=None): self.aug = aug def encodes(self, x): return x class TargetMaskConvertTransform(ItemTransform): split_idx = 0 def __init__(self): pass def encodes(self, x): return x # --- 3. CONFIGURACIÓN Y CARGA --- model = None # SEGÚN TU NOTEBOOK, EL REPO ES ESTE: repo_id = "acascanzal/practica3" # push_to_hub_fastai suele guardar como 'export.pkl'. Si lo cambiaste, edita esto: filename = "model.pkl" def load_model_pure_pytorch(): global model try: print("Descargando modelo del Hub...") model_path = hf_hub_download(repo_id=repo_id, filename=filename) print("Deserializando learner de FastAI...") # Cargamos el learner completo learn = load_learner(model_path, cpu=True) # EXTRAEMOS EL MODELO PURO Y DESCARTAMOS EL RESTO # Esto evita errores con las transformaciones de Albumentations en inferencia model = learn.model model.eval() model.cpu() # Limpieza de memoria del learn import gc gc.collect() print("Modelo PyTorch extraído y listo.") return model except Exception as e: print(f"Error cargando el modelo: {e}") traceback.print_exc() return None # Cargar al inicio model = load_model_pure_pytorch() # --- 4. PRE-PROCESAMIENTO MANUAL --- # Usamos las mismas estadísticas de ImageNet que usaste en el notebook # Y redimensionamos a 480x640 (tamaño usado en tu entrenamiento) preprocess = T.Compose([ T.Resize((480, 640)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --- 5. FUNCIÓN DE PREDICCIÓN --- def predict(img): if img is None or model is None: return img try: start_time = time.time() # A. Preparar imagen original img = img.convert("RGB") original_size = img.size # B. Transformar a Tensor para la IA input_tensor = preprocess(img).unsqueeze(0) # Batch size 1 # C. Inferencia (Sin gradientes para ahorrar RAM) with torch.no_grad(): output = model(input_tensor) # D. Post-procesado (Argmax para obtener índices de clase) # output shape: [1, 5, 480, 640] -> [480, 640] mask_idx = output.argmax(dim=1).squeeze().cpu().numpy().astype(np.uint8) # E. Coloreado (Mapeo de Clases a Colores RGB) # 0=background, 1=leaves, 2=wood, 3=pole, 4=grape colors = { 0: (0, 0, 0), # Fondo 1: (0, 255, 0), # Hojas -> Verde Lima Puro 2: (255, 140, 0), # Madera -> Naranja Intenso (se ve mejor que el marrón) 3: (0, 255, 255), # Poste -> Cyan / Azul Eléctrico 4: (255, 0, 255) # Uva -> Magenta / Fucsia } # Crear imagen RGB vacía para la máscara h, w = mask_idx.shape colored_mask = np.zeros((h, w, 3), dtype=np.uint8) for cls_id, color in colors.items(): if cls_id == 0: continue # Saltamos el fondo para mantenerlo limpio colored_mask[mask_idx == cls_id] = color # F. Redimensionar máscara al tamaño original de la foto subida mask_pil = Image.fromarray(colored_mask).resize(original_size, resample=Image.NEAREST) # G. Mezclar con la imagen original (Overlay) final_img = Image.blend(img, mask_pil, alpha=0.4) print(f"Inferencia completada en {time.time() - start_time:.2f}s") return final_img except Exception as e: print(f"Error en predicción: {e}") traceback.print_exc() return img # --- 6. INTERFAZ --- interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Subir Imagen de Viñedo"), outputs=gr.Image(type="pil", label="Segmentación (Verde: Hojas, Marrón: Madera, Azul: Poste, Morado: Uva)"), title="Segmentación de Uvas - Práctica 3", description="Modelo U-Net ResNet50 entrenado para detectar 5 clases." ) if __name__ == "__main__": interface.launch()