PablitoGil14 commited on
Commit
7bab3c7
·
verified ·
1 Parent(s): 9e9defa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -82
app.py CHANGED
@@ -1,107 +1,84 @@
1
  from huggingface_hub import from_pretrained_fastai
2
- import gradio as gr
3
  from fastai.vision.all import *
 
 
 
4
  import torchvision.transforms as transforms
5
- import torchvision.transforms as transforms
6
- from fastai.basics import *
7
- from fastai.vision import models
8
- from fastai.vision.all import *
9
- from fastai.metrics import *
10
- from fastai.data.all import *
11
- from fastai.callback import *
12
  from pathlib import Path
13
- import random
14
- import PIL
15
- from fastai.callback.fp16 import AMPMode
16
 
17
- #Primero definimos todas las funciones, clases y variables que sopn necesarias para que esto funcione
 
 
 
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
19
  def transform_image(image):
20
- my_transforms = transforms.Compose([transforms.ToTensor(),
21
- transforms.Normalize(
22
- [0.485, 0.456, 0.406],
23
- [0.229, 0.224, 0.225])])
24
- image_aux = image
25
- return my_transforms(image_aux).unsqueeze(0).to(device)
 
 
 
 
26
 
27
  class TargetMaskConvertTransform(ItemTransform):
28
- def __init__(self):
29
- pass
30
  def encodes(self, x):
31
- img,mask = x
32
-
33
- #Convert to array
34
  mask = np.array(mask)
35
-
36
- mask[(mask!=255) & (mask!=150) & (mask!=76) & (mask!=74) & (mask!=29) & (mask!=25)]=0
37
- mask[mask==255]=1
38
- mask[mask==150]=2
39
- mask[mask==76]=4
40
- mask[mask==74]=4
41
- mask[mask==29]=3
42
- mask[mask==25]=3
43
-
44
- # Back to PILMask
45
- mask = PILMask.create(mask)
46
- return img, mask
47
-
48
- from albumentations import (
49
- Compose,
50
- OneOf,
51
- ElasticTransform,
52
- GridDistortion,
53
- OpticalDistortion,
54
- HorizontalFlip,
55
- Rotate,
56
- Transpose,
57
- CLAHE,
58
- ShiftScaleRotate
59
- )
60
-
61
- def get_y_fn (x):
62
- return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
63
 
64
  class SegmentationAlbumentationsTransform(ItemTransform):
65
  split_idx = 0
66
-
67
- def __init__(self, aug):
68
- self.aug = aug
69
-
70
  def encodes(self, x):
71
- img,mask = x
72
  aug = self.aug(image=np.array(img), mask=np.array(mask))
73
  return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
74
 
75
- #Cargamos el modelo
76
  repo_id = "PablitoGil14/AP-Practica3"
77
  learn = from_pretrained_fastai(repo_id)
78
- model = learn.model
79
- model = model.cpu()
80
-
81
 
82
- # Definimos una función que se encarga de llevar a cabo las predicciones
83
- def predict(img_ruta):
84
- # img = PIL.Image.open(img_ruta) #esto si el parámetro de entrada es una ruta a una imagen
85
- # img = img_ruta # esto si el parámetro de entrada es una imagen
86
- img = PIL.Image.fromarray(img_ruta)
87
- image = transforms.Resize((480,640))(img)
88
- tensor = transform_image(image=image)
89
  model.to(device)
90
  with torch.no_grad():
91
- outputs = model(tensor)
92
-
93
- outputs = torch.argmax(outputs,1)
94
- mask = np.array(outputs.cpu())
95
- mask[mask==1]=255
96
- mask[mask==2]=150
97
- mask[mask==3]=29
98
- mask[mask==4]=74
99
- mask = np.reshape(mask,(480,640))
100
- return Image.fromarray(mask.astype('uint8'))
 
 
 
 
 
 
 
 
 
 
101
 
102
-
103
- #img = PILImage.create(img) #igual hay que usar esto en vez de PIL.Image.open
104
-
105
- # Creamos la interfaz y la lanzamos.
106
- gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(480, 640)), outputs=gr.inputs.Image(shape=(480, 640)), examples=['color_61.jpg','color_62.jpg']).launch(share=False)
107
 
 
1
  from huggingface_hub import from_pretrained_fastai
 
2
  from fastai.vision.all import *
3
+ import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
  import torchvision.transforms as transforms
7
+ from albumentations import (
8
+ Compose, GridDistortion, HorizontalFlip, Rotate
9
+ )
 
 
 
 
10
  from pathlib import Path
11
+ import torch
 
 
12
 
13
+ # 1. Evita imports repetidos e innecesarios
14
+ # ❌ ya importaste fastai.vision.all que lo incluye todo
15
+ # ❌ ya tienes PIL y transforms más arriba
16
+
17
+ # ✅ 2. Función de preprocesamiento
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
  def transform_image(image):
21
+ preprocess = transforms.Compose([
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406],
24
+ [0.229, 0.224, 0.225])
25
+ ])
26
+ return preprocess(image).unsqueeze(0).to(device)
27
+
28
+ # ✅ 3. Funciones y clases necesarias para cargar el modelo
29
+ def get_y_fn(x):
30
+ return Path(str(x).replace("Images", "Labels").replace("color", "gt").replace(".jpg", ".png"))
31
 
32
  class TargetMaskConvertTransform(ItemTransform):
 
 
33
  def encodes(self, x):
34
+ img, mask = x
 
 
35
  mask = np.array(mask)
36
+ mask[(mask!=255) & (mask!=150) & (mask!=76) & (mask!=74) & (mask!=29) & (mask!=25)] = 0
37
+ mask[mask==255] = 1
38
+ mask[mask==150] = 2
39
+ mask[(mask==76) | (mask==74)] = 4
40
+ mask[(mask==29) | (mask==25)] = 3
41
+ return img, PILMask.create(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  class SegmentationAlbumentationsTransform(ItemTransform):
44
  split_idx = 0
45
+ def __init__(self, aug): self.aug = aug
 
 
 
46
  def encodes(self, x):
47
+ img, mask = x
48
  aug = self.aug(image=np.array(img), mask=np.array(mask))
49
  return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
50
 
51
+ # ✅ 4. Carga del modelo desde el hub
52
  repo_id = "PablitoGil14/AP-Practica3"
53
  learn = from_pretrained_fastai(repo_id)
54
+ model = learn.model.cpu()
 
 
55
 
56
+ # 5. Función de predicción
57
+ def predict(img_input):
58
+ image = Image.fromarray(img_input).resize((640, 480))
59
+ tensor = transform_image(image)
 
 
 
60
  model.to(device)
61
  with torch.no_grad():
62
+ output = model(tensor)
63
+ pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
64
+
65
+ # Recolorear la máscara
66
+ colored = np.zeros_like(pred, dtype=np.uint8)
67
+ colored[pred == 1] = 255
68
+ colored[pred == 2] = 150
69
+ colored[pred == 3] = 29
70
+ colored[pred == 4] = 74
71
+ return Image.fromarray(colored)
72
+
73
+ # ✅ 6. Interfaz Gradio moderna (gr.Image en lugar de .inputs.Image)
74
+ demo = gr.Interface(
75
+ fn=predict,
76
+ inputs=gr.Image(type="numpy", label="Sube una imagen", shape=(480, 640)),
77
+ outputs=gr.Image(type="pil", label="Máscara generada"),
78
+ examples=["color_61.jpg", "color_62.jpg"],
79
+ title="Segmentador de Viñedos",
80
+ description="Sube una imagen y este modelo segmentará racimos de uva, hojas, madera y postes."
81
+ )
82
 
83
+ demo.launch()
 
 
 
 
84