LevyJonas commited on
Commit
70e850c
·
verified ·
1 Parent(s): 135d679

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +336 -0
pipeline.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA_ROOT = Path("sat_land_patches") # local dataset folder
2
+ EMB_DIR = Path("embeddings_part3") # where Part 3 outputs are
3
+
4
+ # Load best embeddings + metadata saved as .npy + .csv
5
+ DB_E = np.load(EMB_DIR / "best_embeddings.npy").astype(np.float32) # (N,D) L2-normalized
6
+ db_meta = pd.read_csv(EMB_DIR / "best_metadata.csv") # id,label,filename,model_id
7
+
8
+ DB_labels = db_meta["label"].values
9
+ DB_files = db_meta["filename"].values
10
+
11
+ print("DB:", DB_E.shape, "| labels:", len(np.unique(DB_labels)))
12
+
13
+ # ===============================================================================================
14
+
15
+ import torch
16
+ from transformers import AutoImageProcessor, Dinov2Model
17
+
18
+ # Same model we selected in Part 3
19
+ EMB_MODEL_ID = "facebook/dinov2-small"
20
+
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print("Device:", device)
23
+
24
+ processor = AutoImageProcessor.from_pretrained(EMB_MODEL_ID)
25
+ embedder = Dinov2Model.from_pretrained(EMB_MODEL_ID).to(device)
26
+ embedder.eval()
27
+
28
+ def embed_query_image(img: Image.Image) -> np.ndarray:
29
+ """Return L2-normalized embedding vector for a query PIL image."""
30
+ img = img.convert("RGB")
31
+ inputs = processor(images=[img], return_tensors="pt")
32
+ pixel_values = inputs["pixel_values"].to(device)
33
+
34
+ with torch.inference_mode():
35
+ if device == "cuda":
36
+ with torch.autocast("cuda", dtype=torch.float16):
37
+ out = embedder(pixel_values=pixel_values)
38
+ else:
39
+ out = embedder(pixel_values=pixel_values)
40
+
41
+ v = out.last_hidden_state[:, 0, :].float().cpu().numpy()[0]
42
+ v = v / (np.linalg.norm(v) + 1e-12)
43
+ return v.astype(np.float32)
44
+
45
+ # ===============================================================================================
46
+
47
+ from collections import Counter
48
+
49
+ def retrieve_topk(query_vec: np.ndarray, k: int = 5):
50
+ """Cosine similarity = dot product because vectors are L2-normalized."""
51
+ k = int(max(0, min(5, k))) # cap at 5
52
+ if k == 0:
53
+ return [], [], []
54
+ sims = DB_E @ query_vec
55
+ idx = np.argsort(-sims)[:k]
56
+ return idx, sims[idx], DB_labels[idx]
57
+
58
+ def majority_label(labels):
59
+ if len(labels) == 0:
60
+ return None
61
+ return Counter(labels.tolist()).most_common(1)[0][0]
62
+
63
+ def load_db_image(rel_path: str) -> Image.Image:
64
+ return Image.open(DATA_ROOT / rel_path).convert("RGB")
65
+
66
+ # ===============================================================================================
67
+
68
+ # Uses sd-turbo for both text-to-image and image-to-image.
69
+ # Notes: keep steps low for speed (1-2). Works best on GPU.
70
+
71
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
72
+
73
+ GEN_MODEL_ID = "stabilityai/sd-turbo"
74
+
75
+ txt2img = StableDiffusionPipeline.from_pretrained(
76
+ GEN_MODEL_ID, torch_dtype=torch.float16, variant="fp16"
77
+ ).to("cuda")
78
+ txt2img.set_progress_bar_config(disable=True)
79
+
80
+ img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
81
+ GEN_MODEL_ID, torch_dtype=torch.float16, variant="fp16"
82
+ ).to("cuda")
83
+ img2img.set_progress_bar_config(disable=True)
84
+
85
+ # Optional speed-ups (safe to ignore if not available)
86
+ try:
87
+ txt2img.enable_xformers_memory_efficient_attention()
88
+ img2img.enable_xformers_memory_efficient_attention()
89
+ except Exception:
90
+ pass
91
+
92
+ # ===============================================================================================
93
+
94
+ # PROMPTS must exist: dict[label] -> prompt
95
+ assert "PROMPTS" in globals(), "PROMPTS dict not found. Paste your PROMPTS (30 labels) before running Part 4."
96
+ NEGATIVE = "cartoon, illustration, anime, text, watermark, logo, low quality, blurry, distorted, unrealistic"
97
+
98
+ # ===============================================================================================
99
+
100
+ import math
101
+ from collections import Counter
102
+ import torch
103
+ from PIL import Image
104
+
105
+ def _cap_0_5(x):
106
+ """Cap an integer to the range [0, 5]."""
107
+ return int(max(0, min(5, int(x))))
108
+
109
+ def _majority_label(arr):
110
+ """Return majority label from an array of labels (or None)."""
111
+ if len(arr) == 0:
112
+ return None
113
+ return Counter(arr.tolist()).most_common(1)[0][0]
114
+
115
+ def retrieve_topk(query_vec, k=5):
116
+ """
117
+ Retrieve top-k most similar items from DB using cosine similarity.
118
+ Cosine similarity = dot product because vectors are L2-normalized.
119
+ Returns: list of dicts with image, label, similarity, filename.
120
+ """
121
+ k = _cap_0_5(k)
122
+ if k == 0:
123
+ return []
124
+
125
+ sims = DB_E @ query_vec
126
+ idx = np.argsort(-sims)[:k]
127
+
128
+ results = []
129
+ for i in idx:
130
+ rel = DB_files[i]
131
+ results.append({
132
+ "img": load_db_image(rel),
133
+ "label": DB_labels[i],
134
+ "sim": float(sims[i]),
135
+ "filename": rel
136
+ })
137
+ return results
138
+
139
+ def _safe_img2img_steps(strength, user_steps):
140
+ """
141
+ Diffusers img2img requires at least 1 effective denoising step:
142
+ effective = int(num_inference_steps * strength) >= 1
143
+ If not, tensors become empty and you get the reshape error.
144
+ This function chooses a safe num_inference_steps automatically.
145
+ """
146
+ strength = float(strength)
147
+ strength = max(1e-3, min(1.0, strength)) # keep in (0,1]
148
+
149
+ steps = int(user_steps)
150
+ steps = max(1, min(6, steps)) # keep small for turbo
151
+
152
+ # Ensure effective steps >= 1
153
+ if int(steps * strength) < 1:
154
+ steps = int(math.ceil(1.0 / strength))
155
+
156
+ # Clamp again to keep runtime bounded (still safe)
157
+ steps = max(2, min(6, steps))
158
+ return steps, strength
159
+
160
+ def run_search_and_generate(
161
+ user_img: Image.Image,
162
+ k_retrieve: int = 2,
163
+ n_i2i: int = 1,
164
+ n_t2i: int = 1,
165
+ steps_t2i: int = 1,
166
+ strength_i2i: float = 0.35,
167
+ gen_size: int = 512,
168
+ seed: int = 123
169
+ ):
170
+ """
171
+ Pipeline:
172
+ 1) Embed input image (DINOv2)
173
+ 2) Retrieve top-k similar images from DB
174
+ 3) Choose prompt based on majority retrieved label
175
+ 4) Generate n_i2i images using img2img
176
+ 5) Generate n_t2i images using txt2img
177
+
178
+ Returns:
179
+ retrieved: list[dict] (each dict has img/label/sim/filename)
180
+ gen_i2i: list[PIL.Image]
181
+ gen_t2i: list[PIL.Image]
182
+ info: dict (prompt/labels/params)
183
+ """
184
+ # --- Cap counts to [0,5] for app safety ---
185
+ k_retrieve = _cap_0_5(k_retrieve)
186
+ n_i2i = _cap_0_5(n_i2i)
187
+ n_t2i = _cap_0_5(n_t2i)
188
+
189
+ # --- Embed query image ---
190
+ q_vec = embed_query_image(user_img)
191
+
192
+ # --- Retrieve ---
193
+ retrieved = retrieve_topk(q_vec, k=k_retrieve)
194
+
195
+ # Decide label/prompt from retrieval results
196
+ retrieved_labels = np.array([r["label"] for r in retrieved]) if len(retrieved) else np.array([])
197
+ maj_label = _majority_label(retrieved_labels) if len(retrieved_labels) else None
198
+
199
+ prompt = PROMPTS.get(
200
+ maj_label,
201
+ "Satellite-like RGB patch, realistic remote sensing, top-down view"
202
+ )
203
+
204
+ # Prepare init image for img2img
205
+ init_img = user_img.convert("RGB").resize((gen_size, gen_size))
206
+
207
+ # --- Generate (img2img) ---
208
+ gen_i2i = []
209
+ if n_i2i > 0:
210
+ safe_steps_i2i, safe_strength = _safe_img2img_steps(strength_i2i, steps_t2i)
211
+
212
+ for i in range(n_i2i):
213
+ g = torch.Generator("cuda").manual_seed(seed + 10*i)
214
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
215
+ im = img2img(
216
+ prompt=prompt,
217
+ negative_prompt=NEGATIVE,
218
+ image=init_img,
219
+ strength=safe_strength,
220
+ num_inference_steps=safe_steps_i2i,
221
+ guidance_scale=0.0,
222
+ generator=g
223
+ ).images[0]
224
+ gen_i2i.append(im)
225
+
226
+ # --- Generate (txt2img) ---
227
+ gen_t2i = []
228
+ if n_t2i > 0:
229
+ # sd-turbo is designed for 1–2 steps
230
+ steps_txt = max(1, min(2, int(steps_t2i)))
231
+
232
+ for i in range(n_t2i):
233
+ g = torch.Generator("cuda").manual_seed(seed + 100 + 10*i)
234
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
235
+ im = txt2img(
236
+ prompt=prompt,
237
+ negative_prompt=NEGATIVE,
238
+ num_inference_steps=steps_txt,
239
+ guidance_scale=0.0,
240
+ height=gen_size,
241
+ width=gen_size,
242
+ generator=g
243
+ ).images[0]
244
+ gen_t2i.append(im)
245
+
246
+ info = {
247
+ "majority_label_from_retrieval": maj_label,
248
+ "used_prompt": prompt,
249
+ "k_retrieve": k_retrieve,
250
+ "n_img2img": n_i2i,
251
+ "n_txt2img": n_t2i,
252
+ "steps_txt2img": max(1, min(2, int(steps_t2i))),
253
+ "requested_strength_img2img": float(strength_i2i),
254
+ "gen_size": gen_size,
255
+ "seed": seed
256
+ }
257
+
258
+ return retrieved, gen_i2i, gen_t2i, info
259
+
260
+ # ===============================================================================================
261
+
262
+ import random
263
+ import matplotlib.pyplot as plt
264
+ from PIL import Image
265
+
266
+ # --- Pick a demo input image from your dataset ---
267
+ demo_rel = DB_files[random.randrange(len(DB_files))]
268
+ user_img = load_db_image(demo_rel)
269
+
270
+ # --- Run pipeline (you can change these 0-5 values) ---
271
+ k_retrieve = 2 # 0..5 images from database
272
+ n_i2i = 2 # 0..5 new images via image-to-image
273
+ n_t2i = 2 # 0..5 new images via text-to-image
274
+
275
+ retrieved, gen_i2i, gen_t2i, info = run_search_and_generate(
276
+ user_img=user_img,
277
+ k_retrieve=k_retrieve,
278
+ n_i2i=n_i2i,
279
+ n_t2i=n_t2i,
280
+ steps_t2i=1, # txt2img steps (1-2 recommended for sd-turbo)
281
+ strength_i2i=0.35, # img2img strength (0.25-0.60 is typical)
282
+ gen_size=512,
283
+ seed=42
284
+ )
285
+
286
+ print("=== PIPELINE INFO ===")
287
+ for k, v in info.items():
288
+ print(f"{k}: {v}")
289
+
290
+ # --- Helper to show a gallery in one row ---
291
+ def show_row(images, titles, fig_w=16, fig_h=3, suptitle=None):
292
+ n = len(images)
293
+ if n == 0:
294
+ print(suptitle or "No images to show.")
295
+ return
296
+ plt.figure(figsize=(fig_w, fig_h))
297
+ for i, (im, t) in enumerate(zip(images, titles), 1):
298
+ ax = plt.subplot(1, n, i)
299
+ ax.imshow(im)
300
+ ax.set_title(t, fontsize=9)
301
+ ax.axis("off")
302
+ if suptitle:
303
+ plt.suptitle(suptitle, fontsize=12)
304
+ plt.tight_layout()
305
+ plt.show()
306
+
307
+ # 1) Show input image
308
+ show_row(
309
+ images=[user_img],
310
+ titles=[f"USER INPUT\n{demo_rel}"],
311
+ fig_w=6,
312
+ fig_h=4,
313
+ suptitle="User Input"
314
+ )
315
+
316
+ # 2) Show retrieved images
317
+ if len(retrieved) > 0:
318
+ ret_imgs = [r["img"] for r in retrieved]
319
+ ret_titles = [f"{r['label']}\ncos={r['sim']:.3f}" for r in retrieved]
320
+ show_row(ret_imgs, ret_titles, fig_w=3.2*len(ret_imgs), fig_h=3, suptitle="Top-K Retrieved from Database")
321
+ else:
322
+ print("No retrieval results (k_retrieve=0).")
323
+
324
+ # 3) Show generated img2img images
325
+ if len(gen_i2i) > 0:
326
+ titles = [f"img2img #{i+1}" for i in range(len(gen_i2i))]
327
+ show_row(gen_i2i, titles, fig_w=3.2*len(gen_i2i), fig_h=3, suptitle="Generated (Image-to-Image)")
328
+ else:
329
+ print("No img2img generated (n_i2i=0).")
330
+
331
+ # 4) Show generated txt2img images
332
+ if len(gen_t2i) > 0:
333
+ titles = [f"txt2img #{i+1}" for i in range(len(gen_t2i))]
334
+ show_row(gen_t2i, titles, fig_w=3.2*len(gen_t2i), fig_h=3, suptitle="Generated (Text-to-Image)")
335
+ else:
336
+ print("No txt2img generated (n_t2i=0).")