|
|
import torch |
|
|
from torchvision import transforms |
|
|
from torchvision.models import resnet18 |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
class_names = ['Nu.1', 'Nu.10', 'Nu.100', 'Nu.1000', 'Nu.20', 'Nu.5', 'Nu.50', 'Nu.500'] |
|
|
|
|
|
|
|
|
device = torch.device('cpu') |
|
|
|
|
|
|
|
|
model = resnet18(pretrained=False) |
|
|
|
|
|
|
|
|
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) |
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load("currency_model.pth", map_location=device)) |
|
|
|
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((256, 256)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
|
|
|
def predict(image): |
|
|
image = image.convert("RGB") |
|
|
image = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image) |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
return class_names[predicted.item()] |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=gr.Text(), |
|
|
title="Bhutanese Currency Detector", |
|
|
description="Upload a currency note image to identify its value." |
|
|
) |
|
|
|
|
|
interface.launch() |
|
|
|