MatchMe / app.py
Elevi7's picture
Update app.py to add category and gender filtering
d1b4c6f verified
import gradio as gr
import torch
import pickle
import numpy as np
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
from sklearn.metrics.pairwise import cosine_similarity
import timm
# Load the precomputed fashion embeddings
file_path = hf_hub_download(
repo_id="Elevi7/MatchMe",
filename="fashion_embeddings.pkl",
repo_type="dataset"
)
with open(file_path, "rb") as f:
embedding_store = pickle.load(f)
# Load the pretrained model (same one used for embeddings)
model = timm.create_model("resnet18", pretrained=True)
model.eval()
# Define image transformation pipeline
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Extract embedding from uploaded image
def extract_embedding(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
embedding = model(image).squeeze().numpy()
return embedding
# Recommendation function with filtering
def recommend(image, category, gender):
query_embedding = extract_embedding(image)
# Filter embeddings by category and gender
filtered_store = [
item for item in embedding_store
if item.get("category") == category and item.get("gender") == gender
]
if len(filtered_store) == 0:
print("No matches found for selected category and gender.")
return ["", "", ""] # Return empty images
all_embeddings = np.array([item["embedding"] for item in filtered_store])
paths = [item["image_path"] for item in filtered_store]
similarities = cosine_similarity([query_embedding], all_embeddings)[0]
top_indices = np.argsort(similarities)[-3:][::-1] # Top 3
return [paths[i] for i in top_indices]
# Gradio interface
demo = gr.Interface(
fn=recommend,
inputs=[
gr.Image(type="pil", label="Upload Clothing Image"),
gr.Dropdown(
choices=["shoes", "tops", "pants", "handbags", "coats_jackets",
"sunglasses", "shorts", "skirts", "earrings", "necklaces"],
label="Category"
),
gr.Dropdown(choices=["men", "women"], label="Gender")
],
outputs=[
gr.Image(type="filepath", label="Match 1"),
gr.Image(type="filepath", label="Match 2"),
gr.Image(type="filepath", label="Match 3")
],
title="MatchMe: Fashion Recommender",
description="Upload a fashion image and get 3 visually similar items. Filter by category and gender."
)
demo.launch()