mctrack-reid / load_reid.py
blank4hd's picture
Upload Re-ID model checkpoints and model card
1d7cec6 verified
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import models, transforms
class ReIDModel(nn.Module):
"""Minimal Re-ID model matching the trained checkpoint architecture.
ResNet-50 backbone (final stride 1) -> global average pooling -> BNNeck.
Returns 256-dim L2-normalized embeddings during inference.
"""
def __init__(self, num_classes: int = 751, embedding_dim: int = 256):
super().__init__()
backbone = models.resnet50(weights=None)
backbone.layer4[0].conv2.stride = (1, 1)
backbone.layer4[0].downsample[0].stride = (1, 1)
self.backbone = nn.Sequential(*list(backbone.children())[:-2])
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.feature_dim = 2048
self.bnneck = nn.BatchNorm1d(self.feature_dim)
self.bnneck.bias.requires_grad_(False)
nn.init.constant_(self.bnneck.weight, 1.0)
nn.init.constant_(self.bnneck.bias, 0.0)
self.embedding_layer = nn.Linear(self.feature_dim, embedding_dim, bias=False)
nn.init.kaiming_normal_(self.embedding_layer.weight, mode="fan_out")
self.num_classes = num_classes
if num_classes > 0:
self.classifier = nn.Linear(self.feature_dim, num_classes, bias=False)
nn.init.normal_(self.classifier.weight, std=0.001)
else:
self.classifier = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return an L2-normalized embedding of shape (N, embedding_dim)."""
feat_map = self.backbone(x)
pooled = self.global_pool(feat_map).flatten(1)
bn_features = self.bnneck(pooled)
embedding = self.embedding_layer(bn_features)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
def load_model(checkpoint_path: str | Path, device: str = "cpu") -> ReIDModel:
"""Load the trained Re-ID model from a stripped checkpoint."""
model = ReIDModel(num_classes=751, embedding_dim=256)
state = torch.load(str(checkpoint_path), map_location=device, weights_only=False)
sd = state["state_dict"] if "state_dict" in state else state
missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False)
if missing_keys:
print(f"Warning: missing keys: {len(missing_keys)} keys")
if unexpected_keys:
print(f"Warning: unexpected keys: {len(unexpected_keys)} keys")
model.eval()
model.to(device)
return model
_PREPROCESS = transforms.Compose(
[
transforms.Resize((256, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""Preprocess a PIL image to model input format."""
if image.mode != "RGB":
image = image.convert("RGB")
return _PREPROCESS(image).unsqueeze(0)
def cosine_similarity(emb_a: torch.Tensor, emb_b: torch.Tensor) -> float:
"""Cosine similarity between two L2-normalized embeddings, in [-1, 1]."""
return float(F.cosine_similarity(emb_a, emb_b).item())
if __name__ == "__main__":
import sys
from huggingface_hub import hf_hub_download
print("Downloading model from Hugging Face...")
ckpt = hf_hub_download(repo_id="blank4hd/mctrack-reid", filename="best_60ep.pth")
model = load_model(ckpt)
print(f"Model loaded. Embedding dim: 256, parameter count: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
if len(sys.argv) >= 3:
img_a = Image.open(sys.argv[1])
img_b = Image.open(sys.argv[2])
x_a = preprocess_image(img_a)
x_b = preprocess_image(img_b)
with torch.no_grad():
emb_a = model(x_a)
emb_b = model(x_b)
sim = cosine_similarity(emb_a, emb_b)
print(f"Cosine similarity: {sim:.4f}")
if sim > 0.7:
print("Likely SAME person")
elif sim > 0.4:
print("Possibly same person (uncertain)")
else:
print("Likely DIFFERENT people")
else:
print("Usage: python load_reid.py <image_a.jpg> <image_b.jpg>")