File size: 3,024 Bytes
496908f
 
 
 
 
 
 
 
 
 
 
 
 
 
6e10e81
 
 
 
 
 
 
 
 
 
 
 
 
 
933b3a7
4ddd9cc
8436ee9
14e3920
8436ee9
 
 
 
14e3920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b836ed6
 
 
 
 
14e3920
 
 
 
4ddd9cc
 
 
 
 
 
 
 
 
 
 
14e3920
4ddd9cc
 
14e3920
4ddd9cc
 
933b3a7
4ddd9cc
 
933b3a7
4ddd9cc
 
 
 
 
ecdee40
14e3920
 
 
933b3a7
 
b6784b7
 
 
4ddd9cc
 
b6784b7
4ddd9cc
 
b6784b7
4ddd9cc
 
b6784b7
4ddd9cc
b6784b7
 
 
 
 
 
 
 
4ddd9cc
14e3920
 
b6784b7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torchvision.transforms as transforms
import random
import gradio as gr
import PIL

from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai
from fastai.basics import *
from fastai.vision import models
from fastai.vision.all import *
from fastai.metrics import *
from fastai.data.all import *
from fastai.callback import *
from pathlib import Path
try:
    import albumentations
except ImportError:
    os.system('pip install albumentations')
    import albumentations

try:
    import toml
except ImportError:
    os.system('pip install toml')
    import toml

os.system('pip install -U gradio')

import gradio as gr


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_y_fn (x):
    return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
    
def transform_image(image):
    my_transforms = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image_aux = image
    return my_transforms(image_aux).unsqueeze(0).to(device)

class TargetMaskConvertTransform(ItemTransform):
    def __init__(self):
        pass
    def encodes(self, x):
        img,mask = x

        #Convert to array
        mask = np.array(mask)

        mask[mask==255]=1
        mask[mask==150]=2
        mask[mask==76]=3
        mask[mask==74]=3
        mask[mask==29]=4
        mask[mask==25]=4
        mask[(mask != 1) & (mask != 2) & (mask != 3) & (mask != 4)] = 0

        # Back to PILMask
        mask = PILMask.create(mask)
        return img, mask

from albumentations import (
    Compose,
    OneOf,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion,
    HorizontalFlip,
    Rotate,
    Transpose,
    CLAHE,
    ShiftScaleRotate
)


class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0

    def __init__(self, aug):
        self.aug = aug

    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

repo_id = "maviced/practica3"
learn = from_pretrained_fastai(repo_id)
model = learn.model
model = model.cpu()


def predict(img):
    img = PILImage.create(img)
    
    image = transforms.Resize((480,640))(img)
    tensor = transform_image(image=image)

    model.to(device)
    with torch.no_grad():
        outputs = model(tensor)
    
    outputs = torch.argmax(outputs,1)

    mask = np.array(outputs.cpu())
    mask[mask==0]=255
    mask[mask==1]=150
    mask[mask==2]=76
    mask[mask==3]=25
    mask[mask==4]=0

    mask=np.reshape(mask,(480,640))

    return Image.fromarray(mask.astype('uint8'))

    
# Creamos la interfaz y la lanzamos.
gr.Interface(fn=predict, inputs=["image"], outputs=["image"],
             examples=['color_154.jpg','color_155.jpg']).launch(share=True)