File size: 1,649 Bytes
55e5aa3
fb7db60
 
 
 
 
 
 
55e5aa3
fb7db60
55e5aa3
fb7db60
 
55e5aa3
8c879e8
 
fb7db60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import torch 
import numpy as np
import cv2
import sys 
sys.path.append("DLASS4L")
from cycleGAN.training.trainer import CGANTrainer
import torchvision.transforms.functional as TF

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CGANTrainer()
model.load_checkpoints("models/epoch_100.pth")

sketch = gr.Image(sources=['upload'], type="numpy", label="Sketch", show_label=True, height=256, width=256)
lesion = gr.Image(label="Generated Image", show_label=True, height=256, width=256)

    
def generateLesion(image):
    # image = cv2.imread(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    sketch = cv2.resize(image, (256, 256))

    true_label = 3
    # Create a numpy array of zeros with shape (num_classes, 256, 256)
    encoded_label = np.zeros((7, 256, 256), dtype=np.float32)
    # Set all elements in the channel corresponding to the true label to 1
    encoded_label[true_label, :, :] = 1  # Use broadcasting to set all elements

    label = torch.tensor(encoded_label, dtype=torch.float32)
    sketch = torch.tensor(sketch, dtype=torch.float32)

    label = label.unsqueeze(0)
    sketch = sketch.unsqueeze(0)
    
    sketch = torch.permute(sketch, (0, 3, 1, 2))

    label = label.to(device)
    sketch = sketch.to(device)

    concat_ls = torch.cat((sketch, label), dim=1)

    fake_image = model.genB(concat_ls)

    ret_image = fake_image[0, :3, :, :].detach().cpu()
    ret_image = ret_image.squeeze(0)
    ret_image = TF.to_pil_image(ret_image)

    return ret_image

iface = gr.Interface(
    fn=generateLesion, 
    inputs=sketch, 
    outputs=lesion
)

iface.launch()