image-retrieval / utils /get_embeddings.py
nampham1106's picture
first commit
ab9b7a8
import os
from tqdm.auto import tqdm
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def preprocess_image(image_path):
img = Image.open(image_path).convert('RGB')
processed_img = transform(img)
return processed_img
def create_resnet18_model():
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
modules = list(model.children())[:-1]
model = nn.Sequential(*modules)
return model
def extract_features(model, processed_image):
input = processed_image.unsqueeze(dim=0).to(device)
model.eval()
with torch.inference_mode():
prediction = model(input)
return prediction.squeeze().tolist()