Spaces:
Running on Zero
Running on Zero
| import spaces | |
| import gradio as gr | |
| from util import imread, imsave, get_examples | |
| import torch | |
| def torch_compile(*args, **kwargs): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| torch.compile = torch_compile # temporary workaround | |
| default_model = 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c' | |
| def predict(filename, model=None, device=None, reduce_labels=True): | |
| from cpn import CpnInterface | |
| from prep import multi_norm | |
| from celldetection import label_cmap | |
| global default_model | |
| assert isinstance(filename, str) | |
| if device is None: | |
| if torch.cuda.device_count(): | |
| device = 'cuda' | |
| else: | |
| device = 'cpu' | |
| print(dict( | |
| filename=filename, | |
| model=model, | |
| device=device, | |
| reduce_labels=reduce_labels | |
| ), flush=True) | |
| img = imread(filename) | |
| print('Image:', img.dtype, img.shape, (img.min(), img.max()), flush=True) | |
| if model is None or len(str(model)) <= 0: | |
| model = default_model | |
| img = multi_norm(img, 'cstm-mix') # TODO | |
| m = CpnInterface(model.strip(), device=device) | |
| y = m(img, reduce_labels=reduce_labels) | |
| labels = y['labels'] | |
| vis_labels = label_cmap(labels) | |
| dst = '.'.join(filename.split('.')[:-1]) + '_labels.tiff' | |
| imsave(dst, labels) | |
| return img, vis_labels, dst | |
| gr.Interface( | |
| predict, | |
| inputs=[gr.components.Image(label="Upload Input Image", type="filepath"), | |
| gr.components.Textbox(label='Model Name', value=default_model, max_lines=1)], | |
| outputs=[gr.Image(label="Processed Image"), | |
| gr.Image(label="Label Image"), | |
| gr.File(label="Download Label Image")], | |
| title="Cell Detection with Contour Proposal Networks", | |
| examples=get_examples(default_model) | |
| ).launch() | |