Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models, datasets | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| # ---------------------------- | |
| # Settings | |
| # ---------------------------- | |
| MODEL_PTH = "banknote_cnn_resnet18.pth" # put your model here | |
| DATA_DIR = "notes_clean" # folder used during training | |
| TARGET_SIZE = 224 | |
| TOP_K = 3 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ---------------------------- | |
| # Recover class mapping | |
| # ---------------------------- | |
| dataset = datasets.ImageFolder(DATA_DIR) | |
| class_to_idx = dataset.class_to_idx | |
| idx_to_class = {v:k for k,v in class_to_idx.items()} | |
| print("Recovered class mapping:", idx_to_class) | |
| # ---------------------------- | |
| # Load model | |
| # ---------------------------- | |
| num_classes = len(class_to_idx) | |
| model = models.resnet18(weights=None) | |
| model.fc = torch.nn.Linear(model.fc.in_features, num_classes) | |
| state_dict = torch.load(MODEL_PTH, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| print("Model loaded") | |
| # ---------------------------- | |
| # Preprocessing | |
| # ---------------------------- | |
| normalize = transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) | |
| def resize_and_pad(img: Image.Image, target_size=TARGET_SIZE): | |
| img = img.convert("RGB") | |
| w, h = img.size | |
| scale = target_size / max(w, h) | |
| new_w, new_h = int(w*scale), int(h*scale) | |
| img_resized = img.resize((new_w,new_h), Image.BILINEAR) | |
| new_img = Image.new("RGB", (target_size,target_size)) | |
| paste_x = (target_size-new_w)//2 | |
| paste_y = (target_size-new_h)//2 | |
| new_img.paste(img_resized,(paste_x,paste_y)) | |
| return new_img | |
| to_tensor_and_norm = transforms.Compose([transforms.ToTensor(), normalize]) | |
| # ---------------------------- | |
| # Prediction function | |
| # ---------------------------- | |
| def predict(image: Image.Image): | |
| orig = image.convert("RGB") | |
| pre = resize_and_pad(orig, TARGET_SIZE) | |
| tensor = to_tensor_and_norm(pre).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] | |
| # get top-k predictions | |
| topk_idx = probs.argsort()[-TOP_K:][::-1] | |
| preds = [(int(i), idx_to_class[int(i)], float(probs[int(i)])) for i in topk_idx] | |
| # build readable text | |
| text_lines = [f"Predicted index: {preds[0][0]}", f"Predicted label: {preds[0][1]}", f"Top-{TOP_K}:"] | |
| for idx, label, p in preds: | |
| text_lines.append(f" {idx} {label} — {p*100:.2f}%") | |
| text = "\n".join(text_lines) | |
| return orig, pre, text | |
| # ---------------------------- | |
| # Gradio interface | |
| # ---------------------------- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Banknote"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Original Image"), | |
| gr.Image(type="pil", label="Preprocessed Image"), | |
| gr.Textbox(label="Prediction Info") | |
| ], | |
| title="Pakistan Banknote Classifier", | |
| description="Upload a Pakistani banknote. Preprocessing + prediction shown." | |
| ) | |
| iface.launch() | |