TripletVGG11

A Siamese network based on VGG11 architecture trained with triplet loss on CIFAR-10 for learning discriminative image embeddings.

Model Description

TripletVGG11 is a metric learning model that learns to embed images into a 128-dimensional space where similar images are close together and dissimilar images are far apart. The model uses a VGG11 backbone pre-trained on ImageNet, followed by a linear projection layer and L2 normalization.

Key Features:

  • Architecture: VGG11-based Siamese network
  • Embedding Dimension: 128
  • Loss Function: Triplet loss with margin
  • Training Dataset: CIFAR-10 (50,000 training images)
  • Performance: 0.9597 AUC on validation set

Model Architecture

Input Image (32x32x3)
    โ†“
VGG11 Feature Extractor (pretrained on ImageNet)
    โ†“
Flatten
    โ†“
Linear(512 โ†’ 128)
    โ†“
L2 Normalization
    โ†“
Embedding Vector (128-dimensional)

The model uses cosine similarity to compare embeddings:

  • Similar images have cosine similarity close to 1
  • Dissimilar images have cosine similarity close to 0 or negative

Training Details

Training Data

The model was trained on CIFAR-10 dataset with 10 classes:

  • airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
  • 50,000 training images organized into triplets
  • 5% validation split

Hyperparameters

The following hyperparameters were optimized using Optuna (100 trials):

Parameter Value
Batch Size 64
Learning Rate 0.000161
Optimizer Adam
Triplet Margin 0.265
Embedding Size 128
Epochs 15-30
KoLeo Loss No
Gradient Accumulation 1

Training Procedure

  1. Images from the same class form positive pairs
  2. Images from different classes form negative pairs
  3. Triplets (anchor, positive, negative) are constructed
  4. Model is trained to minimize triplet loss: max(0, d(a,p) - d(a,n) + margin)
  5. Distance is computed as: 1 - cosine_similarity

Performance

Metric Value
Validation AUC 0.9597
Intra-class Distance ~0.22
Inter-class Distance ~1.04
Separation Margin ~0.81
Good Triplets Ratio >93%

Usage

Installation

pip install torch torchvision

Loading the Model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class VGG11Embedding(nn.Module):
    def __init__(self, embedding_size=128, weights=None):
        super(VGG11Embedding, self).__init__()
        vgg = models.vgg11(weights=weights)
        self.features = vgg.features
        self.linear = nn.Linear(512, embedding_size)
        
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        x = F.normalize(x, p=2, dim=1)
        return x

# Load the model
model = VGG11Embedding(embedding_size=128)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()

Computing Embeddings

import torch
from torchvision import transforms
from PIL import Image

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 
                       std=[0.2470, 0.2435, 0.2616])
])

# Load and preprocess image
image = Image.open("image.png")
image_tensor = transform(image).unsqueeze(0)

# Get embedding
with torch.no_grad():
    embedding = model(image_tensor)

print(f"Embedding shape: {embedding.shape}")  # torch.Size([1, 128])

Computing Similarity

# Compare two images
image1 = transform(Image.open("image1.jpg")).unsqueeze(0)
image2 = transform(Image.open("image2.jpg")).unsqueeze(0)

with torch.no_grad():
    emb1 = model(image1)
    emb2 = model(image2)
    
    # Cosine similarity (already normalized)
    similarity = torch.mm(emb1, emb2.t()).item()
    
    print(f"Similarity: {similarity:.4f}")
    # High similarity (> 0.5): similar images
    # Low similarity (< 0.3): dissimilar images

Finding Similar Images

import numpy as np

# Create embeddings for a database of images
database_embeddings = []
for image_path in image_database:
    img = transform(Image.open(image_path)).unsqueeze(0)
    with torch.no_grad():
        emb = model(img)
        database_embeddings.append(emb)

database_embeddings = torch.cat(database_embeddings, dim=0)

# Query with a new image
query_image = transform(Image.open("query.jpg")).unsqueeze(0)
with torch.no_grad():
    query_embedding = model(query_image)

# Compute similarities
similarities = torch.mm(query_embedding, database_embeddings.t()).squeeze(0)

# Get top-k most similar images
top_k = 5
top_k_indices = torch.topk(similarities, k=top_k).indices
print(f"Most similar images: {top_k_indices.tolist()}")

Model Card Authors

Adlane Ladjal

Model Card Contact

For questions or feedback, please open an issue in the GitHub repository.

Downloads last month
7
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Dataset used to train adlito/TripletVGG11