itayitay123 commited on
Commit
eabc0b5
·
verified ·
1 Parent(s): 1eb91cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -26
app.py CHANGED
@@ -29,7 +29,7 @@ dataset = load_dataset("lirus18/deepfashion", split="train")
29
  # Embed a subset of dataset images
30
  image_vectors = []
31
  image_indices = []
32
- N = 500 # You can increase this if performance is okay
33
 
34
  for i in range(N):
35
  img = dataset[i]['image'].convert("RGB")
@@ -62,7 +62,8 @@ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
62
  ).to(device)
63
  pipe.enable_attention_slicing()
64
 
65
- def generate_outfits(input_image, n=10):
 
66
  prompt = "fashion outfit design inspired by the clothing item"
67
  init_image = input_image.resize((512, 512))
68
  generated_images = []
@@ -73,11 +74,11 @@ def generate_outfits(input_image, n=10):
73
 
74
  return generated_images
75
 
76
- # Main function
77
  def recommend_from_upload(uploaded_image):
78
  uploaded_image = uploaded_image.convert("RGB")
79
 
80
- # Check if the uploaded image already exists in the dataset
81
  closest_idx = None
82
  for i in range(len(image_indices)):
83
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
@@ -85,13 +86,13 @@ def recommend_from_upload(uploaded_image):
85
  closest_idx = i
86
  break
87
 
88
- # Find 3 similar dataset items + get uploaded image embedding
89
  similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
90
 
91
- # Generate 10 new outfit images
92
- generated_imgs = generate_outfits(uploaded_image, n=10)
93
 
94
- # Find the most relevant generated image
95
  best_score = -1
96
  best_generated_img = None
97
  for img in generated_imgs:
@@ -103,10 +104,9 @@ def recommend_from_upload(uploaded_image):
103
  best_score = sim
104
  best_generated_img = img
105
 
106
- # Final output: input + 3 similar + 1 most relevant generated
107
  return [uploaded_image] + similar_imgs + [best_generated_img]
108
 
109
- # Example images (make sure they are in the root folder)
110
  example_paths = [
111
  ["example1.jpg"],
112
  ["example2.jpg"],
@@ -115,23 +115,33 @@ example_paths = [
115
  ["example5.jpg"]
116
  ]
117
 
118
- # Gradio Interface
119
- demo = gr.Interface(
120
- fn=recommend_from_upload,
121
- inputs=gr.Image(type="pil", label="Upload a clothing item"),
122
- outputs=[
123
- gr.Image(label="Your Input"),
124
- gr.Image(label="Similar Item 1"),
125
- gr.Image(label="Similar Item 2"),
126
- gr.Image(label="Similar Item 3"),
127
- gr.Image(label="Best AI-Generated Outfit")
128
- ],
129
- title="👗 Fashion Outfit Recommender",
130
- description="Upload a clothing image to get 3 similar items from the dataset and 1 best AI-generated outfit design.",
131
- examples=example_paths
132
- )
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  if __name__ == "__main__":
135
  demo.launch()
136
 
137
-
 
29
  # Embed a subset of dataset images
30
  image_vectors = []
31
  image_indices = []
32
+ N = 500
33
 
34
  for i in range(N):
35
  img = dataset[i]['image'].convert("RGB")
 
62
  ).to(device)
63
  pipe.enable_attention_slicing()
64
 
65
+ # Generate outfits (2 only)
66
+ def generate_outfits(input_image, n=2):
67
  prompt = "fashion outfit design inspired by the clothing item"
68
  init_image = input_image.resize((512, 512))
69
  generated_images = []
 
74
 
75
  return generated_images
76
 
77
+ # Main recommendation function
78
  def recommend_from_upload(uploaded_image):
79
  uploaded_image = uploaded_image.convert("RGB")
80
 
81
+ # Check if the uploaded image exists in dataset
82
  closest_idx = None
83
  for i in range(len(image_indices)):
84
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
 
86
  closest_idx = i
87
  break
88
 
89
+ # Get 3 similar items + embedding
90
  similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
91
 
92
+ # Generate 2 synthetic outfits
93
+ generated_imgs = generate_outfits(uploaded_image, n=2)
94
 
95
+ # Select most similar generated image
96
  best_score = -1
97
  best_generated_img = None
98
  for img in generated_imgs:
 
104
  best_score = sim
105
  best_generated_img = img
106
 
 
107
  return [uploaded_image] + similar_imgs + [best_generated_img]
108
 
109
+ # Example image paths (must exist in root folder)
110
  example_paths = [
111
  ["example1.jpg"],
112
  ["example2.jpg"],
 
115
  ["example5.jpg"]
116
  ]
117
 
118
+ # Gradio Interface with button
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("## 👗 Fashion Outfit Recommender")
121
+ gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 AI-generated outfit design.")
122
+
123
+ with gr.Row():
124
+ image_input = gr.Image(type="pil", label="Upload a clothing item")
125
+
126
+ generate_btn = gr.Button("Generate Recommendations")
127
+
128
+ with gr.Row():
129
+ output1 = gr.Image(label="Your Input")
130
+ output2 = gr.Image(label="Similar Item 1")
131
+ output3 = gr.Image(label="Similar Item 2")
132
+ output4 = gr.Image(label="Similar Item 3")
133
+ output5 = gr.Image(label="Best AI-Generated Outfit")
134
+
135
+ examples = gr.Examples(
136
+ examples=example_paths,
137
+ inputs=image_input,
138
+ label="Try an Example"
139
+ )
140
+
141
+ generate_btn.click(fn=recommend_from_upload,
142
+ inputs=image_input,
143
+ outputs=[output1, output2, output3, output4, output5])
144
 
145
  if __name__ == "__main__":
146
  demo.launch()
147