TripletVGG11 / example.py
adlito's picture
Upload folder using huggingface_hub
aae5634 verified
from PIL import Image
import torch
from torchvision import transforms
from model import VGG11Embedding
# Preprocessing
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616])
])
# Load and preprocess image
image = Image.open("image.png")
image_tensor = transform(image).unsqueeze(0)
model = VGG11Embedding(embedding_size=128)
# Get embedding
with torch.no_grad():
embedding = model(image_tensor)
print(f"Embedding shape: {embedding.shape}") # torch.Size([1, 128])