|
|
from fastai.vision.all import * |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import glob |
|
|
|
|
|
class Hook(): |
|
|
def hook_func(self, m, i, o): self.stored = o.detach().clone() |
|
|
|
|
|
|
|
|
|
|
|
path = "drawings2" |
|
|
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock), |
|
|
get_items = get_image_files, |
|
|
get_y=parent_label, |
|
|
splitter = RandomSplitter(valid_pct=0.2), |
|
|
item_tfms=RandomResizedCrop(128, min_scale=0.7), |
|
|
batch_tfms=[*aug_transforms(max_rotate=0, max_warp=0), |
|
|
Normalize.from_stats(*imagenet_stats)]) |
|
|
dls_augmented = dblock.dataloaders(path, shuffle=True) |
|
|
|
|
|
learn=vision_learner(dls_augmented, resnet152) |
|
|
learn.load("rn152_sketch_9label_mixup_0_3") |
|
|
|
|
|
class Hook(): |
|
|
def hook_func(self, m, i, o): self.stored = o.detach().clone() |
|
|
|
|
|
def gradcam(img_create): |
|
|
|
|
|
pred,idx,probs=learn.predict(img_create) |
|
|
return dict(zip(categories, map(float, probs))) |
|
|
|
|
|
categories = ('balkanlar_osmanli', 'bursa', 'cankirievi', 'diyarbakir', 'kayseri', 'kula', 'ordu', 'ormana_antalya', 'pazaryeri') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image=gr.inputs.Image(shape=(128,128)) |
|
|
label=gr.outputs.Label() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples=["sf107.jpg", "sf27_example3.png", "diyarbakir-1.jpg", "sf108.jpg", "sf135.png"] |
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=gradcam, inputs=image, outputs=[label], examples=examples) |
|
|
|
|
|
demo.launch(inline=False) |