Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +39 -16
pipeline.py
CHANGED
|
@@ -75,49 +75,72 @@ def _safe_i2i_steps(strength: float, user_steps: int):
|
|
| 75 |
steps_i2i = min(6, steps_i2i) # keep fast
|
| 76 |
return steps, steps_i2i, strength
|
| 77 |
|
| 78 |
-
def run_search_and_generate(
|
| 79 |
k_retrieve=2, n_i2i=2, n_t2i=2,
|
| 80 |
strength_i2i=0.35, steps=1, gen_size=512, seed=42):
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
prompt = (user_prompt or "").strip()
|
| 83 |
if not prompt:
|
| 84 |
raise ValueError("Please enter a prompt (required for generation).")
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
retrieved = retrieve(q, k_retrieve)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
steps_txt, steps_i2i, strength = _safe_i2i_steps(strength_i2i, steps)
|
| 90 |
-
init = user_img.convert("RGB").resize((gen_size, gen_size))
|
| 91 |
|
| 92 |
gen_i2i, gen_t2i = [], []
|
|
|
|
| 93 |
for i in range(n_i2i):
|
| 94 |
g = torch.Generator(device).manual_seed(seed + 10*i)
|
| 95 |
with torch.inference_mode():
|
| 96 |
if device == "cuda":
|
| 97 |
with torch.autocast("cuda", dtype=torch.float16):
|
| 98 |
-
gen_i2i.append(img2img(
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
else:
|
| 102 |
-
gen_i2i.append(img2img(
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
|
| 106 |
for i in range(n_t2i):
|
| 107 |
g = torch.Generator(device).manual_seed(seed + 100 + 10*i)
|
| 108 |
with torch.inference_mode():
|
| 109 |
if device == "cuda":
|
| 110 |
with torch.autocast("cuda", dtype=torch.float16):
|
| 111 |
-
gen_t2i.append(txt2img(
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
else:
|
| 115 |
-
gen_t2i.append(txt2img(
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
| 118 |
|
| 119 |
info = {
|
| 120 |
"prompt": prompt,
|
|
|
|
| 121 |
"k_retrieve": k_retrieve,
|
| 122 |
"n_img2img": n_i2i,
|
| 123 |
"n_txt2img": n_t2i,
|
|
|
|
| 75 |
steps_i2i = min(6, steps_i2i) # keep fast
|
| 76 |
return steps, steps_i2i, strength
|
| 77 |
|
| 78 |
+
def run_search_and_generate(user_imgs, user_prompt: str,
|
| 79 |
k_retrieve=2, n_i2i=2, n_t2i=2,
|
| 80 |
strength_i2i=0.35, steps=1, gen_size=512, seed=42):
|
| 81 |
+
# user_imgs: list of PIL images (1..4), some may be None
|
| 82 |
+
imgs = [im for im in (user_imgs or []) if im is not None]
|
| 83 |
+
if len(imgs) == 0:
|
| 84 |
+
raise ValueError("Please upload at least 1 image.")
|
| 85 |
+
|
| 86 |
prompt = (user_prompt or "").strip()
|
| 87 |
if not prompt:
|
| 88 |
raise ValueError("Please enter a prompt (required for generation).")
|
| 89 |
|
| 90 |
+
k_retrieve, n_i2i, n_t2i = _cap(k_retrieve), _cap(n_i2i), _cap(n_t2i)
|
| 91 |
+
|
| 92 |
+
# --- embed each image and average embeddings ---
|
| 93 |
+
vecs = [embed_image(im) for im in imgs]
|
| 94 |
+
q = np.mean(np.stack(vecs, axis=0), axis=0)
|
| 95 |
+
q = (q / (np.linalg.norm(q) + 1e-12)).astype(np.float32)
|
| 96 |
+
|
| 97 |
+
# --- retrieval based on averaged embedding ---
|
| 98 |
retrieved = retrieve(q, k_retrieve)
|
| 99 |
|
| 100 |
+
# --- choose init image for img2img (first provided) ---
|
| 101 |
+
init = imgs[0].convert("RGB").resize((gen_size, gen_size))
|
| 102 |
+
|
| 103 |
steps_txt, steps_i2i, strength = _safe_i2i_steps(strength_i2i, steps)
|
|
|
|
| 104 |
|
| 105 |
gen_i2i, gen_t2i = [], []
|
| 106 |
+
|
| 107 |
for i in range(n_i2i):
|
| 108 |
g = torch.Generator(device).manual_seed(seed + 10*i)
|
| 109 |
with torch.inference_mode():
|
| 110 |
if device == "cuda":
|
| 111 |
with torch.autocast("cuda", dtype=torch.float16):
|
| 112 |
+
gen_i2i.append(img2img(
|
| 113 |
+
prompt=prompt, negative_prompt=NEG, image=init,
|
| 114 |
+
strength=strength, num_inference_steps=steps_i2i,
|
| 115 |
+
guidance_scale=0.0, generator=g
|
| 116 |
+
).images[0])
|
| 117 |
else:
|
| 118 |
+
gen_i2i.append(img2img(
|
| 119 |
+
prompt=prompt, negative_prompt=NEG, image=init,
|
| 120 |
+
strength=strength, num_inference_steps=steps_i2i,
|
| 121 |
+
guidance_scale=0.0, generator=g
|
| 122 |
+
).images[0])
|
| 123 |
|
| 124 |
for i in range(n_t2i):
|
| 125 |
g = torch.Generator(device).manual_seed(seed + 100 + 10*i)
|
| 126 |
with torch.inference_mode():
|
| 127 |
if device == "cuda":
|
| 128 |
with torch.autocast("cuda", dtype=torch.float16):
|
| 129 |
+
gen_t2i.append(txt2img(
|
| 130 |
+
prompt=prompt, negative_prompt=NEG,
|
| 131 |
+
num_inference_steps=steps_txt, guidance_scale=0.0,
|
| 132 |
+
height=gen_size, width=gen_size, generator=g
|
| 133 |
+
).images[0])
|
| 134 |
else:
|
| 135 |
+
gen_t2i.append(txt2img(
|
| 136 |
+
prompt=prompt, negative_prompt=NEG,
|
| 137 |
+
num_inference_steps=steps_txt, guidance_scale=0.0,
|
| 138 |
+
height=gen_size, width=gen_size, generator=g
|
| 139 |
+
).images[0])
|
| 140 |
|
| 141 |
info = {
|
| 142 |
"prompt": prompt,
|
| 143 |
+
"num_user_images": len(imgs),
|
| 144 |
"k_retrieve": k_retrieve,
|
| 145 |
"n_img2img": n_i2i,
|
| 146 |
"n_txt2img": n_t2i,
|