#!/usr/bin/env python # coding: utf-8 # In[8]: import gradio as gr from fastai.vision.all import * import skimage import matplotlib.pyplot as plt import matplotlib as mpl from mpl_toolkits.axes_grid1 import make_axes_locatable def get_files(path): folders = path.ls() imgs = [] for f in folders: try: imgs.append(glob.glob(f'{f}/*camera7*.jpg')[0]) except: continue return imgs def get_mask(filename): txt = Path(Path(filename).name).stem scene_id = re.findall("\d+", txt)[0] db_path = 'celine-3.db' conn = sql.connect(db_path) c = conn.cursor() query = "SELECT trt FROM scenes WHERE scene_id == ?" params = [scene_id] c.execute(query,params) trt = c.fetchall()[0][0] trt_list = ['PCCCC', 'PCFCF', 'PFEFE', 'PHNHN'] mask = np.load(glob.glob(f'{path/Path(filename).stem}/*.npy')[0]).squeeze() mask = np.where(mask>0.5,1,0) mask= mask.astype(np.uint8) mask *= [i+1 for i,j in enumerate(trt_list) if j == trt][0] return mask @patch def __call__(self:DiceLoss, pred, targ): targ = self._one_hot(targ, pred.shape[self.axis]) pred, targ = TensorBase(pred), TensorBase(targ) assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs' pred = self.activation(pred) sum_dims = list(range(2, len(pred.shape))) inter = torch.sum(pred*targ, dim=sum_dims) union = (torch.sum(pred**2+targ, dim=sum_dims) if self.square_in_union else torch.sum(pred+targ, dim=sum_dims)) dice_score = (2. * inter + self.smooth)/(union + self.smooth) x = ((1-dice_score).flatten().mean() if self.reduction == "mean" else (1-dice_score).flatten().sum()) return torch.log((torch.exp(x) + torch.exp(-x)) / 2.0) class CombinedLoss: "Dice and Focal combined" def __init__(self, axis=1, smooth=1., alpha=1., reduction=None): store_attr() self.focal_loss = FocalLossFlat(axis=axis) self.dice_loss = DiceLoss(axis, smooth) def __call__(self, pred, targ): return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ) def decodes(self, x): return x.argmax(dim=self.axis) def activation(self, x): return F.softmax(x, dim=self.axis) learn = load_learner('resnet50_dice906.pkl') def predict(img): img = PILImage.create(img) inp,preds,pred_idx,probs = learn.predict(img,with_input=True) cm = mpl.cm.rainbow bounds = [0,1,2,3,4,5] norm = mpl.colors.BoundaryNorm(bounds, cm.N) fig,axs = plt.subplots(1,1,figsize = (10,10),gridspec_kw={'width_ratios': [1]}) #axs[0].imshow(inp.permute(1,2,0)) axs.imshow(preds,cmap = 'rainbow',vmax = 4) divider = make_axes_locatable(axs) cax1 = divider.append_axes('right', size = '3%', pad = 0.05) cbar = fig.colorbar(mpl.cm.ScalarMappable(norm,cm),ax = axs, cax = cax1, label = "") cbar.ax.set_yticklabels(['','BG','Cabage','Kale','Fennel','Bean'], fontsize = 15) axs.set_axis_off() plt.savefig("image.png", dpi = 100) pred = PILImage.create("image.png") return np.array(pred) title = "Multispecies Canopies Segmenter" description = "A multispecies canopies Segmentation model that will segment your images according to the species your plants belong. It can identify Beans, Cabbages ,Kale & Fennel BG stands for BackGround/" examples = ['192.168.42.17-00610-AFECF1.jpg', "192.168.42.19-00951-AFECF1.jpg", "192.168.42.19-01010-AFECF2", "Registered_images_190_camera7.jpg"] interpretation='default' gr.Interface(fn=predict, inputs="image", outputs="image", title=title,description=description, examples=examples).launch(share=True) # In[ ]: # In[ ]: