from fastai.vision.all import * import gradio as gr import glob import torch import timm from timm.models import convnext convnext_model = 'convnext_tiny_in22k' model_architecture=timm.create_model(convnext_model) class FastaiConvNext(torch.nn.Module): def __init__(self, original_model): super().__init__() self.features = original_model def forward(self, x): x = self.features(x) return x model = FastaiConvNext(model_architecture) learn = load_learner("convnext_mixup_0_33.pkl") categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna') def classify_img(img): pred,idx,probs=learn.predict(img) return dict(zip(categories, map(float, probs))) image=gr.inputs.Image(shape=(128,128)) label=gr.outputs.Label() examples=["filibe-1-1.jpg", "ohrid-3-1.jpg", "varna-1-1.jpg"] demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples) demo.launch(inline=False)