aadarsh99 commited on
Commit
7e33e45
·
1 Parent(s): b0822ac

added plm file

Browse files
plm_adapter_lora_with_image_input_only_text_positions.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # plm_adapter.py
2
+ import torch, torch.nn as nn
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoProcessor,
6
+ Qwen2_5_VLForConditionalGeneration, # <- Qwen2.5-VL
7
+ )
8
+
9
+ from peft import LoraConfig, get_peft_model, TaskType, PeftModel
10
+
11
+
12
+
13
+ class PLMLanguageAdapter(nn.Module):
14
+ """
15
+ Uses Qwen/Qwen2.5-VL-7B-Instruct as a multimodal encoder and
16
+ projects features to SAM2's decoder dims.
17
+ Produces:
18
+ sparse: [B, N_text_tokens, C]
19
+ dense: [B, C, H, W] (text-conditioned bias map)
20
+ """
21
+ def __init__(
22
+ self,
23
+ model_name="Qwen/Qwen2.5-VL-7B-Instruct",
24
+ transformer_dim=256,
25
+ n_sparse_tokens=0,
26
+ use_dense_bias=True,
27
+ dtype=torch.bfloat16,
28
+ device="cuda",
29
+ # ---- LoRA knobs ----
30
+ use_lora=True,
31
+ lora_r=16,
32
+ lora_alpha=32,
33
+ lora_dropout=0.05,
34
+ lora_bias="none",
35
+ lora_target_modules="auto",
36
+ gradient_checkpointing=False,
37
+ # ---- NEW ----
38
+ use_image_input=True,
39
+ max_txt_len=256, # cap token length to save memory
40
+ ):
41
+ super().__init__()
42
+
43
+ self.max_txt_len = max_txt_len
44
+
45
+ # --- tokenizer & (optional) processor ---
46
+ self.tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
47
+ self.tok.padding_side = "right"
48
+
49
+ # Cache which token IDs are *not* plain text (special, image placeholders, etc.)
50
+ self._non_text_token_ids = None
51
+ self._init_non_text_token_ids()
52
+
53
+ # self.processor = AutoProcessor.from_pretrained(model_name) if use_image_input else None
54
+ # replace your AutoProcessor line with:
55
+ min_pix = (28*20) * (28*20) # 560x560
56
+ # min_pix = 512*512
57
+ max_pix = min_pix
58
+ self.processor = AutoProcessor.from_pretrained(
59
+ model_name, min_pixels=min_pix, max_pixels=max_pix
60
+ ) if use_image_input else None
61
+ if self.processor is not None and hasattr(self.processor, "image_processor"):
62
+ ip = self.processor.image_processor
63
+ # turn on resizing
64
+ try:
65
+ ip.do_resize = True
66
+ except Exception:
67
+ pass
68
+ # prefer explicit H/W dict (works across most processors)
69
+ try:
70
+ ip.size = {"height": 256, "width": 256}
71
+ except Exception:
72
+ # fallbacks for processors that expect a single int or 'shortest_edge'
73
+ try:
74
+ ip.size = 256
75
+ except Exception:
76
+ try:
77
+ ip.size = {"shortest_edge": 256}
78
+ except Exception:
79
+ pass
80
+
81
+ # --- backbone: Qwen2.5-VL conditional generation model ---
82
+ self.backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
83
+ model_name, dtype=dtype, device_map=None
84
+ ).to(device)
85
+
86
+ # Start frozen; LoRA will re-enable a tiny set
87
+ for p in self.backbone.parameters():
88
+ p.requires_grad = False
89
+
90
+ # Wire up LoRA (optional)
91
+ self.peft_enabled = False
92
+ if use_lora:
93
+ target_modules = self._infer_lora_targets(self.backbone) if lora_target_modules == "auto" else lora_target_modules
94
+ if len(target_modules) == 0:
95
+ raise RuntimeError("Could not find any LoRA target modules; set `lora_target_modules` explicitly.")
96
+ self.lora_cfg = LoraConfig(
97
+ r=lora_r,
98
+ lora_alpha=lora_alpha,
99
+ lora_dropout=lora_dropout,
100
+ target_modules=target_modules,
101
+ bias=lora_bias,
102
+ task_type=TaskType.FEATURE_EXTRACTION,
103
+ )
104
+ self.backbone = get_peft_model(self.backbone, self.lora_cfg)
105
+ self.peft_enabled = True
106
+
107
+ if gradient_checkpointing and hasattr(self.backbone, "gradient_checkpointing_enable"):
108
+ try:
109
+ if hasattr(self.backbone, "config") and hasattr(self.backbone.config, "use_cache"):
110
+ self.backbone.config.use_cache = False
111
+ except Exception:
112
+ pass
113
+ self.backbone.gradient_checkpointing_enable()
114
+ if hasattr(self.backbone, "enable_input_require_grads"):
115
+ self.backbone.enable_input_require_grads()
116
+
117
+ # Hidden size on the text side (Qwen2.5-VL has text_config)
118
+ cfg = getattr(self.backbone, "config", None)
119
+ D_t = getattr(getattr(cfg, "text_config", None), "hidden_size", None)
120
+ if D_t is None:
121
+ raise RuntimeError("Could not infer text hidden_size from model config.")
122
+
123
+ self.to_sparse = nn.Linear(D_t, transformer_dim)
124
+ self.to_dense = nn.Sequential(
125
+ nn.Linear(D_t, transformer_dim),
126
+ nn.SiLU(),
127
+ nn.Linear(transformer_dim, transformer_dim),
128
+ ) if use_dense_bias else None
129
+
130
+ nn.init.xavier_uniform_(self.to_sparse.weight); nn.init.zeros_(self.to_sparse.bias)
131
+ if self.to_dense is not None:
132
+ for m in self.to_dense:
133
+ if isinstance(m, nn.Linear):
134
+ nn.init.xavier_uniform_(m.weight)
135
+ nn.init.zeros_(m.bias)
136
+
137
+ self.n_sparse_tokens = n_sparse_tokens
138
+ self.use_dense_bias = use_dense_bias
139
+ self.scale = nn.Parameter(torch.tensor(1.0))
140
+ self.txt_norm = nn.LayerNorm(D_t)
141
+ self.temp = nn.Parameter(torch.tensor(1.0))
142
+
143
+ # ensure module dtypes/devices match
144
+ self.to(device=device, dtype=dtype)
145
+
146
+
147
+ # ---- token filters -------------------------------------------------------
148
+ def _init_non_text_token_ids(self):
149
+ """
150
+ Build a list of token IDs that should NOT count as text positions.
151
+ Includes:
152
+ - all special tokens (BOS/EOS, role tokens, etc.)
153
+ - added vocab entries that look like image/vision placeholders
154
+ """
155
+ ids = set(getattr(self.tok, "all_special_ids", []) or [])
156
+ # Grab added vocab and heuristically include any image/vision markers
157
+ try:
158
+ added = getattr(self.tok, "get_added_vocab", lambda: {})()
159
+ for tok, tid in added.items():
160
+ tl = tok.lower()
161
+ if any(s in tl for s in ("image", "vision", "<img", "picture", "video")):
162
+ ids.add(int(tid))
163
+ except Exception:
164
+ pass
165
+ # store as a 1D LongTensor; move to device on use
166
+ if len(ids) == 0:
167
+ # keep a sentinel so equality checks never broadcast against empty
168
+ ids = {-(10**9)}
169
+ self._non_text_token_ids = torch.tensor(sorted(ids), dtype=torch.long)
170
+
171
+ def _text_positions_mask(self, ids: torch.Tensor, attn: torch.Tensor, eos_pos: torch.Tensor) -> torch.Tensor:
172
+ """
173
+ Return [B, T] mask where True = positions that correspond to *plain text* tokens.
174
+ We exclude:
175
+ - padding (already excluded by attn)
176
+ - EOS position
177
+ - any token ID in _non_text_token_ids (special/image placeholders)
178
+ """
179
+ device = ids.device
180
+ bad = self._non_text_token_ids.to(device) # [K]
181
+ is_bad = (ids.unsqueeze(-1) == bad.view(1, 1, -1)).any(dim=-1) # [B, T]
182
+ idxs = torch.arange(ids.shape[1], device=device).unsqueeze(0).expand_as(ids)
183
+ return (attn.bool() & ~is_bad & (idxs != eos_pos.unsqueeze(1))) # [B, T]
184
+
185
+
186
+ # ---- LoRA utilities -----------------------------------------------------
187
+ def _infer_lora_targets(self, model: nn.Module):
188
+ """
189
+ Heuristic for LLaMA/decoder stacks:
190
+ prefer attention proj + MLP proj layers.
191
+ Returns base names that PEFT will match in module paths.
192
+ """
193
+ common = ["q_proj", "k_proj", "v_proj", "o_proj", # attn
194
+ "wq", "wk", "wv", "wo", # alt naming
195
+ "gate_proj", "up_proj", "down_proj"] # MLP
196
+ # Keep only those that actually occur
197
+ present = set()
198
+ for name, _ in model.named_modules():
199
+ base = name.split(".")[-1].lower()
200
+ for t in common:
201
+ if base == t:
202
+ present.add(t)
203
+ # If nothing matches (unusual naming), fall back to all Linear in attention/MLP blocks
204
+ if not present:
205
+ for name, mod in model.named_modules():
206
+ if isinstance(mod, nn.Linear) and any(s in name.lower() for s in ["attn", "attention", "mlp", "ffn"]):
207
+ present.add(name.split(".")[-1])
208
+ return sorted(list(present))
209
+
210
+ # --- text-only ---
211
+ def encode_text(self, texts: list[str]):
212
+ toks = self.tok(
213
+ texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_txt_len
214
+ )
215
+ toks = {k: v.to(self.backbone.device) for k, v in toks.items()}
216
+ out = self.backbone(
217
+ input_ids=toks["input_ids"],
218
+ attention_mask=toks["attention_mask"],
219
+ return_dict=True,
220
+ output_hidden_states=True, # <-- required for Qwen2.5-VL
221
+ use_cache=False, # <-- safer with LoRA / grad ckpt
222
+ )
223
+ seq = self._final_token_features(out) # [B, T, D_t]
224
+ attn = toks["attention_mask"].bool()
225
+ ids = toks["input_ids"].long()
226
+ return seq, attn, ids
227
+
228
+
229
+ # Add inside PLMLanguageAdapter
230
+ def _final_token_features(self, out):
231
+ # Prefer hidden_states[-1] (decoder-only models usually don't return last_hidden_state)
232
+ hs = getattr(out, "hidden_states", None)
233
+ if hs is not None and len(hs) > 0:
234
+ return hs[-1]
235
+ lh = getattr(out, "last_hidden_state", None)
236
+ if lh is not None:
237
+ return lh
238
+ raise RuntimeError(
239
+ "Model output has neither last_hidden_state nor hidden_states. "
240
+ "Pass output_hidden_states=True to the forward call."
241
+ )
242
+
243
+
244
+ # --- batched V+L (your Point 1 version) ---
245
+ def encode_text_image(self, texts: list[str], image_paths: list[str]):
246
+ assert self.processor is not None and len(texts) == len(image_paths) and len(texts) > 0
247
+ device = self.backbone.device
248
+ proj_dtype = self.to_sparse.weight.dtype
249
+
250
+ def truncate_text(txt: str) -> str:
251
+ toks = self.tok(txt or "", return_tensors="pt", padding=False, truncation=True,
252
+ max_length=getattr(self, "max_txt_len", 256), add_special_tokens=False)
253
+ return self.tok.decode(toks["input_ids"][0], skip_special_tokens=True)
254
+
255
+ conversations = [[{
256
+ "role": "user",
257
+ "content": [{"type": "image", "url": p}, {"type": "text", "text": truncate_text(t)}],
258
+ }] for t, p in zip(texts, image_paths)]
259
+
260
+ inputs = self.processor.apply_chat_template(
261
+ conversations,
262
+ add_generation_prompt=False,
263
+ tokenize=True,
264
+ return_dict=True,
265
+ return_tensors="pt",
266
+ padding=True,
267
+ truncation=False, # keep image tokens intact
268
+ images_kwargs={
269
+ "do_resize": True,
270
+ "size": {"height": 256, "width": 256},
271
+ "disable_grouping": False, # allow efficient vision batching
272
+ },
273
+ # pad_to_multiple_of=8, # uncomment if your processor supports it
274
+ )
275
+
276
+ for k, v in list(inputs.items()):
277
+ if torch.is_tensor(v):
278
+ inputs[k] = v.to(device, non_blocking=True)
279
+
280
+ out = self.backbone(
281
+ **inputs,
282
+ return_dict=True,
283
+ output_hidden_states=True, # <-- required
284
+ use_cache=False, # <-- safer with LoRA / grad ckpt
285
+ )
286
+
287
+ seq = self._final_token_features(out).to(proj_dtype) # [B, T, D_t]
288
+ attn = inputs["attention_mask"].to(torch.bool) # [B, T]
289
+ ids = inputs["input_ids"].to(torch.long) # [B, T]
290
+ return seq, attn, ids
291
+
292
+
293
+
294
+ def forward(self, texts: list[str], H: int, W: int, image_paths: list[str] | None = None):
295
+ import time
296
+ # start = time.time()
297
+ # Route to V+L or text-only encoder
298
+ if image_paths is not None and self.processor is not None:
299
+ seq, attn, ids = self.encode_text_image(texts, image_paths) # [B, T, D_t]
300
+ else:
301
+ seq, attn, ids = self.encode_text(texts) # [B, T, D_t]
302
+
303
+ B, T, D_t = seq.shape
304
+ device = seq.device
305
+
306
+ # print("Shape of seq:", seq.shape)
307
+
308
+ # match projection dtype
309
+ proj_dtype = self.to_sparse.weight.dtype
310
+ seq = seq.to(proj_dtype)
311
+
312
+ # Normalize token embeddings
313
+ seq = self.txt_norm(seq)
314
+
315
+ # ---- find EOS per sequence ----
316
+ eos_id = self.tok.eos_token_id
317
+ if eos_id is None:
318
+ eos_mask = torch.zeros_like(ids, dtype=torch.bool, device=device)
319
+ else:
320
+ eos_mask = (ids == eos_id).to(device)
321
+
322
+ idxs = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
323
+ valid_counts = attn.long().sum(dim=1)
324
+ fallback = (valid_counts - 1).clamp(min=0)
325
+ eos_pos = torch.where(eos_mask, idxs, torch.full_like(idxs, -1)).amax(dim=1)
326
+ eos_pos = torch.where(eos_pos >= 0, eos_pos, fallback) # [B]
327
+
328
+ # Dense = EOS vector
329
+ eos_vec = seq[torch.arange(B, device=device), eos_pos] # [B, D_t]
330
+
331
+ # ---- sparse = TEXT token positions only (exclude image + all special + EOS) ----
332
+ non_eos_mask = self._text_positions_mask(ids, attn, eos_pos) # [B, T]
333
+ if self.n_sparse_tokens > 0:
334
+ N = self.n_sparse_tokens
335
+ else:
336
+ N = int(non_eos_mask.sum(dim=1).max().item())
337
+ if N == 0:
338
+ N = 1
339
+
340
+ idx_mat = torch.full((B, N), -1, device=device, dtype=torch.long)
341
+ for b in range(B):
342
+ pos = torch.nonzero(non_eos_mask[b], as_tuple=False).squeeze(-1)
343
+ take = pos[:N]
344
+ idx_mat[b, :take.numel()] = take
345
+
346
+ safe_idx = idx_mat.clamp(min=0)
347
+ sparse_tok = seq[torch.arange(B, device=device).unsqueeze(-1), safe_idx] # [B, N, D_t]
348
+ valid_mask = (idx_mat >= 0).unsqueeze(-1).to(sparse_tok.dtype)
349
+ sparse_tok = sparse_tok * valid_mask # zero-pad
350
+
351
+ # Project
352
+ sparse = self.to_sparse(sparse_tok) * self.scale # [B, N, C]
353
+
354
+ # Dense projection from EOS only
355
+ if self.use_dense_bias:
356
+ bias = self.to_dense(eos_vec) * self.temp.clamp(min=0.01) # [B, C]
357
+ dense = bias.unsqueeze(-1).unsqueeze(-1).expand(B, bias.shape[-1], H, W)
358
+ else:
359
+ C = self.to_sparse.out_features
360
+ dense = torch.zeros(B, C, H, W, device=device, dtype=proj_dtype)
361
+
362
+ # end = time.time()
363
+ # print(f"PLM Adapter forward time: {end - start:.3f} seconds")
364
+
365
+ return sparse, dense
366
+
367
+
368
+
369
+ # -------- Save / Load LoRA adapters only --------
370
+ def save_lora(self, out_dir: str):
371
+ """
372
+ Saves only the LoRA adapters (and PEFT config). Use with PeftModel.
373
+ """
374
+ if not self.peft_enabled:
375
+ raise RuntimeError("LoRA is not enabled.")
376
+ self.backbone.save_pretrained(out_dir)
377
+
378
+ def load_lora(self, adapter_dir: str):
379
+ """
380
+ Loads adapters onto the *current* backbone weights.
381
+ """
382
+ if PeftModel is None:
383
+ raise ImportError("peft is not installed. `pip install peft`")
384
+ # If already a PeftModel, just load the new adapter weights.
385
+ if isinstance(self.backbone, PeftModel):
386
+ self.backbone.load_adapter(adapter_dir, adapter_name="default", is_trainable=True)
387
+ self.backbone.set_adapter("default")
388
+ else:
389
+ self.backbone = PeftModel.from_pretrained(self.backbone, adapter_dir, is_trainable=True)
390
+ self.peft_enabled = True