itayitay123 commited on
Commit
1d60bf1
·
verified ·
1 Parent(s): 2f8ff64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -6,38 +6,41 @@ Automatically generated by Colab.
6
  Original file is located at
7
  https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X
8
  """
9
-
10
  from datasets import load_dataset
11
- from PIL import Image
12
  from transformers import CLIPProcessor, CLIPModel
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  import torch
15
  import numpy as np
16
  import gradio as gr
 
17
 
18
- # Load dataset and model
19
- dataset = load_dataset("lirus18/deepfashion", split="train")
20
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
 
 
21
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
22
 
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- model.to(device)
25
 
26
- # Precompute vectors
27
  image_vectors = []
28
  image_indices = []
29
- N = 500
30
 
31
  for i in range(N):
32
- image = dataset[i]['image'].convert("RGB")
33
- inputs = processor(images=image, return_tensors="pt").to(device)
34
  with torch.no_grad():
35
- embedding = model.get_image_features(**inputs)
36
- image_vectors.append(embedding.cpu().numpy().squeeze())
37
  image_indices.append(i)
38
 
39
  image_vectors = np.array(image_vectors)
40
 
 
41
  def find_similar(user_image, top_k=3, exclude_index=None):
42
  inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device)
43
  with torch.no_grad():
@@ -45,14 +48,12 @@ def find_similar(user_image, top_k=3, exclude_index=None):
45
 
46
  sims = cosine_similarity(query_vec, image_vectors)[0]
47
  if exclude_index is not None:
48
- sims[exclude_index] = -1
49
 
50
  top_idx = sims.argsort()[-top_k:][::-1]
51
  return [dataset[image_indices[i]]['image'] for i in top_idx]
52
 
53
- # Placeholder for Stable Diffusion (optional)
54
- from diffusers import StableDiffusionImg2ImgPipeline
55
-
56
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
57
  "runwayml/stable-diffusion-v1-5",
58
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -61,11 +62,10 @@ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
61
  def generate_outfit_from_image(input_image):
62
  prompt = "fashion outfit design inspired by the clothing item"
63
  init_image = input_image.resize((512, 512))
64
- generated = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
65
- return generated.images[0]
66
-
67
- from PIL import ImageChops
68
 
 
69
  def recommend_from_upload(uploaded_image):
70
  uploaded_image = uploaded_image.convert("RGB")
71
  closest_idx = None
@@ -77,7 +77,6 @@ def recommend_from_upload(uploaded_image):
77
 
78
  similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
79
  generated_img = generate_outfit_from_image(uploaded_image)
80
-
81
  return [uploaded_image] + similar_imgs + [generated_img]
82
 
83
  # Gradio Interface
@@ -95,4 +94,6 @@ demo = gr.Interface(
95
  description="Upload a clothing image to see 3 similar outfits and 1 AI-generated one!"
96
  )
97
 
98
- demo.launch()
 
 
 
6
  Original file is located at
7
  https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X
8
  """
 
9
  from datasets import load_dataset
10
+ from PIL import Image, ImageChops
11
  from transformers import CLIPProcessor, CLIPModel
12
  from sklearn.metrics.pairwise import cosine_similarity
13
  import torch
14
  import numpy as np
15
  import gradio as gr
16
+ from diffusers import StableDiffusionImg2ImgPipeline
17
 
18
+ # Device setup
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # Load model and processor
22
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
23
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
24
 
25
+ # Load dataset
26
+ dataset = load_dataset("lirus18/deepfashion", split="train")
27
 
28
+ # Precompute image vectors
29
  image_vectors = []
30
  image_indices = []
31
+ N = 500 # use a smaller subset to avoid long loading
32
 
33
  for i in range(N):
34
+ img = dataset[i]['image'].convert("RGB")
35
+ inputs = processor(images=img, return_tensors="pt").to(device)
36
  with torch.no_grad():
37
+ emb = model.get_image_features(**inputs)
38
+ image_vectors.append(emb.cpu().numpy().squeeze())
39
  image_indices.append(i)
40
 
41
  image_vectors = np.array(image_vectors)
42
 
43
+ # Find similar images
44
  def find_similar(user_image, top_k=3, exclude_index=None):
45
  inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device)
46
  with torch.no_grad():
 
48
 
49
  sims = cosine_similarity(query_vec, image_vectors)[0]
50
  if exclude_index is not None:
51
+ sims[exclude_index] = -1 # Exclude identical
52
 
53
  top_idx = sims.argsort()[-top_k:][::-1]
54
  return [dataset[image_indices[i]]['image'] for i in top_idx]
55
 
56
+ # Load Stable Diffusion pipeline
 
 
57
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
58
  "runwayml/stable-diffusion-v1-5",
59
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
62
  def generate_outfit_from_image(input_image):
63
  prompt = "fashion outfit design inspired by the clothing item"
64
  init_image = input_image.resize((512, 512))
65
+ result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
66
+ return result.images[0]
 
 
67
 
68
+ # Main recommendation function
69
  def recommend_from_upload(uploaded_image):
70
  uploaded_image = uploaded_image.convert("RGB")
71
  closest_idx = None
 
77
 
78
  similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
79
  generated_img = generate_outfit_from_image(uploaded_image)
 
80
  return [uploaded_image] + similar_imgs + [generated_img]
81
 
82
  # Gradio Interface
 
94
  description="Upload a clothing image to see 3 similar outfits and 1 AI-generated one!"
95
  )
96
 
97
+ # Only launch if main
98
+ if __name__ == "__main__":
99
+ demo.launch()