File size: 2,508 Bytes
300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 d1b4c6f 300b764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import torch
import pickle
import numpy as np
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
from sklearn.metrics.pairwise import cosine_similarity
import timm
# Load the precomputed fashion embeddings
file_path = hf_hub_download(
repo_id="Elevi7/MatchMe",
filename="fashion_embeddings.pkl",
repo_type="dataset"
)
with open(file_path, "rb") as f:
embedding_store = pickle.load(f)
# Load the pretrained model (same one used for embeddings)
model = timm.create_model("resnet18", pretrained=True)
model.eval()
# Define image transformation pipeline
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Extract embedding from uploaded image
def extract_embedding(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
embedding = model(image).squeeze().numpy()
return embedding
# Recommendation function with filtering
def recommend(image, category, gender):
query_embedding = extract_embedding(image)
# Filter embeddings by category and gender
filtered_store = [
item for item in embedding_store
if item.get("category") == category and item.get("gender") == gender
]
if len(filtered_store) == 0:
print("No matches found for selected category and gender.")
return ["", "", ""] # Return empty images
all_embeddings = np.array([item["embedding"] for item in filtered_store])
paths = [item["image_path"] for item in filtered_store]
similarities = cosine_similarity([query_embedding], all_embeddings)[0]
top_indices = np.argsort(similarities)[-3:][::-1] # Top 3
return [paths[i] for i in top_indices]
# Gradio interface
demo = gr.Interface(
fn=recommend,
inputs=[
gr.Image(type="pil", label="Upload Clothing Image"),
gr.Dropdown(
choices=["shoes", "tops", "pants", "handbags", "coats_jackets",
"sunglasses", "shorts", "skirts", "earrings", "necklaces"],
label="Category"
),
gr.Dropdown(choices=["men", "women"], label="Gender")
],
outputs=[
gr.Image(type="filepath", label="Match 1"),
gr.Image(type="filepath", label="Match 2"),
gr.Image(type="filepath", label="Match 3")
],
title="MatchMe: Fashion Recommender",
description="Upload a fashion image and get 3 visually similar items. Filter by category and gender."
)
demo.launch()
|