draft2 / app.py
khan994's picture
Update app.py
bf17d6c
raw
history blame
1.55 kB
from fastai.vision.all import *
import gradio as gr
import glob
import timm
from timm.models import convnext
convnext_model = 'convnext_tiny_in22k'
model_architecture=timm.create_model(convnext_model)
import torch
class FastaiConvNext(torch.nn.Module):
def __init__(self, original_model):
super().__init__()
self.features = original_model
def forward(self, x):
x = self.features(x)
return x
model = FastaiConvNext(model_architecture)
#class Hook():
# def hook_func(self, m, i, o): self.stored = o.detach().clone()
#learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True)
#categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
learn = load_learner("convnext_mixup_0_33.pkl", cpu=True)
categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
def classify_img(img):
pred,idx,probs=learn.predict(img)
return dict(zip(categories, map(float, probs)))
image=gr.inputs.Image(shape=(128,128))
label=gr.outputs.Label()
examples_=[]
for i in glob.glob("valid/**/*.jpg", recursive=True):
examples_.append(i)
examples=["filibe-1-1.jpg",
"ohrid-3-1.jpg",
"varna-1-1.jpg"]
demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples)
demo.launch(inline=False)