maviced commited on
Commit
4ddd9cc
·
verified ·
1 Parent(s): 12b4bdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -11
app.py CHANGED
@@ -1,26 +1,104 @@
1
- from huggingface_hub import from_pretrained_fastai
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
3
  from fastai.vision.all import *
4
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- try:
7
- import toml
8
- except ImportError:
9
- os.system('pip install toml')
10
- import toml
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # repo_id = "YOUR_USERNAME/YOUR_LEARNER_NAME"
14
  repo_id = "maviced/practica3"
15
 
16
  learner = from_pretrained_fastai(repo_id)
17
  labels = learner.dls.vocab
 
 
18
 
19
  # Definimos una función que se encarga de llevar a cabo las predicciones
20
  def predict(img):
21
- # img = PILImage.create(img)
22
- pred,pred_idx,probs = learner.predict(img)
23
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Creamos la interfaz y la lanzamos.
26
- gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Label(num_top_classes=3),examples=['color_155.jpg','color_154.jpg']).launch(share=False)
 
 
1
  import gradio as gr
2
+ import torchvision.transforms as transforms
3
+ import torchvision.transforms as transforms
4
+ import random
5
+ import PIL
6
+
7
+ from fastai.vision.all import *
8
+ from huggingface_hub import from_pretrained_fastai
9
+ from fastai.basics import *
10
+ from fastai.vision import models
11
  from fastai.vision.all import *
12
+ from fastai.metrics import *
13
+ from fastai.data.all import *
14
+ from fastai.callback import *
15
+ from pathlib import Path
16
+
17
+
18
+ def get_y_fn (x):
19
+ return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
20
+
21
+ from albumentations import (
22
+ Compose,
23
+ OneOf,
24
+ ElasticTransform,
25
+ GridDistortion,
26
+ OpticalDistortion,
27
+ HorizontalFlip,
28
+ Rotate,
29
+ Transpose,
30
+ CLAHE,
31
+ ShiftScaleRotate,
32
+ VerticalFlip
33
+ )
34
+
35
+ class SegmentationAlbumentationsTransform(ItemTransform):
36
+ split_idx = 0
37
 
38
+ def __init__(self, aug):
39
+ self.aug = aug
 
 
 
40
 
41
+ def encodes(self, x):
42
+ img,mask = x
43
+ aug = self.aug(image=np.array(img), mask=np.array(mask))
44
+ return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
45
+
46
+ class TargetMaskConvertTransform(ItemTransform):
47
+ def __init__(self):
48
+ pass
49
+ def encodes(self, x):
50
+ img,mask = x
51
+
52
+ #Convert to array
53
+ mask = np.array(mask)
54
+
55
+ mask[mask==255]=1
56
+ mask[mask==150]=2
57
+ mask[mask==76]=3
58
+ mask[mask==74]=3
59
+ mask[mask==29]=4
60
+ mask[mask==25]=4
61
+ mask[(mask != 1) & (mask != 2) & (mask != 3) & (mask != 4)] = 0
62
+
63
+
64
+ # Back to PILMask
65
+ mask = PILMask.create(mask)
66
+ return img, mask
67
+
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ def transform_image(image):
70
+ my_transforms = transforms.Compose([transforms.ToTensor(),
71
+ transforms.Normalize(
72
+ [0.485, 0.456, 0.406],
73
+ [0.229, 0.224, 0.225])])
74
+ image_aux = image
75
+ return my_transforms(image_aux).unsqueeze(0).to(device)
76
 
77
  # repo_id = "YOUR_USERNAME/YOUR_LEARNER_NAME"
78
  repo_id = "maviced/practica3"
79
 
80
  learner = from_pretrained_fastai(repo_id)
81
  labels = learner.dls.vocab
82
+ model = learner.model
83
+ model = learner.cpu()
84
 
85
  # Definimos una función que se encarga de llevar a cabo las predicciones
86
  def predict(img):
87
+ # img = PILImage.create(img
88
+ image = transforms.Resize((480,640))(img)
89
+ tensor = transform_image(image=image)
90
+ model.to(device)
91
+ with torch.no_grad():
92
+ outputs = model(tensor)
93
+
94
+ outputs = torch.argmax(outputs,1)
95
+ mask = np.array(outputs.cpu())
96
+ mask[mask==1]=255
97
+ mask[mask==2]=150
98
+ mask[mask==3]=76
99
+ mask[mask==4]=29
100
+ mask=np.reshape(mask,(480,640))
101
+ return Image.fromarray(mask.astype('uint8'))
102
 
103
  # Creamos la interfaz y la lanzamos.
104
+ gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(480, 640)), outputs=gr.inputs.Image(shape=(480, 640)),examples=['color_155.jpg','color_154.jpg']).launch(share=False)