| import numpy as np |
| import gradio as gr |
| import cv2 |
| from cellpose import models |
| from matplotlib.colors import hsv_to_rgb |
| import matplotlib.pyplot as plt |
| import os, io, base64 |
|
|
| try: |
| model = models.CellposeModel(gpu=False, pretrained_model="cyto3") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| exit(1) |
|
|
| def plot_flows(y): |
| Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2 |
| X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2 |
| H = (np.arctan2(Y, X) + np.pi) / (2*np.pi) |
| S = normalize99(y[0][0]**2 + y[1][0]**2) |
| HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1) |
| HSV = np.clip(HSV, 0.0, 1.0) |
| flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8) |
| return flow |
|
|
| def plot_outlines(img, masks): |
| outpix = [] |
| contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE) |
| for c in range(len(contours)): |
| pix = contours[c].astype(int).squeeze() |
| if len(pix)>4: |
| peri = cv2.arcLength(contours[c], True) |
| approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:] |
| outpix.append(approx) |
| |
| figsize = (6,6) |
| if img.shape[0]>img.shape[1]: |
| figsize = (6*img.shape[1]/img.shape[0], 6) |
| else: |
| figsize = (6, 6*img.shape[0]/img.shape[1]) |
| fig = plt.figure(figsize=figsize, facecolor='k') |
| ax = fig.add_axes([0.0,0.0,1,1]) |
| ax.set_xlim([0,img.shape[1]]) |
| ax.set_ylim([0,img.shape[0]]) |
| ax.imshow(img[::-1], origin='upper') |
| if outpix is not None: |
| for o in outpix: |
| ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1) |
| ax.axis('off') |
| bytes_image = io.BytesIO() |
| plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none') |
| bytes_image.seek(0) |
| img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8) |
| bytes_image.close() |
| img = cv2.imdecode(img_arr, 1) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| del bytes_image |
| fig.clf() |
| plt.close(fig) |
| return img |
|
|
| def plot_overlay(img, masks): |
| img = normalize99(img.astype(np.float32).mean(axis=-1)) |
| img -= img.min() |
| img /= img.max() |
| HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32) |
| HSV[:,:,2] = np.clip(img*1.5, 0, 1.0) |
| for n in range(int(masks.max())): |
| ipix = (masks==n+1).nonzero() |
| HSV[ipix[0],ipix[1],0] = np.random.rand() |
| HSV[ipix[0],ipix[1],1] = 1.0 |
| RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8) |
| return RGB |
|
|
| def normalize99(img): |
| X = img.copy() |
| X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1)) |
| return X |
|
|
| def image_resize(img, resize=224): |
| ny,nx = img.shape[:2] |
| if np.array(img.shape).max() > resize: |
| if ny>nx: |
| nx = int(nx/ny * resize) |
| ny = resize |
| else: |
| ny = int(ny/nx * resize) |
| nx = resize |
| shape = (nx,ny) |
| img = cv2.resize(img, shape) |
| img = img.astype(np.uint8) |
| return img |
|
|
| def cellpose_segment(img_input): |
| img = image_resize(img_input) |
| masks, flows, _ = model.eval(img, channels=[0,0]) |
| flows = flows[0] |
| |
| |
| target_size = (img_input.shape[1], img_input.shape[0]) |
| if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]): |
| |
| masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16') |
| flows = cv2.resize(flows.astype('float32'), target_size).astype('uint8') |
| |
| outpix = plot_outlines(img, masks) |
| overlay = plot_overlay(img, masks) |
| return outpix, overlay, flows, masks |
|
|
| |
| iface = gr.Interface( |
| fn=cellpose_segment, |
| inputs="image", |
| outputs=["image", "image", "image", "image"], |
| title="cellpose segmentation", |
| description="upload an image, then cellpose will segment it at a max size of 224x224" |
| ) |
|
|
| iface.launch() |
|
|