Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[8]: | |
| import gradio as gr | |
| from fastai.vision.all import * | |
| import skimage | |
| import matplotlib.pyplot as plt | |
| import matplotlib as mpl | |
| from mpl_toolkits.axes_grid1 import make_axes_locatable | |
| def get_files(path): | |
| folders = path.ls() | |
| imgs = [] | |
| for f in folders: | |
| try: | |
| imgs.append(glob.glob(f'{f}/*camera7*.jpg')[0]) | |
| except: | |
| continue | |
| return imgs | |
| def get_mask(filename): | |
| txt = Path(Path(filename).name).stem | |
| scene_id = re.findall("\d+", txt)[0] | |
| db_path = 'celine-3.db' | |
| conn = sql.connect(db_path) | |
| c = conn.cursor() | |
| query = "SELECT trt FROM scenes WHERE scene_id == ?" | |
| params = [scene_id] | |
| c.execute(query,params) | |
| trt = c.fetchall()[0][0] | |
| trt_list = ['PCCCC', 'PCFCF', 'PFEFE', 'PHNHN'] | |
| mask = np.load(glob.glob(f'{path/Path(filename).stem}/*.npy')[0]).squeeze() | |
| mask = np.where(mask>0.5,1,0) | |
| mask= mask.astype(np.uint8) | |
| mask *= [i+1 for i,j in enumerate(trt_list) if j == trt][0] | |
| return mask | |
| def __call__(self:DiceLoss, pred, targ): | |
| targ = self._one_hot(targ, pred.shape[self.axis]) | |
| pred, targ = TensorBase(pred), TensorBase(targ) | |
| assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs' | |
| pred = self.activation(pred) | |
| sum_dims = list(range(2, len(pred.shape))) | |
| inter = torch.sum(pred*targ, dim=sum_dims) | |
| union = (torch.sum(pred**2+targ, dim=sum_dims) if self.square_in_union | |
| else torch.sum(pred+targ, dim=sum_dims)) | |
| dice_score = (2. * inter + self.smooth)/(union + self.smooth) | |
| x = ((1-dice_score).flatten().mean() if self.reduction == "mean" else (1-dice_score).flatten().sum()) | |
| return torch.log((torch.exp(x) + torch.exp(-x)) / 2.0) | |
| class CombinedLoss: | |
| "Dice and Focal combined" | |
| def __init__(self, axis=1, smooth=1., alpha=1., reduction=None): | |
| store_attr() | |
| self.focal_loss = FocalLossFlat(axis=axis) | |
| self.dice_loss = DiceLoss(axis, smooth) | |
| def __call__(self, pred, targ): | |
| return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ) | |
| def decodes(self, x): return x.argmax(dim=self.axis) | |
| def activation(self, x): return F.softmax(x, dim=self.axis) | |
| learn = load_learner('resnet50_dice906.pkl') | |
| def predict(img): | |
| img = PILImage.create(img) | |
| inp,preds,pred_idx,probs = learn.predict(img,with_input=True) | |
| cm = mpl.cm.rainbow | |
| bounds = [0,1,2,3,4,5] | |
| norm = mpl.colors.BoundaryNorm(bounds, cm.N) | |
| fig,axs = plt.subplots(1,1,figsize = (10,10),gridspec_kw={'width_ratios': [1]}) | |
| #axs[0].imshow(inp.permute(1,2,0)) | |
| axs.imshow(preds,cmap = 'rainbow',vmax = 4) | |
| divider = make_axes_locatable(axs) | |
| cax1 = divider.append_axes('right', size = '3%', pad = 0.05) | |
| cbar = fig.colorbar(mpl.cm.ScalarMappable(norm,cm),ax = axs, cax = cax1, label = "") | |
| cbar.ax.set_yticklabels(['','BG','Cabage','Kale','Fennel','Bean'], fontsize = 15) | |
| axs.set_axis_off() | |
| plt.savefig("image.png", dpi = 100) | |
| pred = PILImage.create("image.png") | |
| return np.array(pred) | |
| title = "Multispecies Canopies Segmenter" | |
| description = "A multispecies canopies Segmentation model that will segment your images according to the species your plants belong. It can identify Beans, Cabbages ,Kale & Fennel BG stands for BackGround/" | |
| examples = ['192.168.42.17-00610-AFECF1.jpg', | |
| "192.168.42.19-00951-AFECF1.jpg", | |
| "192.168.42.19-01010-AFECF2", | |
| "Registered_images_190_camera7.jpg"] | |
| interpretation='default' | |
| gr.Interface(fn=predict, inputs="image", outputs="image", title=title,description=description, | |
| examples=examples).launch(share=True) | |
| # In[ ]: | |
| # In[ ]: | |