LeafSeg / app.py
Charlolegossbo's picture
Upload app.py
b61e6b6
#!/usr/bin/env python
# coding: utf-8
# In[8]:
import gradio as gr
from fastai.vision.all import *
import skimage
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
def get_files(path):
folders = path.ls()
imgs = []
for f in folders:
try:
imgs.append(glob.glob(f'{f}/*camera7*.jpg')[0])
except:
continue
return imgs
def get_mask(filename):
txt = Path(Path(filename).name).stem
scene_id = re.findall("\d+", txt)[0]
db_path = 'celine-3.db'
conn = sql.connect(db_path)
c = conn.cursor()
query = "SELECT trt FROM scenes WHERE scene_id == ?"
params = [scene_id]
c.execute(query,params)
trt = c.fetchall()[0][0]
trt_list = ['PCCCC', 'PCFCF', 'PFEFE', 'PHNHN']
mask = np.load(glob.glob(f'{path/Path(filename).stem}/*.npy')[0]).squeeze()
mask = np.where(mask>0.5,1,0)
mask= mask.astype(np.uint8)
mask *= [i+1 for i,j in enumerate(trt_list) if j == trt][0]
return mask
@patch
def __call__(self:DiceLoss, pred, targ):
targ = self._one_hot(targ, pred.shape[self.axis])
pred, targ = TensorBase(pred), TensorBase(targ)
assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs'
pred = self.activation(pred)
sum_dims = list(range(2, len(pred.shape)))
inter = torch.sum(pred*targ, dim=sum_dims)
union = (torch.sum(pred**2+targ, dim=sum_dims) if self.square_in_union
else torch.sum(pred+targ, dim=sum_dims))
dice_score = (2. * inter + self.smooth)/(union + self.smooth)
x = ((1-dice_score).flatten().mean() if self.reduction == "mean" else (1-dice_score).flatten().sum())
return torch.log((torch.exp(x) + torch.exp(-x)) / 2.0)
class CombinedLoss:
"Dice and Focal combined"
def __init__(self, axis=1, smooth=1., alpha=1., reduction=None):
store_attr()
self.focal_loss = FocalLossFlat(axis=axis)
self.dice_loss = DiceLoss(axis, smooth)
def __call__(self, pred, targ):
return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
learn = load_learner('resnet50_dice906.pkl')
def predict(img):
img = PILImage.create(img)
inp,preds,pred_idx,probs = learn.predict(img,with_input=True)
cm = mpl.cm.rainbow
bounds = [0,1,2,3,4,5]
norm = mpl.colors.BoundaryNorm(bounds, cm.N)
fig,axs = plt.subplots(1,1,figsize = (10,10),gridspec_kw={'width_ratios': [1]})
#axs[0].imshow(inp.permute(1,2,0))
axs.imshow(preds,cmap = 'rainbow',vmax = 4)
divider = make_axes_locatable(axs)
cax1 = divider.append_axes('right', size = '3%', pad = 0.05)
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm,cm),ax = axs, cax = cax1, label = "")
cbar.ax.set_yticklabels(['','BG','Cabage','Kale','Fennel','Bean'], fontsize = 15)
axs.set_axis_off()
plt.savefig("image.png", dpi = 100)
pred = PILImage.create("image.png")
return np.array(pred)
title = "Multispecies Canopies Segmenter"
description = "A multispecies canopies Segmentation model that will segment your images according to the species your plants belong. It can identify Beans, Cabbages ,Kale & Fennel BG stands for BackGround/"
examples = ['192.168.42.17-00610-AFECF1.jpg',
"192.168.42.19-00951-AFECF1.jpg",
"192.168.42.19-01010-AFECF2",
"Registered_images_190_camera7.jpg"]
interpretation='default'
gr.Interface(fn=predict, inputs="image", outputs="image", title=title,description=description,
examples=examples).launch(share=True)
# In[ ]:
# In[ ]: