| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image |
| from model import SiameseNet |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = SiameseNet(embedding_dim=128).to(device) |
| ckpt = torch.load("../checkpoints/best.pt", map_location=device) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Grayscale(), |
| transforms.Resize((105, 105)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.9220], [0.2256]), |
| ]) |
|
|
| def preprocess(img: Image.Image) -> torch.Tensor: |
| return transform(img).unsqueeze(0).to(device) |
|
|
| |
| def compare_images(img1: Image.Image, img2: Image.Image): |
| with torch.no_grad(): |
| emb1 = model.get_embedding(preprocess(img1)) |
| emb2 = model.get_embedding(preprocess(img2)) |
| similarity = F.cosine_similarity(emb1, emb2).item() |
|
|
| match = similarity > 0.5 |
| label = "Same class" if match else "Different class" |
| conf = f"{similarity * 100:.1f}%" |
| colour = "green" if match else "red" |
|
|
| result = f""" |
| <div style='text-align:center; padding: 16px;'> |
| <div style='font-size: 28px; font-weight: 600; color: {colour};'>{label}</div> |
| <div style='font-size: 16px; color: gray; margin-top: 8px;'> |
| Cosine similarity: <strong>{conf}</strong> |
| </div> |
| </div> |
| """ |
| return result, round(similarity, 4) |
|
|
| |
| with gr.Blocks(title="Siamese Few-Shot Recognition") as demo: |
| gr.Markdown("## Siamese Network β Few-Shot Image Similarity") |
| gr.Markdown("Upload two images. The model will tell you if they belong to the same class.") |
|
|
| with gr.Row(): |
| img1 = gr.Image(type="pil", label="Image 1") |
| img2 = gr.Image(type="pil", label="Image 2") |
|
|
| btn = gr.Button("Compare", variant="primary") |
|
|
| result_html = gr.HTML() |
| result_score = gr.Number(label="Raw similarity score") |
|
|
| btn.click(fn=compare_images, inputs=[img1, img2], |
| outputs=[result_html, result_score]) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |