Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import sys | |
| sys.path.append("DLASS4L") | |
| from cycleGAN.training.trainer import CGANTrainer | |
| import torchvision.transforms.functional as TF | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = CGANTrainer() | |
| model.load_checkpoints("models/epoch_100.pth") | |
| sketch = gr.Image(sources=['upload'], type="numpy", label="Sketch", show_label=True, height=256, width=256) | |
| lesion = gr.Image(label="Generated Image", show_label=True, height=256, width=256) | |
| def generateLesion(image): | |
| # image = cv2.imread(image) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| sketch = cv2.resize(image, (256, 256)) | |
| true_label = 3 | |
| # Create a numpy array of zeros with shape (num_classes, 256, 256) | |
| encoded_label = np.zeros((7, 256, 256), dtype=np.float32) | |
| # Set all elements in the channel corresponding to the true label to 1 | |
| encoded_label[true_label, :, :] = 1 # Use broadcasting to set all elements | |
| label = torch.tensor(encoded_label, dtype=torch.float32) | |
| sketch = torch.tensor(sketch, dtype=torch.float32) | |
| label = label.unsqueeze(0) | |
| sketch = sketch.unsqueeze(0) | |
| sketch = torch.permute(sketch, (0, 3, 1, 2)) | |
| label = label.to(device) | |
| sketch = sketch.to(device) | |
| concat_ls = torch.cat((sketch, label), dim=1) | |
| fake_image = model.genB(concat_ls) | |
| ret_image = fake_image[0, :3, :, :].detach().cpu() | |
| ret_image = ret_image.squeeze(0) | |
| ret_image = TF.to_pil_image(ret_image) | |
| return ret_image | |
| iface = gr.Interface( | |
| fn=generateLesion, | |
| inputs=sketch, | |
| outputs=lesion | |
| ) | |
| iface.launch() |