alramil commited on
Commit
cfad999
·
1 Parent(s): 6fa4c15

Use export.pkl + load_learner en Space

Browse files
Files changed (2) hide show
  1. app.py +19 -73
  2. requirements.txt +1 -4
app.py CHANGED
@@ -1,88 +1,34 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
- from fastai.vision.all import PILImage, imagenet_stats, Normalize
5
  from huggingface_hub import hf_hub_download
6
- import os
7
-
8
- # Define el repo_id y el nombre del archivo del modelo en el Hub
9
- MODEL_REPO_ID = "alramil/unet-segmentation-model"
10
- MODEL_FILE_NAME = "unet.pth"
11
-
12
- # Descarga el modelo al iniciar la app
13
- try:
14
- model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILE_NAME)
15
- print(f"Modelo descargado exitosamente a: {model_path}")
16
- except Exception as e:
17
- print(f"Error al descargar el modelo desde el Hub: {e}")
18
- print(f"Asegúrate de que el modelo '{MODEL_REPO_ID}' existe y contiene el archivo '{MODEL_FILE_NAME}'.")
19
- exit()
20
-
21
- # Carga tu modelo TorchScript desde la ruta descargada
22
- model = torch.jit.load(model_path)
23
- model.eval()
24
-
25
- # Códigos y valores originales (deberían coincidir con codesAll.txt si lo usas)
26
- # Cargar los códigos desde el archivo codesAll.txt
27
- try:
28
- with open("codesAll.txt", "r") as f:
29
- class_names = [line.strip() for line in f if line.strip()]
30
- # Reconstruir class_values basado en los códigos y el palette original (o definir un mapeo si es necesario)
31
- # Nota: Este mapeo es crucial y debe coincidir con RemapMaskLabels. Aquí se asume un mapeo fijo por ahora.
32
- class_values = [0, 29, 76, 150, 255] # Asegúrate de que esto coincida con RemapMaskLabels
33
- if len(class_names) != len(class_values):
34
- print("Advertencia: El número de clases en codesAll.txt no coincide con class_values/palette.")
35
- # Puedes ajustar class_values o palette aquí si es necesario, o salir/avisar.
36
- # Por ahora, seguimos con class_values y palette fijos.
37
-
38
- except FileNotFoundError:
39
- print("Error: codesAll.txt no encontrado en el Space. No se pudieron cargar los nombres de las clases dinámicamente.")
40
- # Usar valores y nombres por defecto si codesAll.txt no se encuentra
41
- class_values = [0,29,76,150,255]
42
- class_names = ['Background', 'Leaves', 'Wood', 'Pole', 'Grape'] # Valores por defecto
43
 
 
 
 
 
44
 
 
45
  palette = np.array([
46
- [0,0,0], # 0: Background (Negro)
47
- [0,128,0], # 29: Leaves (Verde)
48
- [139,69,19], # 76: Wood (Marrón Silla de Montar)
49
- [128,0,128], # 150: Pole (Púrpura)
50
- [128,128,128], # 255: Grape (Gris)
51
  ], dtype=np.uint8)
52
 
53
- # Transform para normalizar (debe ser el mismo que usaste al entrenar)
54
- _norm = Normalize.from_stats(*imagenet_stats)
55
-
56
  def segment_image(img):
57
- pil = img.resize((640,480))
58
- arr = np.array(pil).astype(np.float32) / 255.
59
- tensor = torch.tensor(arr).permute(2,0,1).unsqueeze(0)
60
- tensor = _norm(tensor)
61
-
62
- with torch.no_grad():
63
- out = model(tensor)
64
-
65
- seg = out.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
66
-
67
- # Re-mapear los valores predichos a los valores originales para usar la paleta
68
- # Esto es necesario porque RemapMaskLabels mapeó 0,29,76,150,255 a 0,1,2,3,4 para el entrenamiento
69
- # Pero la paleta espera los valores originales o necesitas una paleta mapeada a 0,1,2,3,4
70
- # Si tu modelo predice 0,1,2,3,4, puedes mapearlos de nuevo a los valores originales
71
- # O usar una paleta que mapee 0,1,2,3,4 a los colores.
72
- # La forma más sencilla es usar la paleta con los índices predichos (0,1,2,3,4)
73
- # Asegúrate de que la paleta está indexada correctamente.
74
- # Si el modelo predice 0,1,2,3,4 y la paleta espera índices 0,1,2,3,4, entonces:
75
- seg_color = palette[seg]
76
-
77
-
78
- return seg_color
79
 
80
  demo = gr.Interface(
81
  fn=segment_image,
82
- inputs=gr.Image(type="pil", label="Sube una imagen"),
83
- outputs=gr.Image(type="numpy", label="Máscara Segmentada"),
84
- title="Segmentación Semántica de Racimos",
85
- description="Sube una imagen de viñedo y obtén la máscara segmentada por clases (Fondo, Hojas, Madera, Poste, Racimo)."
86
  )
87
 
88
  if __name__=="__main__":
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ from fastai.vision.all import load_learner, PILImage
4
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Descargar y cargar export.pkl
7
+ pkl = hf_hub_download(repo_id="${MODEL_REPO}", filename="export.pkl")
8
+ learn = load_learner(pkl)
9
+ learn.model.eval()
10
 
11
+ # Paleta de colores para índices 0–4
12
  palette = np.array([
13
+ [0,0,0], # fondo
14
+ [0,128,0], # hojas
15
+ [139,69,19], # madera
16
+ [128,0,128], # poste
17
+ [128,128,128] # racimo
18
  ], dtype=np.uint8)
19
 
 
 
 
20
  def segment_image(img):
21
+ # FastAI internamente hace resize, tensor y normalize
22
+ pred,_,_ = learn.predict(PILImage.create(img))
23
+ seg = np.array(pred) # H×W array de 0–4
24
+ return palette[seg]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  demo = gr.Interface(
27
  fn=segment_image,
28
+ inputs=gr.Image(type="pil", label="Sube imagen"),
29
+ outputs=gr.Image(type="numpy", label="Máscara segmentada"),
30
+ title="Segmentación de Racimos de Uva",
31
+ description="Modelo U-Net entrenado con FastAI"
32
  )
33
 
34
  if __name__=="__main__":
requirements.txt CHANGED
@@ -1,7 +1,4 @@
1
  fastai
2
- torch
3
- torchvision
4
  gradio
5
- numpy
6
- pillow
7
  huggingface-hub
 
 
1
  fastai
 
 
2
  gradio
 
 
3
  huggingface-hub
4
+ numpy