sketch / app.py
khan994's picture
Update app.py
b646581
from fastai.vision.all import *
import cv2
import gradio as gr
import glob
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
#@title DataLoader
path = "drawings2"
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y=parent_label,
splitter = RandomSplitter(valid_pct=0.2),
item_tfms=RandomResizedCrop(128, min_scale=0.7),
batch_tfms=[*aug_transforms(max_rotate=0, max_warp=0),
Normalize.from_stats(*imagenet_stats)])
dls_augmented = dblock.dataloaders(path, shuffle=True)
learn=vision_learner(dls_augmented, resnet152)
learn.load("rn152_sketch_9label_mixup_0_3")
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
def gradcam(img_create):
pred,idx,probs=learn.predict(img_create)
return dict(zip(categories, map(float, probs)))
categories = ('balkanlar_osmanli', 'bursa', 'cankirievi', 'diyarbakir', 'kayseri', 'kula', 'ordu', 'ormana_antalya', 'pazaryeri')
#def classify_img(img):
# pred,idx,probs=learn.predict(img)
# return dict(zip(categories, map(float, probs)))
image=gr.inputs.Image(shape=(128,128))
label=gr.outputs.Label()
#examples_=[]
#for i in glob.glob("valid/**/*.jpg", recursive=True):
# examples_.append(i)
examples=["sf107.jpg", "sf27_example3.png", "diyarbakir-1.jpg", "sf108.jpg", "sf135.png"]
demo = gr.Interface(fn=gradcam, inputs=image, outputs=[label], examples=examples)
demo.launch(inline=False)