itayitay123 commited on
Commit
69c19c8
·
verified ·
1 Parent(s): d2b0774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -14
app.py CHANGED
@@ -6,6 +6,7 @@ Automatically generated by Colab.
6
  Original file is located at
7
  https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X
8
  """
 
9
  from datasets import load_dataset
10
  from PIL import Image, ImageChops
11
  from transformers import CLIPProcessor, CLIPModel
@@ -18,17 +19,17 @@ from diffusers import StableDiffusionImg2ImgPipeline
18
  # Device setup
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # Load model and processor
22
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
23
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
24
 
25
  # Load dataset
26
  dataset = load_dataset("lirus18/deepfashion", split="train")
27
 
28
- # Precompute image vectors
29
  image_vectors = []
30
  image_indices = []
31
- N = 500 # use a smaller subset to avoid long loading
32
 
33
  for i in range(N):
34
  img = dataset[i]['image'].convert("RGB")
@@ -48,24 +49,30 @@ def find_similar(user_image, top_k=3, exclude_index=None):
48
 
49
  sims = cosine_similarity(query_vec, image_vectors)[0]
50
  if exclude_index is not None:
51
- sims[exclude_index] = -1 # Exclude identical
52
 
53
  top_idx = sims.argsort()[-top_k:][::-1]
54
  return [dataset[image_indices[i]]['image'] for i in top_idx]
55
 
56
- # Load Stable Diffusion pipeline
57
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
58
  "runwayml/stable-diffusion-v1-5",
59
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
60
  ).to(device)
 
61
 
62
- def generate_outfit_from_image(input_image):
63
  prompt = "fashion outfit design inspired by the clothing item"
64
  init_image = input_image.resize((512, 512))
65
- result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
66
- return result.images[0]
 
 
 
 
 
67
 
68
- # Main recommendation function
69
  def recommend_from_upload(uploaded_image):
70
  uploaded_image = uploaded_image.convert("RGB")
71
  closest_idx = None
@@ -76,8 +83,18 @@ def recommend_from_upload(uploaded_image):
76
  break
77
 
78
  similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
79
- generated_img = generate_outfit_from_image(uploaded_image)
80
- return [uploaded_image] + similar_imgs + [generated_img]
 
 
 
 
 
 
 
 
 
 
81
 
82
  # Gradio Interface
83
  demo = gr.Interface(
@@ -88,12 +105,13 @@ demo = gr.Interface(
88
  gr.Image(label="Similar Item 1"),
89
  gr.Image(label="Similar Item 2"),
90
  gr.Image(label="Similar Item 3"),
91
- gr.Image(label="Generated New Outfit"),
92
  ],
93
  title="👗 Fashion Outfit Recommender",
94
- description="Upload a clothing image to see 3 similar outfits and 1 AI-generated one!"
 
95
  )
96
 
97
- # Only launch if main
98
  if __name__ == "__main__":
99
  demo.launch()
 
 
6
  Original file is located at
7
  https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X
8
  """
9
+
10
  from datasets import load_dataset
11
  from PIL import Image, ImageChops
12
  from transformers import CLIPProcessor, CLIPModel
 
19
  # Device setup
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+ # Load CLIP model
23
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
24
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
 
26
  # Load dataset
27
  dataset = load_dataset("lirus18/deepfashion", split="train")
28
 
29
+ # Embed a subset of images
30
  image_vectors = []
31
  image_indices = []
32
+ N = 500
33
 
34
  for i in range(N):
35
  img = dataset[i]['image'].convert("RGB")
 
49
 
50
  sims = cosine_similarity(query_vec, image_vectors)[0]
51
  if exclude_index is not None:
52
+ sims[exclude_index] = -1
53
 
54
  top_idx = sims.argsort()[-top_k:][::-1]
55
  return [dataset[image_indices[i]]['image'] for i in top_idx]
56
 
57
+ # Load Stable Diffusion
58
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
59
  "runwayml/stable-diffusion-v1-5",
60
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
+ low_cpu_mem_usage=True
62
  ).to(device)
63
+ pipe.enable_attention_slicing()
64
 
65
+ def generate_outfits(input_image, n=10):
66
  prompt = "fashion outfit design inspired by the clothing item"
67
  init_image = input_image.resize((512, 512))
68
+ generated_images = []
69
+
70
+ for _ in range(n):
71
+ result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
72
+ generated_images.append(result.images[0])
73
+
74
+ return generated_images
75
 
 
76
  def recommend_from_upload(uploaded_image):
77
  uploaded_image = uploaded_image.convert("RGB")
78
  closest_idx = None
 
83
  break
84
 
85
  similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
86
+ generated_imgs = generate_outfits(uploaded_image, n=10)
87
+
88
+ return [uploaded_image] + similar_imgs + generated_imgs
89
+
90
+ # 5 clickable example images (must be uploaded to the repo)
91
+ example_paths = [
92
+ ["fashion_examples/example1.jpg"],
93
+ ["fashion_examples/example2.jpg"],
94
+ ["fashion_examples/example3.jpg"],
95
+ ["fashion_examples/example4.jpg"],
96
+ ["fashion_examples/example5.jpg"]
97
+ ]
98
 
99
  # Gradio Interface
100
  demo = gr.Interface(
 
105
  gr.Image(label="Similar Item 1"),
106
  gr.Image(label="Similar Item 2"),
107
  gr.Image(label="Similar Item 3"),
108
+ gr.Gallery(label="AI-Generated Outfits (x10)").style(grid=(5, 2), height="auto"),
109
  ],
110
  title="👗 Fashion Outfit Recommender",
111
+ description="Upload a clothing image to get 3 similar items from the dataset and 10 AI-generated outfit designs.",
112
+ examples=example_paths
113
  )
114
 
 
115
  if __name__ == "__main__":
116
  demo.launch()
117
+