TripletVGG11
A Siamese network based on VGG11 architecture trained with triplet loss on CIFAR-10 for learning discriminative image embeddings.
Model Description
TripletVGG11 is a metric learning model that learns to embed images into a 128-dimensional space where similar images are close together and dissimilar images are far apart. The model uses a VGG11 backbone pre-trained on ImageNet, followed by a linear projection layer and L2 normalization.
Key Features:
- Architecture: VGG11-based Siamese network
- Embedding Dimension: 128
- Loss Function: Triplet loss with margin
- Training Dataset: CIFAR-10 (50,000 training images)
- Performance: 0.9597 AUC on validation set
Model Architecture
Input Image (32x32x3)
โ
VGG11 Feature Extractor (pretrained on ImageNet)
โ
Flatten
โ
Linear(512 โ 128)
โ
L2 Normalization
โ
Embedding Vector (128-dimensional)
The model uses cosine similarity to compare embeddings:
- Similar images have cosine similarity close to 1
- Dissimilar images have cosine similarity close to 0 or negative
Training Details
Training Data
The model was trained on CIFAR-10 dataset with 10 classes:
- airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- 50,000 training images organized into triplets
- 5% validation split
Hyperparameters
The following hyperparameters were optimized using Optuna (100 trials):
| Parameter | Value |
|---|---|
| Batch Size | 64 |
| Learning Rate | 0.000161 |
| Optimizer | Adam |
| Triplet Margin | 0.265 |
| Embedding Size | 128 |
| Epochs | 15-30 |
| KoLeo Loss | No |
| Gradient Accumulation | 1 |
Training Procedure
- Images from the same class form positive pairs
- Images from different classes form negative pairs
- Triplets (anchor, positive, negative) are constructed
- Model is trained to minimize triplet loss:
max(0, d(a,p) - d(a,n) + margin) - Distance is computed as:
1 - cosine_similarity
Performance
| Metric | Value |
|---|---|
| Validation AUC | 0.9597 |
| Intra-class Distance | ~0.22 |
| Inter-class Distance | ~1.04 |
| Separation Margin | ~0.81 |
| Good Triplets Ratio | >93% |
Usage
Installation
pip install torch torchvision
Loading the Model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class VGG11Embedding(nn.Module):
def __init__(self, embedding_size=128, weights=None):
super(VGG11Embedding, self).__init__()
vgg = models.vgg11(weights=weights)
self.features = vgg.features
self.linear = nn.Linear(512, embedding_size)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.linear(x)
x = F.normalize(x, p=2, dim=1)
return x
# Load the model
model = VGG11Embedding(embedding_size=128)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
Computing Embeddings
import torch
from torchvision import transforms
from PIL import Image
# Preprocessing
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616])
])
# Load and preprocess image
image = Image.open("image.png")
image_tensor = transform(image).unsqueeze(0)
# Get embedding
with torch.no_grad():
embedding = model(image_tensor)
print(f"Embedding shape: {embedding.shape}") # torch.Size([1, 128])
Computing Similarity
# Compare two images
image1 = transform(Image.open("image1.jpg")).unsqueeze(0)
image2 = transform(Image.open("image2.jpg")).unsqueeze(0)
with torch.no_grad():
emb1 = model(image1)
emb2 = model(image2)
# Cosine similarity (already normalized)
similarity = torch.mm(emb1, emb2.t()).item()
print(f"Similarity: {similarity:.4f}")
# High similarity (> 0.5): similar images
# Low similarity (< 0.3): dissimilar images
Finding Similar Images
import numpy as np
# Create embeddings for a database of images
database_embeddings = []
for image_path in image_database:
img = transform(Image.open(image_path)).unsqueeze(0)
with torch.no_grad():
emb = model(img)
database_embeddings.append(emb)
database_embeddings = torch.cat(database_embeddings, dim=0)
# Query with a new image
query_image = transform(Image.open("query.jpg")).unsqueeze(0)
with torch.no_grad():
query_embedding = model(query_image)
# Compute similarities
similarities = torch.mm(query_embedding, database_embeddings.t()).squeeze(0)
# Get top-k most similar images
top_k = 5
top_k_indices = torch.topk(similarities, k=top_k).indices
print(f"Most similar images: {top_k_indices.tolist()}")
Model Card Authors
Adlane Ladjal
Model Card Contact
For questions or feedback, please open an issue in the GitHub repository.
- Downloads last month
- 7
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support