LuisCe commited on
Commit
44dc216
verified
1 Parent(s): 3f575a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -32
app.py CHANGED
@@ -5,56 +5,107 @@ from fastai.vision.all import *
5
  from fastai.learner import load_learner
6
  from PIL import Image
7
 
8
- # Define la funci贸n get_y_fn si se utiliza en tu modelo
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def get_y_fn (x):
11
  return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
12
 
13
- # Define la clase TargetMaskConvertTransform
 
 
 
 
 
 
 
 
 
 
14
  class TargetMaskConvertTransform(ItemTransform):
15
  def __init__(self):
16
  pass
17
  def encodes(self, x):
18
- img, mask = x
 
 
19
  mask = np.array(mask)
20
- mask[mask == 255] = 1
21
- mask[mask == 150] = 2
22
- mask[mask == 74] = 3
23
- mask[mask == 76] = 3
24
- mask[mask == 29] = 4
25
- mask[mask == 25] = 4
26
- mask = PILMask.create(mask)
27
- return img, mask
28
-
29
- # Define la clase SegmentationAlbumentationsTransform
30
- class SegmentationAlbumentationsTransform(ItemTransform):
31
- def __init__(self):
32
- pass
33
 
34
- def encodes(self, x):
35
- img, mask = x
36
- # Aqu铆 deber铆as definir tu transformaci贸n de Albumentations
37
- # Por ejemplo, podr铆as tener algo como:
38
- # transformed = my_albumentations_function(image=img, mask=mask)
39
- # return transformed['image'], transformed['mask']
 
 
 
 
 
 
40
  return img, mask
41
 
42
-
43
  # Carga el modelo despu茅s de definir la clase
44
  repo_id = "LuisCe/Practica03"
45
  learner = from_pretrained_fastai(repo_id)
46
 
47
- # Define la funci贸n de predicci贸n
48
- def predict_image(img):
49
- img_fastai = Image.fromarray(img.astype('uint8'), 'RGB')
50
- pred, _, _ = learner.predict(img_fastai)
51
- return pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Crea la interfaz Gradio
54
- gr.Interface(predict_image,
55
- inputs="image",
56
- outputs="text",
57
  title="Grape Segmentation",
58
  description="Segment grapes in the image.",
59
  theme="compact",
60
- allow_flagging=False).launch()
 
5
  from fastai.learner import load_learner
6
  from PIL import Image
7
 
8
+ from albumentations import (
9
+ Compose,
10
+ OneOf,
11
+ ElasticTransform,
12
+ GridDistortion,
13
+ OpticalDistortion,
14
+ HorizontalFlip,
15
+ Rotate,
16
+ Transpose,
17
+ CLAHE,
18
+ ShiftScaleRotate
19
+ )
20
 
21
  def get_y_fn (x):
22
  return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
23
 
24
+ class SegmentationAlbumentationsTransform(ItemTransform):
25
+ split_idx = 0
26
+
27
+ def __init__(self, aug):
28
+ self.aug = aug
29
+
30
+ def encodes(self, x):
31
+ img,mask = x
32
+ aug = self.aug(image=np.array(img), mask=np.array(mask))
33
+ return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
34
+
35
  class TargetMaskConvertTransform(ItemTransform):
36
  def __init__(self):
37
  pass
38
  def encodes(self, x):
39
+ img,mask = x
40
+
41
+ #Convert to array
42
  mask = np.array(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # mask[mask!=255]=0
45
+ # Change 255 for 1
46
+ mask[mask==255]=1
47
+ mask[mask==150]=2
48
+ mask[mask==74]=3
49
+ mask[mask==76]=3
50
+ mask[mask==29]=4
51
+ mask[mask==25]=4
52
+ # mask[mask==255]=1
53
+
54
+ # Back to PILMask
55
+ mask = PILMask.create(mask)
56
  return img, mask
57
 
 
58
  # Carga el modelo despu茅s de definir la clase
59
  repo_id = "LuisCe/Practica03"
60
  learner = from_pretrained_fastai(repo_id)
61
 
62
+
63
+ # Carga el modelo previamente entrenado
64
+ model = learner.model
65
+ model = model.cpu()
66
+ model.eval()
67
+
68
+ import torchvision.transforms as transforms
69
+ def transform_image(image):
70
+ my_transforms = transforms.Compose([transforms.ToTensor(),
71
+ transforms.Normalize(
72
+ [0.485, 0.456, 0.406],
73
+ [0.229, 0.224, 0.225])])
74
+ image_aux = image
75
+ return my_transforms(image_aux).unsqueeze(0).to(device)
76
+
77
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+
79
+
80
+ def prediccion(img):
81
+ img = Image.fromarray(img)
82
+ image = transforms.Resize((480,640))(img)
83
+ tensor = transform_image(image=image)
84
+
85
+
86
+ model.to(device)
87
+ with torch.no_grad():
88
+ outputs = model(tensor)
89
+
90
+ outputs = torch.argmax(outputs,1)
91
+
92
+
93
+
94
+ mask = np.array(outputs.cpu())
95
+ mask[mask==1]=255
96
+ mask[mask==2]=150
97
+ mask[mask==3]=74
98
+ mask[mask==4]=29
99
+
100
+ mask=np.reshape(mask,(480,640))
101
+
102
+ return(mask)
103
 
104
  # Crea la interfaz Gradio
105
+ gr.Interface(prediccion,
106
+ inputs="image",
107
+ outputs="image",
108
  title="Grape Segmentation",
109
  description="Segment grapes in the image.",
110
  theme="compact",
111
+ allow_flagging=False).launch(debug=True)