LevyJonas commited on
Commit
511b795
·
verified ·
1 Parent(s): e2a1d87

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +104 -311
pipeline.py CHANGED
@@ -1,336 +1,129 @@
1
- DATA_ROOT = Path("sat_land_patches") # local dataset folder
2
- EMB_DIR = Path("embeddings_part3") # where Part 3 outputs are
3
-
4
- # Load best embeddings + metadata saved as .npy + .csv
5
- DB_E = np.load(EMB_DIR / "best_embeddings.npy").astype(np.float32) # (N,D) L2-normalized
6
- db_meta = pd.read_csv(EMB_DIR / "best_metadata.csv") # id,label,filename,model_id
7
-
8
- DB_labels = db_meta["label"].values
9
- DB_files = db_meta["filename"].values
10
-
11
- print("DB:", DB_E.shape, "| labels:", len(np.unique(DB_labels)))
12
 
13
- # ===============================================================================================
 
 
14
 
15
  import torch
16
  from transformers import AutoImageProcessor, Dinov2Model
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Same model we selected in Part 3
19
  EMB_MODEL_ID = "facebook/dinov2-small"
 
 
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
- print("Device:", device)
23
 
24
- processor = AutoImageProcessor.from_pretrained(EMB_MODEL_ID)
25
- embedder = Dinov2Model.from_pretrained(EMB_MODEL_ID).to(device)
26
- embedder.eval()
27
 
28
- def embed_query_image(img: Image.Image) -> np.ndarray:
29
- """Return L2-normalized embedding vector for a query PIL image."""
30
- img = img.convert("RGB")
31
- inputs = processor(images=[img], return_tensors="pt")
32
- pixel_values = inputs["pixel_values"].to(device)
 
 
 
33
 
 
 
 
 
 
 
 
34
  with torch.inference_mode():
35
  if device == "cuda":
36
  with torch.autocast("cuda", dtype=torch.float16):
37
- out = embedder(pixel_values=pixel_values)
38
  else:
39
- out = embedder(pixel_values=pixel_values)
40
-
41
  v = out.last_hidden_state[:, 0, :].float().cpu().numpy()[0]
42
- v = v / (np.linalg.norm(v) + 1e-12)
43
- return v.astype(np.float32)
44
 
45
- # ===============================================================================================
46
-
47
- from collections import Counter
48
-
49
- def retrieve_topk(query_vec: np.ndarray, k: int = 5):
50
- """Cosine similarity = dot product because vectors are L2-normalized."""
51
- k = int(max(0, min(5, k))) # cap at 5
52
- if k == 0:
53
- return [], [], []
54
  sims = DB_E @ query_vec
55
  idx = np.argsort(-sims)[:k]
56
- return idx, sims[idx], DB_labels[idx]
57
-
58
- def majority_label(labels):
59
- if len(labels) == 0:
60
- return None
61
- return Counter(labels.tolist()).most_common(1)[0][0]
62
-
63
- def load_db_image(rel_path: str) -> Image.Image:
64
- return Image.open(DATA_ROOT / rel_path).convert("RGB")
65
-
66
- # ===============================================================================================
67
-
68
- # Uses sd-turbo for both text-to-image and image-to-image.
69
- # Notes: keep steps low for speed (1-2). Works best on GPU.
70
-
71
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
72
-
73
- GEN_MODEL_ID = "stabilityai/sd-turbo"
74
-
75
- txt2img = StableDiffusionPipeline.from_pretrained(
76
- GEN_MODEL_ID, torch_dtype=torch.float16, variant="fp16"
77
- ).to("cuda")
78
- txt2img.set_progress_bar_config(disable=True)
79
-
80
- img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
81
- GEN_MODEL_ID, torch_dtype=torch.float16, variant="fp16"
82
- ).to("cuda")
83
- img2img.set_progress_bar_config(disable=True)
84
-
85
- # Optional speed-ups (safe to ignore if not available)
86
- try:
87
- txt2img.enable_xformers_memory_efficient_attention()
88
- img2img.enable_xformers_memory_efficient_attention()
89
- except Exception:
90
- pass
91
-
92
- # ===============================================================================================
93
-
94
- # PROMPTS must exist: dict[label] -> prompt
95
- assert "PROMPTS" in globals(), "PROMPTS dict not found. Paste your PROMPTS (30 labels) before running Part 4."
96
- NEGATIVE = "cartoon, illustration, anime, text, watermark, logo, low quality, blurry, distorted, unrealistic"
97
-
98
- # ===============================================================================================
99
-
100
- import math
101
- from collections import Counter
102
- import torch
103
- from PIL import Image
104
-
105
- def _cap_0_5(x):
106
- """Cap an integer to the range [0, 5]."""
107
- return int(max(0, min(5, int(x))))
108
-
109
- def _majority_label(arr):
110
- """Return majority label from an array of labels (or None)."""
111
- if len(arr) == 0:
112
- return None
113
- return Counter(arr.tolist()).most_common(1)[0][0]
114
-
115
- def retrieve_topk(query_vec, k=5):
116
- """
117
- Retrieve top-k most similar items from DB using cosine similarity.
118
- Cosine similarity = dot product because vectors are L2-normalized.
119
- Returns: list of dicts with image, label, similarity, filename.
120
- """
121
- k = _cap_0_5(k)
122
- if k == 0:
123
- return []
124
-
125
- sims = DB_E @ query_vec
126
- idx = np.argsort(-sims)[:k]
127
-
128
- results = []
129
- for i in idx:
130
- rel = DB_files[i]
131
- results.append({
132
- "img": load_db_image(rel),
133
- "label": DB_labels[i],
134
- "sim": float(sims[i]),
135
- "filename": rel
136
- })
137
- return results
138
-
139
- def _safe_img2img_steps(strength, user_steps):
140
- """
141
- Diffusers img2img requires at least 1 effective denoising step:
142
- effective = int(num_inference_steps * strength) >= 1
143
- If not, tensors become empty and you get the reshape error.
144
- This function chooses a safe num_inference_steps automatically.
145
- """
146
- strength = float(strength)
147
- strength = max(1e-3, min(1.0, strength)) # keep in (0,1]
148
-
149
- steps = int(user_steps)
150
- steps = max(1, min(6, steps)) # keep small for turbo
151
-
152
- # Ensure effective steps >= 1
153
- if int(steps * strength) < 1:
154
- steps = int(math.ceil(1.0 / strength))
155
-
156
- # Clamp again to keep runtime bounded (still safe)
157
- steps = max(2, min(6, steps))
158
- return steps, strength
159
-
160
- def run_search_and_generate(
161
- user_img: Image.Image,
162
- k_retrieve: int = 2,
163
- n_i2i: int = 1,
164
- n_t2i: int = 1,
165
- steps_t2i: int = 1,
166
- strength_i2i: float = 0.35,
167
- gen_size: int = 512,
168
- seed: int = 123
169
- ):
170
- """
171
- Pipeline:
172
- 1) Embed input image (DINOv2)
173
- 2) Retrieve top-k similar images from DB
174
- 3) Choose prompt based on majority retrieved label
175
- 4) Generate n_i2i images using img2img
176
- 5) Generate n_t2i images using txt2img
177
-
178
- Returns:
179
- retrieved: list[dict] (each dict has img/label/sim/filename)
180
- gen_i2i: list[PIL.Image]
181
- gen_t2i: list[PIL.Image]
182
- info: dict (prompt/labels/params)
183
- """
184
- # --- Cap counts to [0,5] for app safety ---
185
- k_retrieve = _cap_0_5(k_retrieve)
186
- n_i2i = _cap_0_5(n_i2i)
187
- n_t2i = _cap_0_5(n_t2i)
188
-
189
- # --- Embed query image ---
190
- q_vec = embed_query_image(user_img)
191
-
192
- # --- Retrieve ---
193
- retrieved = retrieve_topk(q_vec, k=k_retrieve)
194
-
195
- # Decide label/prompt from retrieval results
196
- retrieved_labels = np.array([r["label"] for r in retrieved]) if len(retrieved) else np.array([])
197
- maj_label = _majority_label(retrieved_labels) if len(retrieved_labels) else None
198
-
199
- prompt = PROMPTS.get(
200
- maj_label,
201
- "Satellite-like RGB patch, realistic remote sensing, top-down view"
202
- )
203
-
204
- # Prepare init image for img2img
205
- init_img = user_img.convert("RGB").resize((gen_size, gen_size))
206
-
207
- # --- Generate (img2img) ---
208
- gen_i2i = []
209
- if n_i2i > 0:
210
- safe_steps_i2i, safe_strength = _safe_img2img_steps(strength_i2i, steps_t2i)
211
-
212
- for i in range(n_i2i):
213
- g = torch.Generator("cuda").manual_seed(seed + 10*i)
214
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
215
- im = img2img(
216
- prompt=prompt,
217
- negative_prompt=NEGATIVE,
218
- image=init_img,
219
- strength=safe_strength,
220
- num_inference_steps=safe_steps_i2i,
221
- guidance_scale=0.0,
222
- generator=g
223
- ).images[0]
224
- gen_i2i.append(im)
225
-
226
- # --- Generate (txt2img) ---
227
- gen_t2i = []
228
- if n_t2i > 0:
229
- # sd-turbo is designed for 1–2 steps
230
- steps_txt = max(1, min(2, int(steps_t2i)))
231
-
232
- for i in range(n_t2i):
233
- g = torch.Generator("cuda").manual_seed(seed + 100 + 10*i)
234
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
235
- im = txt2img(
236
- prompt=prompt,
237
- negative_prompt=NEGATIVE,
238
- num_inference_steps=steps_txt,
239
- guidance_scale=0.0,
240
- height=gen_size,
241
- width=gen_size,
242
- generator=g
243
- ).images[0]
244
- gen_t2i.append(im)
245
 
246
  info = {
247
- "majority_label_from_retrieval": maj_label,
248
- "used_prompt": prompt,
249
  "k_retrieve": k_retrieve,
250
  "n_img2img": n_i2i,
251
  "n_txt2img": n_t2i,
252
- "steps_txt2img": max(1, min(2, int(steps_t2i))),
253
- "requested_strength_img2img": float(strength_i2i),
254
- "gen_size": gen_size,
255
- "seed": seed
256
  }
257
-
258
- return retrieved, gen_i2i, gen_t2i, info
259
-
260
- # ===============================================================================================
261
-
262
- import random
263
- import matplotlib.pyplot as plt
264
- from PIL import Image
265
-
266
- # --- Pick a demo input image from your dataset ---
267
- demo_rel = DB_files[random.randrange(len(DB_files))]
268
- user_img = load_db_image(demo_rel)
269
-
270
- # --- Run pipeline (you can change these 0-5 values) ---
271
- k_retrieve = 2 # 0..5 images from database
272
- n_i2i = 2 # 0..5 new images via image-to-image
273
- n_t2i = 2 # 0..5 new images via text-to-image
274
-
275
- retrieved, gen_i2i, gen_t2i, info = run_search_and_generate(
276
- user_img=user_img,
277
- k_retrieve=k_retrieve,
278
- n_i2i=n_i2i,
279
- n_t2i=n_t2i,
280
- steps_t2i=1, # txt2img steps (1-2 recommended for sd-turbo)
281
- strength_i2i=0.35, # img2img strength (0.25-0.60 is typical)
282
- gen_size=512,
283
- seed=42
284
- )
285
-
286
- print("=== PIPELINE INFO ===")
287
- for k, v in info.items():
288
- print(f"{k}: {v}")
289
-
290
- # --- Helper to show a gallery in one row ---
291
- def show_row(images, titles, fig_w=16, fig_h=3, suptitle=None):
292
- n = len(images)
293
- if n == 0:
294
- print(suptitle or "No images to show.")
295
- return
296
- plt.figure(figsize=(fig_w, fig_h))
297
- for i, (im, t) in enumerate(zip(images, titles), 1):
298
- ax = plt.subplot(1, n, i)
299
- ax.imshow(im)
300
- ax.set_title(t, fontsize=9)
301
- ax.axis("off")
302
- if suptitle:
303
- plt.suptitle(suptitle, fontsize=12)
304
- plt.tight_layout()
305
- plt.show()
306
-
307
- # 1) Show input image
308
- show_row(
309
- images=[user_img],
310
- titles=[f"USER INPUT\n{demo_rel}"],
311
- fig_w=6,
312
- fig_h=4,
313
- suptitle="User Input"
314
- )
315
-
316
- # 2) Show retrieved images
317
- if len(retrieved) > 0:
318
- ret_imgs = [r["img"] for r in retrieved]
319
- ret_titles = [f"{r['label']}\ncos={r['sim']:.3f}" for r in retrieved]
320
- show_row(ret_imgs, ret_titles, fig_w=3.2*len(ret_imgs), fig_h=3, suptitle="Top-K Retrieved from Database")
321
- else:
322
- print("No retrieval results (k_retrieve=0).")
323
-
324
- # 3) Show generated img2img images
325
- if len(gen_i2i) > 0:
326
- titles = [f"img2img #{i+1}" for i in range(len(gen_i2i))]
327
- show_row(gen_i2i, titles, fig_w=3.2*len(gen_i2i), fig_h=3, suptitle="Generated (Image-to-Image)")
328
- else:
329
- print("No img2img generated (n_i2i=0).")
330
-
331
- # 4) Show generated txt2img images
332
- if len(gen_t2i) > 0:
333
- titles = [f"txt2img #{i+1}" for i in range(len(gen_t2i))]
334
- show_row(gen_t2i, titles, fig_w=3.2*len(gen_t2i), fig_h=3, suptitle="Generated (Text-to-Image)")
335
- else:
336
- print("No txt2img generated (n_t2i=0).")
 
1
+ # pipeline.py
2
+ import math
3
+ from pathlib import Path
 
 
 
 
 
 
 
 
4
 
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
 
9
  import torch
10
  from transformers import AutoImageProcessor, Dinov2Model
11
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ HF_DATASET_ID = "LevyJonas/sat_land_patches"
15
+ CACHE_DIR = Path("hf_cache"); CACHE_DIR.mkdir(exist_ok=True, parents=True)
16
+
17
+ EMB_DIR = Path("embeddings_part3")
18
+ DB_E = np.load(EMB_DIR / "best_embeddings.npy").astype(np.float32)
19
+ META = pd.read_csv(EMB_DIR / "best_metadata.csv")
20
+ DB_FILES = META["filename"].values
21
+ DB_LABELS = META["label"].values
22
 
 
23
  EMB_MODEL_ID = "facebook/dinov2-small"
24
+ GEN_MODEL_ID = "stabilityai/sd-turbo"
25
+ NEG = "cartoon, illustration, anime, text, watermark, logo, low quality, blurry, distorted, unrealistic"
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
28
 
29
+ # --- embedder ---
30
+ proc = AutoImageProcessor.from_pretrained(EMB_MODEL_ID)
31
+ emb = Dinov2Model.from_pretrained(EMB_MODEL_ID).to(device).eval()
32
 
33
+ # --- generators ---
34
+ dtype = torch.float16 if device == "cuda" else torch.float32
35
+ txt2img = StableDiffusionPipeline.from_pretrained(GEN_MODEL_ID, torch_dtype=dtype, variant="fp16" if device=="cuda" else None).to(device)
36
+ img2img = StableDiffusionImg2ImgPipeline.from_pretrained(GEN_MODEL_ID, torch_dtype=dtype, variant="fp16" if device=="cuda" else None).to(device)
37
+ txt2img.set_progress_bar_config(disable=True)
38
+ img2img.set_progress_bar_config(disable=True)
39
+
40
+ def _cap(x): return int(max(0, min(5, int(x))))
41
 
42
+ def load_from_hf(rel_path: str) -> Image.Image:
43
+ p = hf_hub_download(repo_id=HF_DATASET_ID, repo_type="dataset", filename=rel_path,
44
+ local_dir=str(CACHE_DIR), local_dir_use_symlinks=False)
45
+ return Image.open(p).convert("RGB")
46
+
47
+ def embed_image(pil_img: Image.Image) -> np.ndarray:
48
+ x = proc(images=[pil_img.convert("RGB")], return_tensors="pt")["pixel_values"].to(device)
49
  with torch.inference_mode():
50
  if device == "cuda":
51
  with torch.autocast("cuda", dtype=torch.float16):
52
+ out = emb(pixel_values=x)
53
  else:
54
+ out = emb(pixel_values=x)
 
55
  v = out.last_hidden_state[:, 0, :].float().cpu().numpy()[0]
56
+ return (v / (np.linalg.norm(v) + 1e-12)).astype(np.float32)
 
57
 
58
+ def retrieve(query_vec: np.ndarray, k: int):
59
+ k = _cap(k)
60
+ if k == 0: return []
 
 
 
 
 
 
61
  sims = DB_E @ query_vec
62
  idx = np.argsort(-sims)[:k]
63
+ return [{
64
+ "img": load_from_hf(DB_FILES[i]),
65
+ "label": DB_LABELS[i],
66
+ "sim": float(sims[i]),
67
+ "filename": DB_FILES[i],
68
+ } for i in idx]
69
+
70
+ def _safe_i2i_steps(strength: float, user_steps: int):
71
+ strength = float(max(1e-3, min(1.0, strength)))
72
+ steps = int(max(1, min(2, user_steps))) # user slider 1..2
73
+ # ensure int(steps_i2i * strength) >= 1
74
+ steps_i2i = max(2, int(math.ceil(1.0 / strength)))
75
+ steps_i2i = min(6, steps_i2i) # keep fast
76
+ return steps, steps_i2i, strength
77
+
78
+ def run_search_and_generate(user_img: Image.Image, user_prompt: str,
79
+ k_retrieve=2, n_i2i=2, n_t2i=2,
80
+ strength_i2i=0.35, steps=1, gen_size=512, seed=42):
81
+ k_retrieve, n_i2i, n_t2i = _cap(k_retrieve), _cap(n_i2i), _cap(n_t2i)
82
+ prompt = (user_prompt or "").strip()
83
+ if not prompt:
84
+ raise ValueError("Please enter a prompt (required for generation).")
85
+
86
+ q = embed_image(user_img)
87
+ retrieved = retrieve(q, k_retrieve)
88
+
89
+ steps_txt, steps_i2i, strength = _safe_i2i_steps(strength_i2i, steps)
90
+ init = user_img.convert("RGB").resize((gen_size, gen_size))
91
+
92
+ gen_i2i, gen_t2i = [], []
93
+ for i in range(n_i2i):
94
+ g = torch.Generator(device).manual_seed(seed + 10*i)
95
+ with torch.inference_mode():
96
+ if device == "cuda":
97
+ with torch.autocast("cuda", dtype=torch.float16):
98
+ gen_i2i.append(img2img(prompt=prompt, negative_prompt=NEG, image=init,
99
+ strength=strength, num_inference_steps=steps_i2i,
100
+ guidance_scale=0.0, generator=g).images[0])
101
+ else:
102
+ gen_i2i.append(img2img(prompt=prompt, negative_prompt=NEG, image=init,
103
+ strength=strength, num_inference_steps=steps_i2i,
104
+ guidance_scale=0.0, generator=g).images[0])
105
+
106
+ for i in range(n_t2i):
107
+ g = torch.Generator(device).manual_seed(seed + 100 + 10*i)
108
+ with torch.inference_mode():
109
+ if device == "cuda":
110
+ with torch.autocast("cuda", dtype=torch.float16):
111
+ gen_t2i.append(txt2img(prompt=prompt, negative_prompt=NEG,
112
+ num_inference_steps=steps_txt, guidance_scale=0.0,
113
+ height=gen_size, width=gen_size, generator=g).images[0])
114
+ else:
115
+ gen_t2i.append(txt2img(prompt=prompt, negative_prompt=NEG,
116
+ num_inference_steps=steps_txt, guidance_scale=0.0,
117
+ height=gen_size, width=gen_size, generator=g).images[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  info = {
120
+ "prompt": prompt,
 
121
  "k_retrieve": k_retrieve,
122
  "n_img2img": n_i2i,
123
  "n_txt2img": n_t2i,
124
+ "strength_i2i": strength,
125
+ "steps_txt2img": steps_txt,
126
+ "steps_img2img": steps_i2i,
127
+ "dataset": HF_DATASET_ID
128
  }
129
+ return retrieved, gen_i2i, gen_t2i, info