File size: 3,283 Bytes
31949e2
7bab3c7
441c086
2e7df12
441c086
 
 
 
 
 
 
d7433b3
441c086
 
b3da728
441c086
31949e2
 
441c086
 
 
 
 
 
87fccf6
 
441c086
 
87fccf6
441c086
 
 
87fccf6
441c086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87fccf6
d7433b3
 
441c086
 
 
 
d7433b3
441c086
d7433b3
 
87fccf6
441c086
cf8d99f
31949e2
441c086
 
76746e3
441c086
 
 
 
 
 
 
 
31949e2
2e7df12
441c086
 
 
 
 
 
 
 
 
 
2e7df12
441c086
 
 
 
ebb86b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from huggingface_hub import from_pretrained_fastai
import gradio as gr
from fastai.vision.all import *
import torchvision.transforms as transforms
import torchvision.transforms as transforms
from fastai.basics import *
from fastai.vision import models
from fastai.vision.all import *
from fastai.metrics import *
from fastai.data.all import *
from fastai.callback import *
from pathlib import Path
import random
import PIL

#Primero definimos todas las funciones, clases y variables que sopn necesarias para que esto funcione
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def transform_image(image):
    my_transforms = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image_aux = image
    return my_transforms(image_aux).unsqueeze(0).to(device)

class TargetMaskConvertTransform(ItemTransform):
    def __init__(self):
        pass
    def encodes(self, x):
        img,mask = x

        #Convert to array
        mask = np.array(mask)

        mask[(mask!=255) & (mask!=150) & (mask!=76) & (mask!=74) & (mask!=29) & (mask!=25)]=0
        mask[mask==255]=1
        mask[mask==150]=2
        mask[mask==76]=4
        mask[mask==74]=4
        mask[mask==29]=3
        mask[mask==25]=3

        # Back to PILMask
        mask = PILMask.create(mask)
        return img, mask

from albumentations import (
    Compose,
    OneOf,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion,
    HorizontalFlip,
    Rotate,
    Transpose,
    CLAHE,
    ShiftScaleRotate
)

def get_y_fn (x):
    return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0

    def __init__(self, aug):
        self.aug = aug

    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

#Cargamos el modelo
repo_id = "PablitoGil14/AP-Practica3"
learn = from_pretrained_fastai(repo_id)
model = learn.model
model = model.cpu()


# Definimos una función que se encarga de llevar a cabo las predicciones
def predict(img_ruta):
    # img = PIL.Image.open(img_ruta) #esto si el parámetro de entrada es una ruta a una imagen
    # img = img_ruta # esto si el parámetro de entrada es una imagen
    img = PIL.Image.fromarray(img_ruta)
    image = transforms.Resize((480,640))(img)
    tensor = transform_image(image=image)
    model.to(device)
    with torch.no_grad():
      outputs = model(tensor)
    
    outputs = torch.argmax(outputs,1)
    mask = np.array(outputs.cpu())
    mask[mask==1]=255
    mask[mask==2]=150
    mask[mask==3]=29
    mask[mask==4]=74
    mask = np.reshape(mask,(480,640))
    return Image.fromarray(mask.astype('uint8'))

    
    #img = PILImage.create(img) #igual hay que usar esto en vez de PIL.Image.open
    
# Creamos la interfaz y la lanzamos. 
gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Image(), examples=['color_161.jpg','color_162.jpg']).launch(share=True, server_name="0.0.0.0", server_port=7860)