| from fastai.vision.all import * | |
| import gradio as gr | |
| import glob | |
| import torch | |
| import timm | |
| from timm.models import convnext | |
| convnext_model = 'convnext_tiny_in22k' | |
| model_architecture=timm.create_model(convnext_model) | |
| class FastaiConvNext(torch.nn.Module): | |
| def __init__(self, original_model): | |
| super().__init__() | |
| self.features = original_model | |
| def forward(self, x): | |
| x = self.features(x) | |
| return x | |
| model = FastaiConvNext(model_architecture) | |
| learn = load_learner("convnext_mixup_0_33.pkl") | |
| categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna') | |
| def classify_img(img): | |
| pred,idx,probs=learn.predict(img) | |
| return dict(zip(categories, map(float, probs))) | |
| image=gr.inputs.Image(shape=(128,128)) | |
| label=gr.outputs.Label() | |
| examples=["filibe-1-1.jpg", | |
| "ohrid-3-1.jpg", | |
| "varna-1-1.jpg"] | |
| demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples) | |
| demo.launch(inline=False) |