File size: 2,837 Bytes
02ac88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from model import SiameseNet

# ── Load model ────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = SiameseNet(embedding_dim=128).to(device)
ckpt   = torch.load("../checkpoints/best.pt", map_location=device)
model.load_state_dict(ckpt["model_state"])
model.eval()

# ── Transform ─────────────────────────────────────────────────
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((105, 105)),
    transforms.ToTensor(),
    transforms.Normalize([0.9220], [0.2256]),
])

def preprocess(img: Image.Image) -> torch.Tensor:
    return transform(img).unsqueeze(0).to(device)   # [1, 1, 105, 105]

# ── Inference ─────────────────────────────────────────────────
def compare_images(img1: Image.Image, img2: Image.Image):
    with torch.no_grad():
        emb1 = model.get_embedding(preprocess(img1))
        emb2 = model.get_embedding(preprocess(img2))
        similarity = F.cosine_similarity(emb1, emb2).item()

    match     = similarity > 0.5
    label     = "Same class" if match else "Different class"
    conf      = f"{similarity * 100:.1f}%"
    colour    = "green" if match else "red"

    result = f"""
    <div style='text-align:center; padding: 16px;'>
      <div style='font-size: 28px; font-weight: 600; color: {colour};'>{label}</div>
      <div style='font-size: 16px; color: gray; margin-top: 8px;'>
        Cosine similarity: <strong>{conf}</strong>
      </div>
    </div>
    """
    return result, round(similarity, 4)

# ── UI ────────────────────────────────────────────────────────
with gr.Blocks(title="Siamese Few-Shot Recognition") as demo:
    gr.Markdown("## Siamese Network β€” Few-Shot Image Similarity")
    gr.Markdown("Upload two images. The model will tell you if they belong to the same class.")

    with gr.Row():
        img1 = gr.Image(type="pil", label="Image 1")
        img2 = gr.Image(type="pil", label="Image 2")

    btn = gr.Button("Compare", variant="primary")

    result_html  = gr.HTML()
    result_score = gr.Number(label="Raw similarity score")

    btn.click(fn=compare_images, inputs=[img1, img2],
              outputs=[result_html, result_score])

if __name__ == "__main__":
    demo.launch(share=True)   # share=True gives a public URL