Update app.py to add category and gender filtering
Browse filesImplemented dropdowns for selecting clothing category and gender. Added filtering logic in the recommend function to return only results matching user selections. Also handled the case where no matches are found.
app.py
CHANGED
|
@@ -6,47 +6,76 @@ from PIL import Image
|
|
| 6 |
from torchvision import transforms
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
# Step 1: Load the precomputed fashion embeddings
|
| 11 |
-
file_path = hf_hub_download(repo_id="Elevi7/MatchMe", filename="fashion_embeddings.pkl", repo_type="dataset")
|
| 12 |
with open(file_path, "rb") as f:
|
| 13 |
embedding_store = pickle.load(f)
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
# (update this if you used something different, e.g. CLIP, ResNet, etc.)
|
| 17 |
-
import timm
|
| 18 |
model = timm.create_model("resnet18", pretrained=True)
|
| 19 |
model.eval()
|
| 20 |
|
|
|
|
| 21 |
transform = transforms.Compose([
|
| 22 |
transforms.Resize((224, 224)),
|
| 23 |
transforms.ToTensor()
|
| 24 |
])
|
| 25 |
|
|
|
|
| 26 |
def extract_embedding(image):
|
| 27 |
image = transform(image).unsqueeze(0)
|
| 28 |
with torch.no_grad():
|
| 29 |
embedding = model(image).squeeze().numpy()
|
| 30 |
return embedding
|
| 31 |
|
| 32 |
-
|
|
|
|
| 33 |
query_embedding = extract_embedding(image)
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
similarities = cosine_similarity([query_embedding], all_embeddings)[0]
|
| 40 |
-
top_indices = np.argsort(similarities)[-3:][::-1] # Top 3
|
| 41 |
|
| 42 |
return [paths[i] for i in top_indices]
|
| 43 |
|
|
|
|
| 44 |
demo = gr.Interface(
|
| 45 |
fn=recommend,
|
| 46 |
-
inputs=
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
title="MatchMe: Fashion Recommender",
|
| 49 |
-
description="Upload a fashion image and get 3 visually similar items."
|
| 50 |
)
|
| 51 |
|
| 52 |
demo.launch()
|
|
|
|
| 6 |
from torchvision import transforms
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 9 |
+
import timm
|
| 10 |
+
|
| 11 |
+
# Load the precomputed fashion embeddings
|
| 12 |
+
file_path = hf_hub_download(
|
| 13 |
+
repo_id="Elevi7/MatchMe",
|
| 14 |
+
filename="fashion_embeddings.pkl",
|
| 15 |
+
repo_type="dataset"
|
| 16 |
+
)
|
| 17 |
|
|
|
|
|
|
|
| 18 |
with open(file_path, "rb") as f:
|
| 19 |
embedding_store = pickle.load(f)
|
| 20 |
|
| 21 |
+
# Load the pretrained model (same one used for embeddings)
|
|
|
|
|
|
|
| 22 |
model = timm.create_model("resnet18", pretrained=True)
|
| 23 |
model.eval()
|
| 24 |
|
| 25 |
+
# Define image transformation pipeline
|
| 26 |
transform = transforms.Compose([
|
| 27 |
transforms.Resize((224, 224)),
|
| 28 |
transforms.ToTensor()
|
| 29 |
])
|
| 30 |
|
| 31 |
+
# Extract embedding from uploaded image
|
| 32 |
def extract_embedding(image):
|
| 33 |
image = transform(image).unsqueeze(0)
|
| 34 |
with torch.no_grad():
|
| 35 |
embedding = model(image).squeeze().numpy()
|
| 36 |
return embedding
|
| 37 |
|
| 38 |
+
# Recommendation function with filtering
|
| 39 |
+
def recommend(image, category, gender):
|
| 40 |
query_embedding = extract_embedding(image)
|
| 41 |
|
| 42 |
+
# Filter embeddings by category and gender
|
| 43 |
+
filtered_store = [
|
| 44 |
+
item for item in embedding_store
|
| 45 |
+
if item.get("category") == category and item.get("gender") == gender
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
if len(filtered_store) == 0:
|
| 49 |
+
print("No matches found for selected category and gender.")
|
| 50 |
+
return ["", "", ""] # Return empty images
|
| 51 |
+
|
| 52 |
+
all_embeddings = np.array([item["embedding"] for item in filtered_store])
|
| 53 |
+
paths = [item["image_path"] for item in filtered_store]
|
| 54 |
|
| 55 |
similarities = cosine_similarity([query_embedding], all_embeddings)[0]
|
| 56 |
+
top_indices = np.argsort(similarities)[-3:][::-1] # Top 3
|
| 57 |
|
| 58 |
return [paths[i] for i in top_indices]
|
| 59 |
|
| 60 |
+
# Gradio interface
|
| 61 |
demo = gr.Interface(
|
| 62 |
fn=recommend,
|
| 63 |
+
inputs=[
|
| 64 |
+
gr.Image(type="pil", label="Upload Clothing Image"),
|
| 65 |
+
gr.Dropdown(
|
| 66 |
+
choices=["shoes", "tops", "pants", "handbags", "coats_jackets",
|
| 67 |
+
"sunglasses", "shorts", "skirts", "earrings", "necklaces"],
|
| 68 |
+
label="Category"
|
| 69 |
+
),
|
| 70 |
+
gr.Dropdown(choices=["men", "women"], label="Gender")
|
| 71 |
+
],
|
| 72 |
+
outputs=[
|
| 73 |
+
gr.Image(type="filepath", label="Match 1"),
|
| 74 |
+
gr.Image(type="filepath", label="Match 2"),
|
| 75 |
+
gr.Image(type="filepath", label="Match 3")
|
| 76 |
+
],
|
| 77 |
title="MatchMe: Fashion Recommender",
|
| 78 |
+
description="Upload a fashion image and get 3 visually similar items. Filter by category and gender."
|
| 79 |
)
|
| 80 |
|
| 81 |
demo.launch()
|