draft2 / app.py
khan994's picture
Update app.py
4031ad5
from fastai.vision.all import *
import gradio as gr
import glob
import torch
import timm
from timm.models import convnext
convnext_model = 'convnext_tiny_in22k'
model_architecture=timm.create_model(convnext_model)
class FastaiConvNext(torch.nn.Module):
def __init__(self, original_model):
super().__init__()
self.features = original_model
def forward(self, x):
x = self.features(x)
return x
model = FastaiConvNext(model_architecture)
learn = load_learner("convnext_mixup_0_33.pkl")
categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
def classify_img(img):
pred,idx,probs=learn.predict(img)
return dict(zip(categories, map(float, probs)))
image=gr.inputs.Image(shape=(128,128))
label=gr.outputs.Label()
examples=["filibe-1-1.jpg",
"ohrid-3-1.jpg",
"varna-1-1.jpg"]
demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples)
demo.launch(inline=False)