Elevi7 commited on
Commit
d1b4c6f
·
verified ·
1 Parent(s): 859b29d

Update app.py to add category and gender filtering

Browse files

Implemented dropdowns for selecting clothing category and gender. Added filtering logic in the recommend function to return only results matching user selections. Also handled the case where no matches are found.

Files changed (1) hide show
  1. app.py +42 -13
app.py CHANGED
@@ -6,47 +6,76 @@ from PIL import Image
6
  from torchvision import transforms
7
  from huggingface_hub import hf_hub_download
8
  from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
 
 
 
 
9
 
10
- # Step 1: Load the precomputed fashion embeddings
11
- file_path = hf_hub_download(repo_id="Elevi7/MatchMe", filename="fashion_embeddings.pkl", repo_type="dataset")
12
  with open(file_path, "rb") as f:
13
  embedding_store = pickle.load(f)
14
 
15
- # Step 2: Load the same model used to create the embeddings
16
- # (update this if you used something different, e.g. CLIP, ResNet, etc.)
17
- import timm
18
  model = timm.create_model("resnet18", pretrained=True)
19
  model.eval()
20
 
 
21
  transform = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
  transforms.ToTensor()
24
  ])
25
 
 
26
  def extract_embedding(image):
27
  image = transform(image).unsqueeze(0)
28
  with torch.no_grad():
29
  embedding = model(image).squeeze().numpy()
30
  return embedding
31
 
32
- def recommend(image):
 
33
  query_embedding = extract_embedding(image)
34
 
35
- # Get embeddings and paths
36
- all_embeddings = np.array([item["embedding"] for item in embedding_store])
37
- paths = [item["image_path"] for item in embedding_store]
 
 
 
 
 
 
 
 
 
38
 
39
  similarities = cosine_similarity([query_embedding], all_embeddings)[0]
40
- top_indices = np.argsort(similarities)[-3:][::-1] # Top 3 matches
41
 
42
  return [paths[i] for i in top_indices]
43
 
 
44
  demo = gr.Interface(
45
  fn=recommend,
46
- inputs=gr.Image(type="pil"),
47
- outputs=[gr.Image(type="filepath", label=f"Match {i+1}") for i in range(3)],
 
 
 
 
 
 
 
 
 
 
 
 
48
  title="MatchMe: Fashion Recommender",
49
- description="Upload a fashion image and get 3 visually similar items."
50
  )
51
 
52
  demo.launch()
 
6
  from torchvision import transforms
7
  from huggingface_hub import hf_hub_download
8
  from sklearn.metrics.pairwise import cosine_similarity
9
+ import timm
10
+
11
+ # Load the precomputed fashion embeddings
12
+ file_path = hf_hub_download(
13
+ repo_id="Elevi7/MatchMe",
14
+ filename="fashion_embeddings.pkl",
15
+ repo_type="dataset"
16
+ )
17
 
 
 
18
  with open(file_path, "rb") as f:
19
  embedding_store = pickle.load(f)
20
 
21
+ # Load the pretrained model (same one used for embeddings)
 
 
22
  model = timm.create_model("resnet18", pretrained=True)
23
  model.eval()
24
 
25
+ # Define image transformation pipeline
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor()
29
  ])
30
 
31
+ # Extract embedding from uploaded image
32
  def extract_embedding(image):
33
  image = transform(image).unsqueeze(0)
34
  with torch.no_grad():
35
  embedding = model(image).squeeze().numpy()
36
  return embedding
37
 
38
+ # Recommendation function with filtering
39
+ def recommend(image, category, gender):
40
  query_embedding = extract_embedding(image)
41
 
42
+ # Filter embeddings by category and gender
43
+ filtered_store = [
44
+ item for item in embedding_store
45
+ if item.get("category") == category and item.get("gender") == gender
46
+ ]
47
+
48
+ if len(filtered_store) == 0:
49
+ print("No matches found for selected category and gender.")
50
+ return ["", "", ""] # Return empty images
51
+
52
+ all_embeddings = np.array([item["embedding"] for item in filtered_store])
53
+ paths = [item["image_path"] for item in filtered_store]
54
 
55
  similarities = cosine_similarity([query_embedding], all_embeddings)[0]
56
+ top_indices = np.argsort(similarities)[-3:][::-1] # Top 3
57
 
58
  return [paths[i] for i in top_indices]
59
 
60
+ # Gradio interface
61
  demo = gr.Interface(
62
  fn=recommend,
63
+ inputs=[
64
+ gr.Image(type="pil", label="Upload Clothing Image"),
65
+ gr.Dropdown(
66
+ choices=["shoes", "tops", "pants", "handbags", "coats_jackets",
67
+ "sunglasses", "shorts", "skirts", "earrings", "necklaces"],
68
+ label="Category"
69
+ ),
70
+ gr.Dropdown(choices=["men", "women"], label="Gender")
71
+ ],
72
+ outputs=[
73
+ gr.Image(type="filepath", label="Match 1"),
74
+ gr.Image(type="filepath", label="Match 2"),
75
+ gr.Image(type="filepath", label="Match 3")
76
+ ],
77
  title="MatchMe: Fashion Recommender",
78
+ description="Upload a fashion image and get 3 visually similar items. Filter by category and gender."
79
  )
80
 
81
  demo.launch()