import gradio as gr import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image from model import SiameseNet # ── Load model ──────────────────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SiameseNet(embedding_dim=128).to(device) ckpt = torch.load("../checkpoints/best.pt", map_location=device) model.load_state_dict(ckpt["model_state"]) model.eval() # ── Transform ───────────────────────────────────────────────── transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((105, 105)), transforms.ToTensor(), transforms.Normalize([0.9220], [0.2256]), ]) def preprocess(img: Image.Image) -> torch.Tensor: return transform(img).unsqueeze(0).to(device) # [1, 1, 105, 105] # ── Inference ───────────────────────────────────────────────── def compare_images(img1: Image.Image, img2: Image.Image): with torch.no_grad(): emb1 = model.get_embedding(preprocess(img1)) emb2 = model.get_embedding(preprocess(img2)) similarity = F.cosine_similarity(emb1, emb2).item() match = similarity > 0.5 label = "Same class" if match else "Different class" conf = f"{similarity * 100:.1f}%" colour = "green" if match else "red" result = f"""
{label}
Cosine similarity: {conf}
""" return result, round(similarity, 4) # ── UI ──────────────────────────────────────────────────────── with gr.Blocks(title="Siamese Few-Shot Recognition") as demo: gr.Markdown("## Siamese Network — Few-Shot Image Similarity") gr.Markdown("Upload two images. The model will tell you if they belong to the same class.") with gr.Row(): img1 = gr.Image(type="pil", label="Image 1") img2 = gr.Image(type="pil", label="Image 2") btn = gr.Button("Compare", variant="primary") result_html = gr.HTML() result_score = gr.Number(label="Raw similarity score") btn.click(fn=compare_images, inputs=[img1, img2], outputs=[result_html, result_score]) if __name__ == "__main__": demo.launch(share=True) # share=True gives a public URL