LuisCe commited on
Commit
264b326
verified
1 Parent(s): df972f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -16
app.py CHANGED
@@ -1,24 +1,90 @@
1
  import gradio as gr
 
2
  from fastai.learner import load_learner
3
  from PIL import Image
4
 
5
- # Load the learner
6
- learn_inf = load_learner("LuisCe/Practica03.pkl")
7
-
8
- # Define the prediction function
9
- def predict_image(img):
10
- # Convert the PIL image to a format that fastai expects
11
- img_fastai = Image.fromarray(img.astype('uint8'), 'RGB')
12
- # Make prediction
13
- pred, _, _ = learn_inf.predict(img_fastai)
14
- # Return prediction
15
- return pred
16
-
17
- # Create Gradio interface
18
- gr.Interface(predict_image,
19
- inputs="image",
20
- outputs="text",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  title="Grape Segmentation",
22
  description="Segment grapes in the image.",
23
  theme="compact",
24
  allow_flagging=False).launch()
 
 
1
  import gradio as gr
2
+ from fastai.vision.all import *
3
  from fastai.learner import load_learner
4
  from PIL import Image
5
 
6
+ # Define la funci贸n get_y_fn si se utiliza en tu modelo
7
+ def get_y_fn(x):
8
+ return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
9
+
10
+ # Define la clase TargetMaskConvertTransform
11
+ class TargetMaskConvertTransform(ItemTransform):
12
+ def __init__(self):
13
+ pass
14
+
15
+ def encodes(self, x):
16
+ img, mask = x
17
+ mask = np.array(mask)
18
+ mask[mask == 255] = 1
19
+ mask[mask == 150] = 2
20
+ mask[mask == 74] = 3
21
+ mask[mask == 76] = 3
22
+ mask[mask == 29] = 4
23
+ mask[mask == 25] = 4
24
+ mask = PILMask.create(mask)
25
+ return img, mask
26
+
27
+ # Define la clase SegmentationAlbumentationsTransform
28
+ class SegmentationAlbumentationsTransform(ItemTransform):
29
+ def __init__(self):
30
+ pass
31
+
32
+ def encodes(self, x):
33
+ img, mask = x
34
+ # Aqu铆 deber铆as definir tu transformaci贸n de Albumentations
35
+ # Por ejemplo, podr铆as tener algo como:
36
+ # transformed = my_albumentations_function(image=img, mask=mask)
37
+ # return transformed['image'], transformed['mask']
38
+ return img, mask
39
+
40
+ # Carga el modelo despu茅s de definir la clase
41
+ repo_id = "LuisCe/Practica03"
42
+ learner = from_pretrained_fastai(repo_id)
43
+
44
+
45
+ # Carga el modelo previamente entrenado
46
+ model = learner.model
47
+ model = model.cpu()
48
+ model.eval()
49
+
50
+ import torchvision.transforms as transforms
51
+ def transform_image(image):
52
+ my_transforms = transforms.Compose([transforms.ToTensor(),
53
+ transforms.Normalize(
54
+ [0.485, 0.456, 0.406],
55
+ [0.229, 0.224, 0.225])])
56
+ image_aux = image
57
+ return my_transforms(image_aux).unsqueeze(0).to(device)
58
+
59
+ def prediccion(img):
60
+ image = transforms.Resize((480,640))(img)
61
+ tensor = transform_image(image=image)
62
+
63
+
64
+ model.to(device)
65
+ with torch.no_grad():
66
+ outputs = model(tensor)
67
+
68
+ outputs = torch.argmax(outputs,1)
69
+
70
+
71
+
72
+ mask = np.array(outputs.cpu())
73
+ mask[mask==1]=255
74
+ mask[mask==2]=150
75
+ mask[mask==3]=74
76
+ mask[mask==4]=29
77
+
78
+ mask=np.reshape(mask,(480,640))
79
+
80
+ return(mask)
81
+
82
+ # Crea la interfaz Gradio
83
+ gr.Interface(prediccion,
84
+ inputs="image",
85
+ outputs="image",
86
  title="Grape Segmentation",
87
  description="Segment grapes in the image.",
88
  theme="compact",
89
  allow_flagging=False).launch()
90
+