BuVision / app.py
sherab65's picture
git commit -m "app.py"
21227aa verified
import torch
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image
import gradio as gr
# Class labels
class_names = ['Nu.1', 'Nu.10', 'Nu.100', 'Nu.1000', 'Nu.20', 'Nu.5', 'Nu.50', 'Nu.500']
# Force CPU
device = torch.device('cpu')
# Step 1: Define model architecture
model = resnet18(pretrained=False)
# Step 2: Modify final layer (assuming 8 classes)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
# Step 3: Load weights
model.load_state_dict(torch.load("currency_model.pth", map_location=device))
# Step 4: Set to eval mode
model.to(device)
model.eval()
# Image transform
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
# Prediction function
def predict(image):
image = image.convert("RGB")
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return class_names[predicted.item()]
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Text(),
title="Bhutanese Currency Detector",
description="Upload a currency note image to identify its value."
)
interface.launch()