itayitay123 commited on
Commit
1eb91cf
·
verified ·
1 Parent(s): e6c787f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -13
app.py CHANGED
@@ -26,10 +26,10 @@ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
  # Load dataset
27
  dataset = load_dataset("lirus18/deepfashion", split="train")
28
 
29
- # Embed a subset of images
30
  image_vectors = []
31
  image_indices = []
32
- N = 500
33
 
34
  for i in range(N):
35
  img = dataset[i]['image'].convert("RGB")
@@ -41,7 +41,7 @@ for i in range(N):
41
 
42
  image_vectors = np.array(image_vectors)
43
 
44
- # Find similar images
45
  def find_similar(user_image, top_k=3, exclude_index=None):
46
  inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device)
47
  with torch.no_grad():
@@ -52,7 +52,7 @@ def find_similar(user_image, top_k=3, exclude_index=None):
52
  sims[exclude_index] = -1
53
 
54
  top_idx = sims.argsort()[-top_k:][::-1]
55
- return [dataset[image_indices[i]]['image'] for i in top_idx]
56
 
57
  # Load Stable Diffusion
58
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
@@ -73,8 +73,11 @@ def generate_outfits(input_image, n=10):
73
 
74
  return generated_images
75
 
 
76
  def recommend_from_upload(uploaded_image):
77
  uploaded_image = uploaded_image.convert("RGB")
 
 
78
  closest_idx = None
79
  for i in range(len(image_indices)):
80
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
@@ -82,12 +85,28 @@ def recommend_from_upload(uploaded_image):
82
  closest_idx = i
83
  break
84
 
85
- similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
86
- generated_imgs = generate_outfits(uploaded_image, n=3)
 
 
 
87
 
88
- return [uploaded_image] + similar_imgs + generated_imgs
 
 
 
 
 
 
 
 
 
 
89
 
90
- # 5 clickable example images (must be uploaded to the repo)
 
 
 
91
  example_paths = [
92
  ["example1.jpg"],
93
  ["example2.jpg"],
@@ -96,8 +115,6 @@ example_paths = [
96
  ["example5.jpg"]
97
  ]
98
 
99
-
100
-
101
  # Gradio Interface
102
  demo = gr.Interface(
103
  fn=recommend_from_upload,
@@ -107,14 +124,14 @@ demo = gr.Interface(
107
  gr.Image(label="Similar Item 1"),
108
  gr.Image(label="Similar Item 2"),
109
  gr.Image(label="Similar Item 3"),
110
- gr.Gallery(label="AI-Generated Outfits (x10)", columns=5, rows=2),
111
-
112
  ],
113
  title="👗 Fashion Outfit Recommender",
114
- description="Upload a clothing image to get 3 similar items from the dataset and 10 AI-generated outfit designs.",
115
  examples=example_paths
116
  )
117
 
118
  if __name__ == "__main__":
119
  demo.launch()
120
 
 
 
26
  # Load dataset
27
  dataset = load_dataset("lirus18/deepfashion", split="train")
28
 
29
+ # Embed a subset of dataset images
30
  image_vectors = []
31
  image_indices = []
32
+ N = 500 # You can increase this if performance is okay
33
 
34
  for i in range(N):
35
  img = dataset[i]['image'].convert("RGB")
 
41
 
42
  image_vectors = np.array(image_vectors)
43
 
44
+ # Similarity search
45
  def find_similar(user_image, top_k=3, exclude_index=None):
46
  inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device)
47
  with torch.no_grad():
 
52
  sims[exclude_index] = -1
53
 
54
  top_idx = sims.argsort()[-top_k:][::-1]
55
+ return [dataset[image_indices[i]]['image'] for i in top_idx], query_vec
56
 
57
  # Load Stable Diffusion
58
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
 
73
 
74
  return generated_images
75
 
76
+ # Main function
77
  def recommend_from_upload(uploaded_image):
78
  uploaded_image = uploaded_image.convert("RGB")
79
+
80
+ # Check if the uploaded image already exists in the dataset
81
  closest_idx = None
82
  for i in range(len(image_indices)):
83
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
 
85
  closest_idx = i
86
  break
87
 
88
+ # Find 3 similar dataset items + get uploaded image embedding
89
+ similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
90
+
91
+ # Generate 10 new outfit images
92
+ generated_imgs = generate_outfits(uploaded_image, n=10)
93
 
94
+ # Find the most relevant generated image
95
+ best_score = -1
96
+ best_generated_img = None
97
+ for img in generated_imgs:
98
+ inputs = processor(images=img, return_tensors="pt").to(device)
99
+ with torch.no_grad():
100
+ emb = model.get_image_features(**inputs).cpu().numpy()
101
+ sim = cosine_similarity(query_vec, emb)[0][0]
102
+ if sim > best_score:
103
+ best_score = sim
104
+ best_generated_img = img
105
 
106
+ # Final output: input + 3 similar + 1 most relevant generated
107
+ return [uploaded_image] + similar_imgs + [best_generated_img]
108
+
109
+ # Example images (make sure they are in the root folder)
110
  example_paths = [
111
  ["example1.jpg"],
112
  ["example2.jpg"],
 
115
  ["example5.jpg"]
116
  ]
117
 
 
 
118
  # Gradio Interface
119
  demo = gr.Interface(
120
  fn=recommend_from_upload,
 
124
  gr.Image(label="Similar Item 1"),
125
  gr.Image(label="Similar Item 2"),
126
  gr.Image(label="Similar Item 3"),
127
+ gr.Image(label="Best AI-Generated Outfit")
 
128
  ],
129
  title="👗 Fashion Outfit Recommender",
130
+ description="Upload a clothing image to get 3 similar items from the dataset and 1 best AI-generated outfit design.",
131
  examples=example_paths
132
  )
133
 
134
  if __name__ == "__main__":
135
  demo.launch()
136
 
137
+