Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import gradio as gr | |
| from glob import glob | |
| from functools import partial | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as TF | |
| from transformers import SegformerForSemanticSegmentation | |
| class Configs: | |
| NUM_CLASSES: int = 4 # including background. | |
| CLASSES: tuple = ("Large bowel", "Small bowel", "Stomach") | |
| IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H | |
| MEAN: tuple = (0.485, 0.456, 0.406) | |
| STD: tuple = (0.229, 0.224, 0.225) | |
| MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights") | |
| def get_model(*, model_path, num_classes): | |
| model = SegformerForSemanticSegmentation.from_pretrained(model_path, num_labels=num_classes, ignore_mismatched_sizes=True) | |
| return model | |
| def predict(input_image, model=None, preprocess_fn=None, device="cpu"): | |
| shape_H_W = input_image.size[::-1] | |
| input_tensor = preprocess_fn(input_image) | |
| input_tensor = input_tensor.unsqueeze(0).to(device) | |
| # Generate predictions | |
| outputs = model(pixel_values=input_tensor.to(device), return_dict=True) | |
| predictions = F.interpolate(outputs["logits"], size=shape_H_W, mode="bilinear", align_corners=False) | |
| preds_argmax = predictions.argmax(dim=1).cpu().squeeze().numpy() | |
| seg_info = [(preds_argmax == idx, class_name) for idx, class_name in enumerate(Configs.CLASSES, 1)] | |
| return (input_image, seg_info) | |
| if __name__ == "__main__": | |
| class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"} | |
| DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
| model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES) | |
| model.to(DEVICE) | |
| model.eval() | |
| _ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE)) | |
| preprocess = TF.Compose( | |
| [ | |
| TF.Resize(size=Configs.IMAGE_SIZE[::-1]), | |
| TF.ToTensor(), | |
| TF.Normalize(Configs.MEAN, Configs.STD, inplace=True), | |
| ] | |
| ) | |
| with gr.Blocks(title="ImageAlchemy") as demo: | |
| gr.Markdown("""<h1><center>ImageAlchemy</center></h1>""") | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil", height=360, width=360, label="Input image") | |
| img_output = gr.AnnotatedImage(label="Predictions", height=360, width=360, color_map=class2hexcolor) | |
| section_btn = gr.Button("Generate Predictions") | |
| section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output) | |
| images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png") | |
| examples = [i for i in np.random.choice(images_dir, size=10, replace=False)] | |
| gr.Examples(examples=examples, inputs=img_input, outputs=img_output) | |
| demo.launch() | |