practica3 / app.py
acascanzal's picture
Update app.py
ab56b83 verified
import gradio as gr
from huggingface_hub import hf_hub_download
import os, time, traceback
import numpy as np
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms as T
from fastai.vision.all import *
# --- 1. OPTIMIZACIÓN CPU ---
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
# --- 2. DEFINICIONES NECESARIAS PARA DESERIALIZAR (COPIADAS DEL NOTEBOOK) ---
# Estas clases deben existir para que 'load_learner' pueda leer el archivo,
# aunque no las usaremos para la predicción real.
from albumentations import Compose, ElasticTransform, GridDistortion, HorizontalFlip, VerticalFlip, Rotate
# Parche para problemas de rutas entre Windows/Linux
import pathlib
temp = pathlib.PosixPath
pathlib.WindowsPath = pathlib.PosixPath
def get_y_fn(x):
return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
def ParentSplitter(x):
return Path(x).parent.name=="test"
class SegmentationAlbumentationsTransform(ItemTransform):
split_idx = 0
def __init__(self, aug=None): self.aug = aug
def encodes(self, x): return x
class TargetMaskConvertTransform(ItemTransform):
split_idx = 0
def __init__(self): pass
def encodes(self, x): return x
# --- 3. CONFIGURACIÓN Y CARGA ---
model = None
# SEGÚN TU NOTEBOOK, EL REPO ES ESTE:
repo_id = "acascanzal/practica3"
# push_to_hub_fastai suele guardar como 'export.pkl'. Si lo cambiaste, edita esto:
filename = "model.pkl"
def load_model_pure_pytorch():
global model
try:
print("Descargando modelo del Hub...")
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
print("Deserializando learner de FastAI...")
# Cargamos el learner completo
learn = load_learner(model_path, cpu=True)
# EXTRAEMOS EL MODELO PURO Y DESCARTAMOS EL RESTO
# Esto evita errores con las transformaciones de Albumentations en inferencia
model = learn.model
model.eval()
model.cpu()
# Limpieza de memoria
del learn
import gc
gc.collect()
print("Modelo PyTorch extraído y listo.")
return model
except Exception as e:
print(f"Error cargando el modelo: {e}")
traceback.print_exc()
return None
# Cargar al inicio
model = load_model_pure_pytorch()
# --- 4. PRE-PROCESAMIENTO MANUAL ---
# Usamos las mismas estadísticas de ImageNet que usaste en el notebook
# Y redimensionamos a 480x640 (tamaño usado en tu entrenamiento)
preprocess = T.Compose([
T.Resize((480, 640)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 5. FUNCIÓN DE PREDICCIÓN ---
def predict(img):
if img is None or model is None:
return img
try:
start_time = time.time()
# A. Preparar imagen original
img = img.convert("RGB")
original_size = img.size
# B. Transformar a Tensor para la IA
input_tensor = preprocess(img).unsqueeze(0) # Batch size 1
# C. Inferencia (Sin gradientes para ahorrar RAM)
with torch.no_grad():
output = model(input_tensor)
# D. Post-procesado (Argmax para obtener índices de clase)
# output shape: [1, 5, 480, 640] -> [480, 640]
mask_idx = output.argmax(dim=1).squeeze().cpu().numpy().astype(np.uint8)
# E. Coloreado (Mapeo de Clases a Colores RGB)
# 0=background, 1=leaves, 2=wood, 3=pole, 4=grape
colors = {
0: (0, 0, 0), # Fondo
1: (0, 255, 0), # Hojas -> Verde Lima Puro
2: (255, 140, 0), # Madera -> Naranja Intenso (se ve mejor que el marrón)
3: (0, 255, 255), # Poste -> Cyan / Azul Eléctrico
4: (255, 0, 255) # Uva -> Magenta / Fucsia
}
# Crear imagen RGB vacía para la máscara
h, w = mask_idx.shape
colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
for cls_id, color in colors.items():
if cls_id == 0: continue # Saltamos el fondo para mantenerlo limpio
colored_mask[mask_idx == cls_id] = color
# F. Redimensionar máscara al tamaño original de la foto subida
mask_pil = Image.fromarray(colored_mask).resize(original_size, resample=Image.NEAREST)
# G. Mezclar con la imagen original (Overlay)
final_img = Image.blend(img, mask_pil, alpha=0.4)
print(f"Inferencia completada en {time.time() - start_time:.2f}s")
return final_img
except Exception as e:
print(f"Error en predicción: {e}")
traceback.print_exc()
return img
# --- 6. INTERFAZ ---
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Subir Imagen de Viñedo"),
outputs=gr.Image(type="pil", label="Segmentación (Verde: Hojas, Marrón: Madera, Azul: Poste, Morado: Uva)"),
title="Segmentación de Uvas - Práctica 3",
description="Modelo U-Net ResNet50 entrenado para detectar 5 clases."
)
if __name__ == "__main__":
interface.launch()