LETTER / src /demo.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from model import SiameseNet
# ── Load model ────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNet(embedding_dim=128).to(device)
ckpt = torch.load("../checkpoints/best.pt", map_location=device)
model.load_state_dict(ckpt["model_state"])
model.eval()
# ── Transform ─────────────────────────────────────────────────
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.ToTensor(),
transforms.Normalize([0.9220], [0.2256]),
])
def preprocess(img: Image.Image) -> torch.Tensor:
return transform(img).unsqueeze(0).to(device) # [1, 1, 105, 105]
# ── Inference ─────────────────────────────────────────────────
def compare_images(img1: Image.Image, img2: Image.Image):
with torch.no_grad():
emb1 = model.get_embedding(preprocess(img1))
emb2 = model.get_embedding(preprocess(img2))
similarity = F.cosine_similarity(emb1, emb2).item()
match = similarity > 0.5
label = "Same class" if match else "Different class"
conf = f"{similarity * 100:.1f}%"
colour = "green" if match else "red"
result = f"""
<div style='text-align:center; padding: 16px;'>
<div style='font-size: 28px; font-weight: 600; color: {colour};'>{label}</div>
<div style='font-size: 16px; color: gray; margin-top: 8px;'>
Cosine similarity: <strong>{conf}</strong>
</div>
</div>
"""
return result, round(similarity, 4)
# ── UI ────────────────────────────────────────────────────────
with gr.Blocks(title="Siamese Few-Shot Recognition") as demo:
gr.Markdown("## Siamese Network β€” Few-Shot Image Similarity")
gr.Markdown("Upload two images. The model will tell you if they belong to the same class.")
with gr.Row():
img1 = gr.Image(type="pil", label="Image 1")
img2 = gr.Image(type="pil", label="Image 2")
btn = gr.Button("Compare", variant="primary")
result_html = gr.HTML()
result_score = gr.Number(label="Raw similarity score")
btn.click(fn=compare_images, inputs=[img1, img2],
outputs=[result_html, result_score])
if __name__ == "__main__":
demo.launch(share=True) # share=True gives a public URL