LETTER / src /fp_sanity_check.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
# Quick sanity check — Run this before running a single line of training loop code:
import torch
from model import SiameseNet
from loss import ContrastiveLoss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNet(embedding_dim=128).to(device)
criterion = ContrastiveLoss(margin=1.0)
# Fake a batch matching your DataLoader output shape
img1 = torch.randn(32, 1, 105, 105).to(device)
img2 = torch.randn(32, 1, 105, 105).to(device)
labels = torch.randint(0, 2, (32,)).float().to(device)
emb1, emb2 = model(img1, img2)
loss, dist = criterion(emb1, emb2, labels)
print(f"emb1 shape : {emb1.shape}") # [32, 128]
print(f"emb2 shape : {emb2.shape}") # [32, 128]
print(f"emb1 norm : {emb1.norm(dim=1).mean():.4f}") # should be ~1.0
print(f"loss : {loss.item():.4f}")
print(f"dist range : {dist.min():.3f}{dist.max():.3f}")
print("Sanity check passed")