PablitoGil14 commited on
Commit
d7433b3
·
verified ·
1 Parent(s): 76746e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -28
app.py CHANGED
@@ -11,32 +11,19 @@ from albumentations import (
11
  Rotate,
12
  GridDistortion
13
  )
 
14
 
15
- class SegmentationAlbumentationsTransform(ItemTransform):
16
- split_idx = 0
17
-
18
- def __init__(self, aug):
19
- self.aug = aug
20
 
21
- def encodes(self, x):
22
- img,mask = x
23
- aug = self.aug(image=np.array(img), mask=np.array(mask))
24
- return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
25
-
26
-
27
-
28
- # Cargar modelo desde Hugging Face Hub
29
- model_path = hf_hub_download(repo_id="PablitoGil14/AP-Practica3", filename="model.pkl")
30
 
31
  class TargetMaskConvertTransform(ItemTransform):
32
  def __init__(self):
33
  pass
34
  def encodes(self, x):
35
  img,mask = x
36
-
37
- #Convert to array
38
  mask = np.array(mask)
39
-
40
  mask[(mask!=255) & (mask!=150) & (mask!=76) & (mask!=74) & (mask!=29) & (mask!=25)]=0
41
  mask[mask==255]=1
42
  mask[mask==150]=2
@@ -44,19 +31,23 @@ class TargetMaskConvertTransform(ItemTransform):
44
  mask[mask==74]=4
45
  mask[mask==29]=3
46
  mask[mask==25]=3
 
47
 
48
- # Back to PILMask
49
- mask = PILMask.create(mask)
50
- return img, mask
 
 
 
 
 
51
 
 
52
 
 
53
  learn = load_learner(model_path)
54
 
55
-
56
- def get_y_fn(x):
57
- return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
58
-
59
-
60
 
61
  def segmentar(img: Image.Image):
62
  img = img.resize((640, 480))
@@ -68,8 +59,7 @@ def segmentar(img: Image.Image):
68
  with torch.no_grad():
69
  preds = learn.model.eval()(x)
70
  mask = torch.argmax(preds, dim=1).squeeze().cpu().numpy()
71
-
72
- # Asignar colores según los valores de clase
73
  out_mask = np.zeros_like(mask, dtype=np.uint8)
74
  out_mask[mask == 1] = 255
75
  out_mask[mask == 2] = 150
@@ -77,7 +67,8 @@ def segmentar(img: Image.Image):
77
  out_mask[mask == 4] = 74
78
  return Image.fromarray(out_mask)
79
 
80
- # Interfaz de Gradio
 
81
  demo = gr.Interface(
82
  fn=segmentar,
83
  inputs=gr.Image(type="pil"),
 
11
  Rotate,
12
  GridDistortion
13
  )
14
+ from pathlib import Path
15
 
16
+ # --- FUNCIONES Y CLASES NECESARIAS PARA EL PICKLE ---
 
 
 
 
17
 
18
+ def get_y_fn(x):
19
+ return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
 
 
 
 
 
 
 
20
 
21
  class TargetMaskConvertTransform(ItemTransform):
22
  def __init__(self):
23
  pass
24
  def encodes(self, x):
25
  img,mask = x
 
 
26
  mask = np.array(mask)
 
27
  mask[(mask!=255) & (mask!=150) & (mask!=76) & (mask!=74) & (mask!=29) & (mask!=25)]=0
28
  mask[mask==255]=1
29
  mask[mask==150]=2
 
31
  mask[mask==74]=4
32
  mask[mask==29]=3
33
  mask[mask==25]=3
34
+ return img, PILMask.create(mask)
35
 
36
+ class SegmentationAlbumentationsTransform(ItemTransform):
37
+ split_idx = 0
38
+ def __init__(self, aug):
39
+ self.aug = aug
40
+ def encodes(self, x):
41
+ img,mask = x
42
+ aug = self.aug(image=np.array(img), mask=np.array(mask))
43
+ return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
44
 
45
+ # --- CARGAR MODELO ---
46
 
47
+ model_path = hf_hub_download(repo_id="PablitoGil14/AP-Practica3", filename="model.pkl")
48
  learn = load_learner(model_path)
49
 
50
+ # --- FUNCIÓN DE PREDICCIÓN ---
 
 
 
 
51
 
52
  def segmentar(img: Image.Image):
53
  img = img.resize((640, 480))
 
59
  with torch.no_grad():
60
  preds = learn.model.eval()(x)
61
  mask = torch.argmax(preds, dim=1).squeeze().cpu().numpy()
62
+
 
63
  out_mask = np.zeros_like(mask, dtype=np.uint8)
64
  out_mask[mask == 1] = 255
65
  out_mask[mask == 2] = 150
 
67
  out_mask[mask == 4] = 74
68
  return Image.fromarray(out_mask)
69
 
70
+ # --- INTERFAZ GRADIO ---
71
+
72
  demo = gr.Interface(
73
  fn=segmentar,
74
  inputs=gr.Image(type="pil"),