LevyJonas commited on
Commit
e56fe42
·
verified ·
1 Parent(s): 76f544a

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +39 -16
pipeline.py CHANGED
@@ -75,49 +75,72 @@ def _safe_i2i_steps(strength: float, user_steps: int):
75
  steps_i2i = min(6, steps_i2i) # keep fast
76
  return steps, steps_i2i, strength
77
 
78
- def run_search_and_generate(user_img: Image.Image, user_prompt: str,
79
  k_retrieve=2, n_i2i=2, n_t2i=2,
80
  strength_i2i=0.35, steps=1, gen_size=512, seed=42):
81
- k_retrieve, n_i2i, n_t2i = _cap(k_retrieve), _cap(n_i2i), _cap(n_t2i)
 
 
 
 
82
  prompt = (user_prompt or "").strip()
83
  if not prompt:
84
  raise ValueError("Please enter a prompt (required for generation).")
85
 
86
- q = embed_image(user_img)
 
 
 
 
 
 
 
87
  retrieved = retrieve(q, k_retrieve)
88
 
 
 
 
89
  steps_txt, steps_i2i, strength = _safe_i2i_steps(strength_i2i, steps)
90
- init = user_img.convert("RGB").resize((gen_size, gen_size))
91
 
92
  gen_i2i, gen_t2i = [], []
 
93
  for i in range(n_i2i):
94
  g = torch.Generator(device).manual_seed(seed + 10*i)
95
  with torch.inference_mode():
96
  if device == "cuda":
97
  with torch.autocast("cuda", dtype=torch.float16):
98
- gen_i2i.append(img2img(prompt=prompt, negative_prompt=NEG, image=init,
99
- strength=strength, num_inference_steps=steps_i2i,
100
- guidance_scale=0.0, generator=g).images[0])
 
 
101
  else:
102
- gen_i2i.append(img2img(prompt=prompt, negative_prompt=NEG, image=init,
103
- strength=strength, num_inference_steps=steps_i2i,
104
- guidance_scale=0.0, generator=g).images[0])
 
 
105
 
106
  for i in range(n_t2i):
107
  g = torch.Generator(device).manual_seed(seed + 100 + 10*i)
108
  with torch.inference_mode():
109
  if device == "cuda":
110
  with torch.autocast("cuda", dtype=torch.float16):
111
- gen_t2i.append(txt2img(prompt=prompt, negative_prompt=NEG,
112
- num_inference_steps=steps_txt, guidance_scale=0.0,
113
- height=gen_size, width=gen_size, generator=g).images[0])
 
 
114
  else:
115
- gen_t2i.append(txt2img(prompt=prompt, negative_prompt=NEG,
116
- num_inference_steps=steps_txt, guidance_scale=0.0,
117
- height=gen_size, width=gen_size, generator=g).images[0])
 
 
118
 
119
  info = {
120
  "prompt": prompt,
 
121
  "k_retrieve": k_retrieve,
122
  "n_img2img": n_i2i,
123
  "n_txt2img": n_t2i,
 
75
  steps_i2i = min(6, steps_i2i) # keep fast
76
  return steps, steps_i2i, strength
77
 
78
+ def run_search_and_generate(user_imgs, user_prompt: str,
79
  k_retrieve=2, n_i2i=2, n_t2i=2,
80
  strength_i2i=0.35, steps=1, gen_size=512, seed=42):
81
+ # user_imgs: list of PIL images (1..4), some may be None
82
+ imgs = [im for im in (user_imgs or []) if im is not None]
83
+ if len(imgs) == 0:
84
+ raise ValueError("Please upload at least 1 image.")
85
+
86
  prompt = (user_prompt or "").strip()
87
  if not prompt:
88
  raise ValueError("Please enter a prompt (required for generation).")
89
 
90
+ k_retrieve, n_i2i, n_t2i = _cap(k_retrieve), _cap(n_i2i), _cap(n_t2i)
91
+
92
+ # --- embed each image and average embeddings ---
93
+ vecs = [embed_image(im) for im in imgs]
94
+ q = np.mean(np.stack(vecs, axis=0), axis=0)
95
+ q = (q / (np.linalg.norm(q) + 1e-12)).astype(np.float32)
96
+
97
+ # --- retrieval based on averaged embedding ---
98
  retrieved = retrieve(q, k_retrieve)
99
 
100
+ # --- choose init image for img2img (first provided) ---
101
+ init = imgs[0].convert("RGB").resize((gen_size, gen_size))
102
+
103
  steps_txt, steps_i2i, strength = _safe_i2i_steps(strength_i2i, steps)
 
104
 
105
  gen_i2i, gen_t2i = [], []
106
+
107
  for i in range(n_i2i):
108
  g = torch.Generator(device).manual_seed(seed + 10*i)
109
  with torch.inference_mode():
110
  if device == "cuda":
111
  with torch.autocast("cuda", dtype=torch.float16):
112
+ gen_i2i.append(img2img(
113
+ prompt=prompt, negative_prompt=NEG, image=init,
114
+ strength=strength, num_inference_steps=steps_i2i,
115
+ guidance_scale=0.0, generator=g
116
+ ).images[0])
117
  else:
118
+ gen_i2i.append(img2img(
119
+ prompt=prompt, negative_prompt=NEG, image=init,
120
+ strength=strength, num_inference_steps=steps_i2i,
121
+ guidance_scale=0.0, generator=g
122
+ ).images[0])
123
 
124
  for i in range(n_t2i):
125
  g = torch.Generator(device).manual_seed(seed + 100 + 10*i)
126
  with torch.inference_mode():
127
  if device == "cuda":
128
  with torch.autocast("cuda", dtype=torch.float16):
129
+ gen_t2i.append(txt2img(
130
+ prompt=prompt, negative_prompt=NEG,
131
+ num_inference_steps=steps_txt, guidance_scale=0.0,
132
+ height=gen_size, width=gen_size, generator=g
133
+ ).images[0])
134
  else:
135
+ gen_t2i.append(txt2img(
136
+ prompt=prompt, negative_prompt=NEG,
137
+ num_inference_steps=steps_txt, guidance_scale=0.0,
138
+ height=gen_size, width=gen_size, generator=g
139
+ ).images[0])
140
 
141
  info = {
142
  "prompt": prompt,
143
+ "num_user_images": len(imgs),
144
  "k_retrieve": k_retrieve,
145
  "n_img2img": n_i2i,
146
  "n_txt2img": n_t2i,