kevin510 commited on
Commit
3ab8676
Β·
verified Β·
1 Parent(s): f22e168

Upload modeling_friday.py

Browse files
Files changed (1) hide show
  1. modeling_friday.py +1243 -0
modeling_friday.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IMAGE_TOKEN = "<image>"
3
+ IMG_START_TOKEN = "<img_start>"
4
+ IMG_END_TOKEN = "<img_end>"
5
+ IGNORE_INDEX = -100
6
+ PAD_FOR_EOS = -300
7
+
8
+
9
+
10
+
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from PIL import Image
16
+
17
+
18
+ import torch
19
+
20
+ def mask_token_segment(
21
+ start_id: int,
22
+ end_id: int,
23
+ input_ids: torch.Tensor,
24
+ fill_value: int = -100):
25
+ """
26
+ Replace *every* token from each `start_id` **through** its matching `end_id`
27
+ (boundaries included) with `fill_value`. Any spans that start with some
28
+ other token are left untouched.
29
+
30
+ Works on CUDA, TorchScript, batched via vmap, etc.β€”no Python loops.
31
+ """
32
+ if input_ids.dim() != 1:
33
+ raise ValueError("`input_ids` must be 1-D")
34
+
35
+ device = input_ids.device
36
+ n = input_ids.size(0)
37
+
38
+ # where the *target* start-tokens and end-tokens sit
39
+ start_pos = (input_ids == start_id).nonzero(as_tuple=True)[0] # ascending
40
+ end_pos = (input_ids == end_id).nonzero(as_tuple=True)[0] # ascending
41
+
42
+ if start_pos.numel() == 0:
43
+ return input_ids.clone()
44
+
45
+ # ── pair every start with the first end that comes *after* it ────────────────
46
+ # searchsorted gives the insertion index into the (sorted) end positions
47
+ idx_in_end = torch.searchsorted(end_pos, start_pos, right=False)
48
+
49
+ have_match = idx_in_end < end_pos.size(0) # safety: drop unmatched
50
+ start_pos = start_pos[have_match]
51
+ end_pos = end_pos[idx_in_end[have_match]]
52
+
53
+ # (rare) guard against pathological orderings
54
+ keep = end_pos > start_pos
55
+ start_pos, end_pos = start_pos[keep], end_pos[keep]
56
+
57
+ if start_pos.numel() == 0:
58
+ return input_ids
59
+
60
+ # ── differential β€œscan-line” trick to build the span mask in O(N) ───────────
61
+ # +1 at each start index, -1 at the element *after* each end
62
+ delta = torch.zeros(n + 1, dtype=torch.int8, device=device)
63
+ delta[start_pos] += 1
64
+ delta[end_pos + 1] -= 1 # +1 is safe because delta is length n+1
65
+
66
+ inside = torch.cumsum(delta[:-1], dim=0) > 0 # boolean mask, incl. boundaries
67
+
68
+ # ── apply ────────────────────────────────────────────────────────────────────
69
+ out = input_ids.clone()
70
+ out[inside] = fill_value
71
+ return out
72
+
73
+
74
+
75
+ def maybe_zero_3(param, ignore_status=False, name=None):
76
+ from deepspeed import zero
77
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
78
+ if hasattr(param, "ds_id"):
79
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
80
+ if not ignore_status:
81
+ print(name, 'no ignore status')
82
+ with zero.GatheredParameters([param]):
83
+ param = param.data.detach().cpu().clone()
84
+ else:
85
+ param = param.detach().cpu().clone()
86
+ return param
87
+
88
+
89
+ # Borrowed from peft.util.get_peft_model_state_dict
90
+ def get_peft_state_maybe_zero_3(named_params, bias):
91
+ if bias == "none":
92
+ to_return = {k: t for k, t in named_params if "lora_" in k}
93
+ elif bias == "all":
94
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
95
+ elif bias == "lora_only":
96
+ to_return = {}
97
+ maybe_lora_bias = {}
98
+ lora_bias_names = set()
99
+ for k, t in named_params:
100
+ if "lora_" in k:
101
+ to_return[k] = t
102
+ bias_name = k.split("lora_")[0] + "bias"
103
+ lora_bias_names.add(bias_name)
104
+ elif "bias" in k:
105
+ maybe_lora_bias[k] = t
106
+ for k, t in maybe_lora_bias:
107
+ if bias_name in lora_bias_names:
108
+ to_return[bias_name] = t
109
+ else:
110
+ raise NotImplementedError
111
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
112
+ return to_return
113
+
114
+
115
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
116
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
117
+ if require_grad_only:
118
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
119
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
120
+ return to_return
121
+
122
+
123
+ def find_all_linear_names(modules):
124
+ lora_module_names = set()
125
+ for name, module in modules():
126
+ if isinstance(module, torch.nn.Linear):
127
+ names = name.split('.')
128
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
129
+
130
+ if 'lm_head' in lora_module_names: # needed for 16-bit
131
+ lora_module_names.remove('lm_head')
132
+ return list(lora_module_names)
133
+
134
+
135
+ def expand2square(pil_img, background_color):
136
+ width, height = pil_img.size
137
+ if width == height:
138
+ return pil_img
139
+ elif width > height:
140
+ result = Image.new(pil_img.mode, (width, width), background_color)
141
+ result.paste(pil_img, (0, (width - height) // 2))
142
+ return result
143
+ else:
144
+ result = Image.new(pil_img.mode, (height, height), background_color)
145
+ result.paste(pil_img, ((height - width) // 2, 0))
146
+ return result
147
+
148
+ def pad_and_stack(img_list, pad_value=0.0):
149
+ """
150
+ img_list : list[Tensor] each (C, H, W) already *normalised*
151
+ pad_value: float or tuple/list of 3 floats (one per channel)
152
+ Use 0.0 if your processor has already centred to mean 0.
153
+ Returns
154
+ -------
155
+ batch : Tensor (B, C, H_max, W_max)
156
+ """
157
+
158
+ # 1. target square size ---------------------------------------------------
159
+ h_max = max(t.shape[1] for t in img_list)
160
+ w_max = max(t.shape[2] for t in img_list)
161
+ H, W = max(h_max, w_max), max(h_max, w_max)
162
+
163
+ # 2. create padded copies -------------------------------------------------
164
+ padded = []
165
+ for img in img_list:
166
+ c, h, w = img.shape
167
+ canvas = img.new_full((c, H, W), pad_value) # filled with mean/zeros
168
+ canvas[:, :h, :w] = img # top-left corner
169
+ padded.append(canvas)
170
+
171
+ return torch.stack(padded, 0) # (B,C,H,W)
172
+
173
+
174
+
175
+
176
+
177
+ # ------------------------------------------------------------------------------------------
178
+ # Copyright (c) 2024 Baifeng Shi.
179
+ # All rights reserved.
180
+ #
181
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
182
+ # ------------------------------------------------------------------------------------------
183
+
184
+ import torch
185
+
186
+ def split_chessboard(x, num_split):
187
+ """
188
+ x: b * c * h * w
189
+ Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
190
+ """
191
+ B, C, H, W = x.shape
192
+ assert H % num_split == 0 and W % num_split == 0
193
+ h, w = H // num_split, W // num_split
194
+ x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
195
+ return x_split
196
+
197
+ def merge_chessboard(x, num_split):
198
+ """
199
+ x: b * c * h * w
200
+ Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
201
+ (inverse of split_chessboard)
202
+ """
203
+ B, C, H, W = x.shape
204
+ assert B % (num_split**2) == 0
205
+ b = B // (num_split**2)
206
+ x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1)
207
+ for i in range(num_split)], dim=-2)
208
+ return x_merge
209
+
210
+ def batched_forward(model, x, batch_size=-1):
211
+ if batch_size == -1:
212
+ return model(x)
213
+ else:
214
+ x_batched = x.split(batch_size)
215
+ outs = [model(x) for x in x_batched]
216
+ return torch.cat(outs, dim=0)
217
+
218
+
219
+
220
+
221
+
222
+
223
+
224
+ # ------------------------------------------------------------------------------------------
225
+ # Copyright (c) 2024 Baifeng Shi.
226
+ # All rights reserved.
227
+ #
228
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
229
+ # ------------------------------------------------------------------------------------------
230
+
231
+ import math
232
+ import torch
233
+ import torch.nn.functional as F
234
+ from einops import rearrange
235
+
236
+ def multiscale_forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0,
237
+ output_shape='bnc', split_forward=False):
238
+
239
+ # print(f"Input shape: {input.shape}")
240
+
241
+ assert input.dim() == 4, "Input image must be in the shape of BxCxHxW."
242
+ assert input.shape[2] == input.shape[3], "Currently only square images are supported."
243
+ assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)."
244
+ assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token."
245
+
246
+ b, c, input_size, _ = input.shape
247
+
248
+ # image size for each scale
249
+ assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes."
250
+ img_sizes = img_sizes or [int(input_size * scale) for scale in scales]
251
+
252
+ # prepare multiscale inputs
253
+ max_split_size = max_split_size or input_size # The maximum size of each split of image. Set as the input size by default
254
+ num_splits = [math.ceil(size / max_split_size) for size in img_sizes] # number of splits each scale
255
+ input_multiscale = []
256
+ for size, num_split in zip(img_sizes, num_splits):
257
+ x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype)
258
+ x = split_chessboard(x, num_split=num_split)
259
+ input_multiscale.append(x)
260
+
261
+ # run feedforward on each scale
262
+ outs_multiscale = [batched_forward(model, x, b) if split_forward else model(x) for x in input_multiscale]
263
+ if num_prefix_token > 0:
264
+ outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]
265
+ outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale]
266
+ if output_shape == 'bnc':
267
+ outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5))
268
+ for out in outs_multiscale]
269
+
270
+ # merge outputs of different splits for each scale separately
271
+ outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)]
272
+
273
+ # interpolate outputs from different scales and concat together
274
+ output_size = outs_multiscale[resize_output_to_idx].shape[-2]
275
+ out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size,
276
+ mode='area').to(outs_multiscale[i].dtype)
277
+ for i in range(len(outs_multiscale))], dim=1)
278
+ if output_shape == 'bnc':
279
+ out = rearrange(out, 'b c h w -> b (h w) c')
280
+ if num_prefix_token > 0:
281
+ # take the mean of prefix tokens from different splits for each scale
282
+ outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale]
283
+ out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1)
284
+ out = torch.cat([out_prefix_multiscale, out], dim=1)
285
+
286
+ return out
287
+
288
+
289
+
290
+
291
+
292
+ import torch
293
+ import torch.nn as nn
294
+
295
+ class MLPAdapter(nn.Module):
296
+
297
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, activation='gelu', checkpoint_path=None, device=None, **kwargs):
298
+ """
299
+ Initialize the MLPAdapter with the given dimensions and activation function.
300
+
301
+ Args:
302
+ input_dim (int): Input dimension.
303
+ hidden_dim (int): Hidden dimension.
304
+ output_dim (int): Output dimension.
305
+ layers (int): Number of layers in the MLP.
306
+ activation (str): Activation function to use ('gelu' or 'relu').
307
+ """
308
+ super().__init__()
309
+ self.num_layers = num_layers
310
+ self.activation = activation
311
+ self.output_dim = output_dim
312
+
313
+ # Define the first layer
314
+ layers_list = [nn.Linear(input_dim, hidden_dim, device=device)]
315
+ if activation == 'gelu':
316
+ layers_list.append(nn.GELU())
317
+ elif activation == 'relu':
318
+ layers_list.append(nn.ReLU())
319
+ else:
320
+ raise ValueError("Unsupported activation function. Use 'gelu' or 'relu'.")
321
+
322
+ # Define the subsequent layers
323
+ for _ in range(1, num_layers):
324
+ layers_list.append(nn.Linear(hidden_dim, hidden_dim, device=device))
325
+ if activation == 'gelu':
326
+ layers_list.append(nn.GELU())
327
+ elif activation == 'relu':
328
+ layers_list.append(nn.ReLU())
329
+
330
+ # Define the final output layer
331
+ layers_list.append(nn.Linear(hidden_dim, output_dim, device=device))
332
+ self.mlp = nn.Sequential(*layers_list)
333
+
334
+ # Load checkpoint if provided
335
+ if checkpoint_path:
336
+ self.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)
337
+ print(f"Loaded MLPAdapter from {checkpoint_path}")
338
+
339
+ if device:
340
+ self.to(device)
341
+
342
+ def forward(self, x):
343
+ """
344
+ Forward pass through the MLPAdapter.
345
+
346
+ Args:
347
+ x (torch.Tensor): Input tensor.
348
+
349
+ Returns:
350
+ torch.Tensor: Output tensor after passing through the MLP.
351
+ """
352
+ return self.mlp(x)
353
+
354
+
355
+
356
+
357
+ import torch
358
+ import torch.nn as nn
359
+ import torch.nn.functional as F
360
+
361
+ import PIL.Image
362
+ from typing import List
363
+
364
+ from transformers import AutoModel, AutoImageProcessor
365
+
366
+
367
+ class FastVitVisionTower(nn.Module):
368
+ def __init__(self, pretrained_model_name_or_path, model_params={}, pad_to_square=True, **kwargs):
369
+ super().__init__()
370
+
371
+ self.is_loaded = False
372
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
373
+ self.model_params = model_params
374
+ self.pad_to_square = pad_to_square
375
+ self.load_model()
376
+
377
+ @property
378
+ def output_dim(self):
379
+ return self.vision_tower.config.embed_dim if self.vision_tower else None
380
+
381
+ def load_model(self):
382
+ if self.is_loaded:
383
+ return
384
+ self.image_processor = AutoImageProcessor.from_pretrained(self.pretrained_model_name_or_path)
385
+ self.image_processor.crop_size = self.image_processor.size
386
+ self.vision_tower = AutoModel.from_pretrained(
387
+ self.pretrained_model_name_or_path,
388
+ **self.model_params,
389
+ )
390
+ self.vision_tower.requires_grad_(False)
391
+
392
+ self.is_loaded = True
393
+
394
+ def preprocess_images(self, imgs: List[PIL.Image.Image], pad_and_stack_tensors=True) -> torch.Tensor:
395
+ img_mean = tuple(int(x * 255) for x in self.image_processor.image_mean)
396
+ if self.pad_to_square:
397
+ imgs = [expand2square(img, img_mean) for img in imgs]
398
+
399
+ imgs = [self.image_processor(img, do_resize=True, do_center_crop=False, return_tensors="pt")['pixel_values'][0] for img in imgs]
400
+
401
+
402
+ if pad_and_stack_tensors:
403
+ imgs = pad_and_stack(imgs, pad_value=0.0)
404
+ imgs = imgs.to(dtype=torch.float32, device=self.device)
405
+
406
+ return imgs
407
+
408
+ def forward(self, images):
409
+ if type(images) is list:
410
+ image_features = []
411
+ for image in images:
412
+ image_feature = self.vision_tower(
413
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0)
414
+ )
415
+ image_features.append(image_feature)
416
+ else:
417
+ image_features = self.vision_tower(
418
+ images.to(device=self.device, dtype=self.dtype),
419
+ )
420
+
421
+ return image_features
422
+
423
+ @property
424
+ def dummy_feature(self):
425
+ return torch.zeros(1, self.embed_dim, device=self.device, dtype=self.dtype)
426
+
427
+ @property
428
+ def dtype(self):
429
+ return self.vision_tower.dtype
430
+
431
+ @property
432
+ def device(self):
433
+ return self.vision_tower.device
434
+
435
+ @property
436
+ def config(self):
437
+ if self.is_loaded:
438
+ return self.vision_tower.config
439
+ else:
440
+ return self.cfg_only
441
+
442
+ @property
443
+ def hidden_size(self):
444
+ return self.config.embed_dim
445
+
446
+ @property
447
+ def num_patches(self):
448
+ return (self.config.image_size // self.config.patch_size) ** 2
449
+
450
+
451
+ class FastVitVisionTowerS2(FastVitVisionTower):
452
+ def __init__(self, pretrained_model_name_or_path, s2_scales, model_params={}, **kwargs):
453
+ self.s2_scales = list(map(int, s2_scales.split(',')))
454
+ self.s2_scales.sort()
455
+ self.s2_split_size = self.s2_scales[0]
456
+ self.s2_image_size = self.s2_scales[-1]
457
+
458
+ super().__init__(pretrained_model_name_or_path, model_params)
459
+
460
+ self.multiscale_forward = multiscale_forward
461
+
462
+ @property
463
+ def output_dim(self):
464
+ return (2*self.vision_tower.config.embed_dim) if self.vision_tower else None
465
+
466
+ def load_model(self):
467
+ if self.is_loaded:
468
+ return
469
+
470
+ super().load_model()
471
+ self.image_processor.size = self.image_processor.crop_size = {
472
+ "height": self.s2_image_size,
473
+ "width": self.s2_image_size
474
+ }
475
+
476
+ def forward_feature(self, images):
477
+ image_size = self.vision_tower.config.image_size
478
+ if images.shape[2] != image_size or images.shape[3] != image_size:
479
+ images = F.interpolate(
480
+ images,
481
+ size=(image_size, image_size),
482
+ mode="bilinear",
483
+ align_corners=False,
484
+ antialias=True
485
+ )
486
+
487
+ return self.vision_tower(
488
+ images.to(device=self.device, dtype=self.dtype),
489
+ )
490
+
491
+ def forward(self, images):
492
+ if type(images) is list:
493
+ image_features = []
494
+ for image in images:
495
+ image_feature = self.multiscale_forward(
496
+ self.forward_feature,
497
+ image.unsqueeze(0),
498
+ img_sizes=self.s2_scales,
499
+ max_split_size=self.s2_split_size
500
+ )
501
+ image_features.append(image_feature)
502
+ else:
503
+ image_features = self.multiscale_forward(
504
+ self.forward_feature,
505
+ images,
506
+ img_sizes=self.s2_scales,
507
+ max_split_size=self.s2_split_size
508
+ )
509
+
510
+ return image_features
511
+
512
+ @property
513
+ def hidden_size(self):
514
+ return self.config.embed_dim * len(self.s2_scales)
515
+
516
+
517
+
518
+
519
+
520
+ import torch
521
+ import torch.nn as nn
522
+
523
+ import PIL.Image
524
+ from typing import List
525
+
526
+ from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig
527
+
528
+
529
+ class SiglipVisionTower(nn.Module):
530
+ def __init__(self, pretrained_model_name_or_path, model_params={}, pad_to_square=True, **kwargs):
531
+ super().__init__()
532
+
533
+ self.is_loaded = False
534
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
535
+ self.model_params = model_params
536
+ self.pad_to_square = pad_to_square
537
+ self.select_layer = -2
538
+ self.load_model()
539
+
540
+ @property
541
+ def output_dim(self):
542
+ return self.vision_tower.config.hidden_size if self.vision_tower else None
543
+
544
+ def load_model(self):
545
+ if self.is_loaded:
546
+ return
547
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.pretrained_model_name_or_path)
548
+ self.image_processor.crop_size = self.image_processor.size
549
+ self.vision_tower = SiglipVisionModel.from_pretrained(
550
+ self.pretrained_model_name_or_path,
551
+ **self.model_params,
552
+ )
553
+ self.vision_tower.requires_grad_(False)
554
+
555
+ self.is_loaded = True
556
+
557
+ def preprocess_images(self, imgs: List[PIL.Image.Image], pad_and_stack_tensors=True) -> torch.Tensor:
558
+ img_mean = tuple(int(x * 255) for x in self.image_processor.image_mean)
559
+ if self.pad_to_square:
560
+ imgs = [expand2square(img, img_mean) for img in imgs]
561
+ imgs = [self.image_processor(img, return_tensors="pt")['pixel_values'][0] for img in imgs]
562
+
563
+ if pad_and_stack_tensors:
564
+ imgs = pad_and_stack(imgs, pad_value=0.0)
565
+ imgs = imgs.to(dtype=torch.float32, device=self.device)
566
+
567
+ return imgs
568
+
569
+ def feature_select(self, image_forward_outs):
570
+ image_features = image_forward_outs.hidden_states[self.select_layer]
571
+
572
+ return image_features
573
+
574
+ def forward(self, images):
575
+ if type(images) is list:
576
+ image_features = []
577
+ for image in images:
578
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
579
+ output_hidden_states=True)
580
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
581
+ image_features.append(image_feature)
582
+ else:
583
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
584
+ output_hidden_states=True)
585
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
586
+
587
+ return image_features
588
+
589
+ @property
590
+ def dummy_feature(self):
591
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
592
+
593
+ @property
594
+ def dtype(self):
595
+ return self.vision_tower.dtype
596
+
597
+ @property
598
+ def device(self):
599
+ return self.vision_tower.device
600
+
601
+ @property
602
+ def config(self):
603
+ if self.is_loaded:
604
+ return self.vision_tower.config
605
+ else:
606
+ return self.cfg_only
607
+
608
+ @property
609
+ def hidden_size(self):
610
+ return self.config.hidden_size
611
+
612
+ @property
613
+ def num_patches(self):
614
+ return (self.config.image_size // self.config.patch_size) ** 2
615
+
616
+
617
+ class SiglipVisionTowerS2(SiglipVisionTower):
618
+ def __init__(self, pretrained_model_name_or_path, s2_scales, model_params={}, **kwargs):
619
+ self.s2_scales = list(map(int, s2_scales.split(',')))
620
+ self.s2_scales.sort()
621
+ self.s2_split_size = self.s2_scales[0]
622
+ self.s2_image_size = self.s2_scales[-1]
623
+
624
+ super().__init__(pretrained_model_name_or_path, model_params)
625
+
626
+ self.multiscale_forward = multiscale_forward
627
+
628
+ self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
629
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
630
+
631
+ @property
632
+ def output_dim(self):
633
+ return (2*self.vision_tower.config.hidden_size) if self.vision_tower else None
634
+
635
+ def load_model(self):
636
+ if self.is_loaded:
637
+ return
638
+
639
+ super().load_model()
640
+ self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
641
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
642
+
643
+ def forward_feature(self, images):
644
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
645
+ output_hidden_states=True)
646
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
647
+ return image_features
648
+
649
+ def forward(self, images):
650
+ if type(images) is list:
651
+ image_features = []
652
+ for image in images:
653
+ image_feature = self.multiscale_forward(
654
+ self.forward_feature,
655
+ image.unsqueeze(0),
656
+ img_sizes=self.s2_scales,
657
+ max_split_size=self.s2_split_size
658
+ )
659
+ image_features.append(image_feature)
660
+ else:
661
+ image_features = self.multiscale_forward(
662
+ self.forward_feature,
663
+ images,
664
+ img_sizes=self.s2_scales,
665
+ max_split_size=self.s2_split_size
666
+ )
667
+
668
+ return image_features
669
+
670
+ @property
671
+ def hidden_size(self):
672
+ return self.config.hidden_size * len(self.s2_scales)
673
+
674
+
675
+
676
+
677
+
678
+ from __future__ import annotations
679
+
680
+ import torch
681
+ import torch.nn as nn
682
+ import torch.nn.functional as F
683
+ from torchvision import transforms
684
+
685
+ from typing import List, Tuple, Optional, Union
686
+
687
+ import PIL
688
+
689
+ from transformers import AutoTokenizer, AutoConfig
690
+ from transformers.modeling_outputs import CausalLMOutputWithPast
691
+
692
+
693
+ DEFAULT_CFG_SPECIAL_TOKENS = {
694
+ "image_token_id": 200029,
695
+ "image_start_token_id": 200030,
696
+ "image_end_token_id": 200031,
697
+ }
698
+ DEFAULT_CFG_VISION_TOWER = {
699
+ "pretrained_model_name_or_path": "kevin510/fast-vit-hd",
700
+ "type": "fastvit",
701
+ "s2_scales": "512,1024",
702
+ "use_s2": True,
703
+ "pad_to_square": True,
704
+ "freeze": False,
705
+ "model_params": { "trust_remote_code": True }
706
+ }
707
+ DEFAULT_CFG_VISION_ADAPTER = {
708
+ "input_dim": 6144,
709
+ "hidden_dim": 3072,
710
+ "output_dim": 3072,
711
+ "layers": 2,
712
+ "activation": "gelu",
713
+ "freeze": False,
714
+ }
715
+
716
+
717
+ class FridayConfig(Phi3Config):
718
+ model_type = "friday"
719
+
720
+ def __init__(self,
721
+ base_model_name_or_path: str | None = "microsoft/Phi-4-mini-reasoning",
722
+ delay_load=False,
723
+ tokenizer_model_max_length=None,
724
+ **kwargs
725
+ ):
726
+ base_kwargs = {}
727
+ if base_model_name_or_path is not None:
728
+ base_cfg = AutoConfig.from_pretrained(
729
+ base_model_name_or_path,
730
+ trust_remote_code=True, # Phi‑4 uses custom code in the repo
731
+ )
732
+ base_kwargs = base_cfg.to_dict()
733
+
734
+ merged = {**base_kwargs, **kwargs}
735
+ self.delay_load = delay_load
736
+ self.tokenizer_model_max_length = tokenizer_model_max_length
737
+
738
+ self._cfg_vision_tower = DEFAULT_CFG_VISION_TOWER.copy()
739
+ if "cfg_vision_tower" in kwargs:
740
+ self._cfg_vision_tower.update(kwargs["cfg_vision_tower"])
741
+
742
+ self._cfg_vision_adapter = DEFAULT_CFG_VISION_ADAPTER.copy()
743
+ if "cfg_vision_adapter" in kwargs:
744
+ self._cfg_vision_adapter.update(kwargs["cfg_vision_adapter"])
745
+
746
+ self._cfg_special_tokens = DEFAULT_CFG_SPECIAL_TOKENS.copy()
747
+ if "cfg_special_tokens" in kwargs:
748
+ self._cfg_special_tokens.update(kwargs["cfg_special_tokens"])
749
+
750
+ super().__init__(**merged)
751
+
752
+
753
+ @property
754
+ def cfg_vision_tower(self):
755
+ return self._cfg_vision_tower
756
+
757
+ @cfg_vision_tower.setter
758
+ def cfg_vision_tower(self, value):
759
+ if not value:
760
+ raise ValueError("Name cannot be empty")
761
+ self._cfg_vision_tower.update(value)
762
+
763
+
764
+ @property
765
+ def cfg_vision_adapter(self):
766
+ return self._cfg_vision_adapter
767
+
768
+ @cfg_vision_adapter.setter
769
+ def cfg_vision_adapter(self, value):
770
+ if not value:
771
+ raise ValueError("Name cannot be empty")
772
+ self._cfg_vision_adapter.update(value)
773
+
774
+ @property
775
+ def cfg_special_tokens(self):
776
+ return self._cfg_special_tokens
777
+
778
+ @cfg_special_tokens.setter
779
+ def cfg_special_tokens(self, value):
780
+ if not value:
781
+ raise ValueError("Name cannot be empty")
782
+ self._cfg_special_tokens.update(value)
783
+
784
+
785
+ class FridayModel(Phi3Model):
786
+ config_class = FridayConfig
787
+
788
+ def __init__(self, config: FridayConfig):
789
+ super().__init__(config)
790
+
791
+ self.cfg_vision_adapter = config.cfg_vision_adapter
792
+ self.cfg_vision_tower = config.cfg_vision_tower
793
+
794
+ self.vision_tower = None
795
+ self.mm_projector = None
796
+ if not config.delay_load:
797
+ self.initialize_vision_modules()
798
+
799
+ def get_vision_tower(self):
800
+ return self.vision_tower
801
+
802
+ def initialize_vision_modules(self):
803
+ if self.vision_tower is not None:
804
+ return
805
+
806
+ if self.cfg_vision_tower.get("type", "siglip").lower() == "siglip":
807
+ if self.cfg_vision_tower.get("use_s2", True):
808
+ self.vision_tower = SiglipVisionTowerS2(**self.cfg_vision_tower)
809
+ else:
810
+ self.vision_tower = SiglipVisionTower(**self.cfg_vision_tower)
811
+ elif self.cfg_vision_tower.get("type", "siglip").lower() == "fastvit":
812
+ if self.cfg_vision_tower.get("use_s2", True):
813
+ self.vision_tower = FastVitVisionTowerS2(**self.cfg_vision_tower)
814
+ else:
815
+ self.vision_tower = FastVitVisionTower(**self.cfg_vision_tower)
816
+ else:
817
+ raise ValueError(f"Unsupported vision tower type: {self.cfg_vision_tower.get('type', 'siglip')}. Supported types are 'siglip' and 'fastvit'.")
818
+
819
+ self.vision_tower.load_model()
820
+ self.mm_projector = MLPAdapter(**self.cfg_vision_adapter)
821
+
822
+ if self.cfg_vision_tower.get("freeze", False):
823
+ self.set_vision_tower_requires_grad(False)
824
+
825
+ if self.cfg_vision_adapter.get("freeze", False):
826
+ self.set_vision_adapter_requires_grad(False)
827
+
828
+ def compute_image_features(self, imgs: torch.Tensor) -> torch.Tensor:
829
+ features = self.vision_tower(imgs)
830
+ if isinstance(features, list):
831
+ features = torch.stack(features, dim=1)
832
+ return self.mm_projector(features)
833
+
834
+ def set_vision_tower_requires_grad(self, requires_grad: bool):
835
+ if self.vision_tower is not None:
836
+ for param in self.vision_tower.parameters():
837
+ param.requires_grad = requires_grad
838
+ else:
839
+ raise ValueError("Vision tower is not initialized. Please call initialize_vision_modules() first.")
840
+
841
+ def set_vision_adapter_requires_grad(self, requires_grad: bool):
842
+ if self.mm_projector is not None:
843
+ for param in self.mm_projector.parameters():
844
+ param.requires_grad = requires_grad
845
+ else:
846
+ raise ValueError("Vision adapter is not initialized. Please call initialize_vision_modules() first.")
847
+
848
+ def set_vision_tower_dtype(self, dtype: torch.dtype):
849
+ if self.vision_tower is not None:
850
+ for p in self.vision_tower.parameters():
851
+ p.data = p.data.to(dtype)
852
+ else:
853
+ raise ValueError("Vision tower is not initialized. Please call initialize_vision_modules() first.")
854
+
855
+ def set_vision_adapter_dtype(self, dtype: torch.dtype):
856
+ if self.mm_projector is not None:
857
+ for p in self.mm_projector.parameters():
858
+ p.data = p.data.to(dtype)
859
+ else:
860
+ raise ValueError("Vision adapter is not initialized. Please call initialize_vision_modules() first.")
861
+
862
+ def is_vision_tower_frozen(self):
863
+ if self.vision_tower is not None:
864
+ return all(not p.requires_grad for p in self.vision_tower.parameters())
865
+ else:
866
+ raise ValueError("Vision tower is not initialized. Please call initialize_vision_modules() first.")
867
+
868
+ def is_vision_adapter_frozen(self):
869
+ if self.mm_projector is not None:
870
+ return all(not p.requires_grad for p in self.mm_projector.parameters())
871
+ else:
872
+ raise ValueError("Vision adapter is not initialized. Please call initialize_vision_modules() first.")
873
+
874
+
875
+ class FridayForCausalLM(Phi3ForCausalLM):
876
+ config_class = FridayConfig
877
+
878
+ def __init__(self, config: FridayConfig):
879
+ super().__init__(config)
880
+
881
+ self.config = config
882
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
883
+ self.image_token_id = config.cfg_special_tokens["image_token_id"]
884
+ self.image_start_id = config.cfg_special_tokens["image_start_token_id"]
885
+ self.image_end_id = config.cfg_special_tokens["image_end_token_id"]
886
+
887
+ self.model = FridayModel(config)
888
+ self.post_init()
889
+
890
+ def get_model(self) -> FridayModel:
891
+ return self.model
892
+
893
+ def get_vision_tower(self) -> SiglipVisionTower:
894
+ return self.model.get_vision_tower()
895
+
896
+ def get_vision_adapter(self) -> MLPAdapter:
897
+ return self.model.mm_projector
898
+
899
+ def get_llm_parameters(self, exclude_lora: bool = False):
900
+ return [
901
+ p for n, p in self.named_parameters()
902
+ if "vision_tower" not in n and "mm_projector" not in n and (not exclude_lora or ("lora_" not in n))
903
+ ]
904
+
905
+ def get_llm_named_modules(self):
906
+ return {n: m for n, m in self.named_modules() if "vision_tower" not in n and "mm_projector" not in n}
907
+
908
+ def set_llm_requires_grad(self, requires_grad: bool, exclude_lora: bool = True):
909
+ for n, p in self.named_parameters():
910
+ if exclude_lora and ("lora_A" in n or "lora_B" in n):
911
+ continue
912
+ if "vision_tower" in n or "mm_projector" in n:
913
+ continue
914
+ p.requires_grad = requires_grad
915
+
916
+ def set_vision_tower_requires_grad(self, requires_grad: bool):
917
+ self.model.set_vision_tower_requires_grad(requires_grad)
918
+
919
+ def set_vision_adapter_requires_grad(self, requires_grad: bool):
920
+ self.model.set_vision_adapter_requires_grad(requires_grad)
921
+
922
+ def set_llm_dtype(self, dtype: torch.dtype):
923
+ for p in self.get_llm_parameters():
924
+ p.data = p.data.to(dtype)
925
+
926
+ def set_vision_tower_dtype(self, dtype: torch.dtype):
927
+ self.model.set_vision_tower_dtype(dtype)
928
+
929
+ def set_vision_adapter_dtype(self, dtype: torch.dtype):
930
+ self.model.set_vision_adapter_dtype(dtype)
931
+
932
+ def is_llm_frozen(self):
933
+ return all(not p.requires_grad for p in self.get_llm_parameters())
934
+
935
+ def is_vision_tower_frozen(self):
936
+ return self.model.is_vision_tower_frozen()
937
+
938
+ def is_vision_adapter_frozen(self):
939
+ return self.model.is_vision_adapter_frozen()
940
+
941
+
942
+
943
+ def initialize_vision_modules(self):
944
+ self.model.initialize_vision_modules()
945
+
946
+ def get_multimodal_input_embeddings(self, input_ids, image_features, return_labels=True) -> torch.Tensor:
947
+ emb_start_image_id = self.model.embed_tokens(torch.tensor([self.image_start_id], device=self.device))
948
+ emb_end_image_id = self.model.embed_tokens(torch.tensor([self.image_end_id], device=self.device))
949
+ id_ignore = torch.tensor([IGNORE_INDEX], device=self.device)
950
+
951
+ # repetition‑penalty safety ????
952
+ # input_ids[input_ids == self.image_token_id] = 0
953
+
954
+
955
+ # Iterate over each batch item
956
+ embeds_list, labels_list = [], []
957
+ for batch_id, item_ids in enumerate(input_ids):
958
+
959
+ image_token_positions = (item_ids == self.image_token_id).nonzero(as_tuple=True)[0]
960
+ if len(image_token_positions) != image_features[batch_id].shape[0]:
961
+ raise ValueError(
962
+ f"Mismatch between number of image tokens ({len(image_token_positions)}) and number of image features ({image_features[batch_id].shape[0]})"
963
+ )
964
+
965
+
966
+ cursor = 0
967
+ emb_parts, lbl_parts = [], []
968
+ for indx_image, image_token_pos in enumerate(image_token_positions):
969
+ if image_token_pos > cursor:
970
+ span = item_ids[cursor:image_token_pos]
971
+ emb_parts.append(self.model.embed_tokens(span))
972
+ lbl_parts.append(span)
973
+
974
+ # <image_start>
975
+ emb_parts.append(emb_start_image_id)
976
+ lbl_parts.append(id_ignore)
977
+
978
+ # vision embeddings
979
+ image_tokens = image_features[batch_id][indx_image]
980
+ if image_tokens.shape[0] == 1 and image_tokens.ndim == 3:
981
+ image_tokens = image_tokens.squeeze(0)
982
+ emb_parts.append(image_tokens)
983
+ lbl_parts.append(id_ignore.repeat(image_tokens.shape[0]))
984
+
985
+ # <image_end>
986
+ emb_parts.append(emb_end_image_id)
987
+ lbl_parts.append(id_ignore)
988
+
989
+ cursor = image_token_pos + 1
990
+
991
+ # tail text
992
+ if cursor < item_ids.shape[0]:
993
+ tail = item_ids[cursor:]
994
+ emb_parts.append(self.model.embed_tokens(tail))
995
+ lbl_parts.append(tail)
996
+
997
+ embeds_list.append(torch.cat(emb_parts, dim=0))
998
+ labels_list.append(torch.cat(lbl_parts, dim=0))
999
+
1000
+ return (embeds_list, labels_list) if return_labels else embeds_list
1001
+
1002
+ def prepare_inputs_for_multimodal(
1003
+ self,
1004
+ input_ids: torch.LongTensor,
1005
+ images: List[List[PIL.Image.Image]], # B x N
1006
+ position_ids: Optional[torch.LongTensor],
1007
+ attention_mask: Optional[torch.Tensor],
1008
+ past_key_values: Optional[List[torch.FloatTensor]],
1009
+ labels: Optional[torch.LongTensor],
1010
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.LongTensor], Optional[torch.Tensor], Optional[List[torch.FloatTensor]], torch.Tensor, Optional[torch.Tensor]]:
1011
+
1012
+ # ─────────────────── early return (no image / streaming step) ───────────────────
1013
+ # if we have already processed images and are in a streaming step we can skip the multimodal processing
1014
+ # but we need to ensure the attention mask and position ids are correct
1015
+
1016
+ if past_key_values is not None and attention_mask is not None and input_ids.shape[1] == 1:
1017
+ tgt = past_key_values[-1][-1].shape[-2] + 1
1018
+ attention_mask = torch.cat(
1019
+ [attention_mask,
1020
+ torch.ones((attention_mask.size(0),
1021
+ tgt - attention_mask.size(1)),
1022
+ dtype=attention_mask.dtype,
1023
+ device=attention_mask.device)],
1024
+ dim=1,
1025
+ )
1026
+ position_ids = (attention_mask.sum(dim=1, keepdim=True) - 1).long()
1027
+
1028
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
1029
+
1030
+ # ─────────────────────────── images: (B, N) ───────────────────────────
1031
+ if isinstance(images, list) and isinstance(images[0], list):
1032
+ # images is a list of lists, each containing multiple images, B x N
1033
+ # e.g. [[img1, img2], [img3, img4]]
1034
+ assert len(images) == input_ids.shape[0], f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1035
+ image_features = []
1036
+ for sublst_images in images:
1037
+ if len(sublst_images) == 0:
1038
+ image_features.append(torch.zeros((0, self.get_model().mm_projector.output_dim), device=self.device))
1039
+ else:
1040
+ if isinstance(sublst_images[0], PIL.Image.Image):
1041
+ image_features.append(
1042
+ self.model.compute_image_features(
1043
+ self.model.vision_tower.preprocess_images(sublst_images, pad_and_stack_tensors=True)
1044
+ )
1045
+ )
1046
+ elif isinstance(sublst_images[0], torch.Tensor):
1047
+ # This should be a list of tensors of pre-processed images, [(N X 3 X W x H), ...]
1048
+ image_features.append(
1049
+ self.model.compute_image_features(sublst_images)
1050
+ )
1051
+ elif isinstance(images, list) and isinstance(images[0], PIL.Image.Image):
1052
+ # images is a list of images for a single batch item, 1 x N
1053
+ # e.g. [img1, img2, img3]
1054
+ assert input_ids.shape[0] == 1, f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1055
+ image_features = [
1056
+ self.model.compute_image_features(
1057
+ self.model.vision_tower.preprocess_images(images, pad_and_stack_tensors=True)
1058
+ )
1059
+ ]
1060
+ elif isinstance(images, list) and isinstance(images[0], torch.Tensor):
1061
+ # This should be a list of tensors of pre-processed images, [(N X 3 X W x H), ...]
1062
+ # The list length should match the batch size
1063
+ assert input_ids.shape[0] == len(images), f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1064
+ image_features = [
1065
+ self.model.compute_image_features(imgs) for imgs in images
1066
+ ]
1067
+ elif isinstance(images, PIL.Image.Image):
1068
+ # images is a single image, 1 x 1
1069
+ # e.g. img1
1070
+ assert input_ids.shape[0] == 1, f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1071
+ image_features = [
1072
+ self.model.compute_image_features(
1073
+ self.model.vision_tower.preprocess_images([images])
1074
+ )
1075
+ ]
1076
+ else:
1077
+ raise ValueError(f"Unsupported images format: {type(images)}. Expected list of PIL images, a single PIL image or a Tensor of pre-processed images")
1078
+
1079
+ # ─────────────────────────── image_features: (B x N x D) ───────────────────────────
1080
+ if isinstance(image_features, list):
1081
+ assert input_ids.shape[0] == len(image_features), f"Incorrectly formatted image_features: list length should match batch size"
1082
+ assert isinstance(image_features[0], torch.Tensor), f"Incorrectly formatted image_features: list items should be tensors"
1083
+ elif isinstance(image_features, torch.Tensor):
1084
+ assert input_ids.shape[0] == image_features.shape[0], f"Incorrectly formatted image_features: tensor should match batch size"
1085
+
1086
+
1087
+ # ───────────────────────────── pad handling prelims ──────────────────────────────
1088
+ if attention_mask is None:
1089
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1090
+ else:
1091
+ attention_mask = attention_mask.bool()
1092
+ if position_ids is None:
1093
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
1094
+
1095
+ input_ids_nopad = [ids[mask] for ids, mask in zip(input_ids, attention_mask)]
1096
+ embeds_list, labels_list = self.get_multimodal_input_embeddings(
1097
+ input_ids_nopad,
1098
+ image_features,
1099
+ return_labels=True
1100
+ )
1101
+
1102
+ # ───────────────────── truncate then pad back to rectangle ──────────────────────
1103
+ new_input_embeds = torch.nn.utils.rnn.pad_sequence(
1104
+ embeds_list,
1105
+ batch_first=True,
1106
+ padding_value=0.0
1107
+ ).to(dtype=self.dtype)
1108
+
1109
+ new_labels = torch.nn.utils.rnn.pad_sequence(
1110
+ labels_list,
1111
+ batch_first=True,
1112
+ padding_value=IGNORE_INDEX
1113
+ ).long()
1114
+
1115
+ if self.config.tokenizer_model_max_length is not None:
1116
+ new_input_embeds = new_input_embeds[:, :self.config.tokenizer_model_max_length]
1117
+ new_labels = new_labels[:, :self.config.tokenizer_model_max_length]
1118
+
1119
+
1120
+
1121
+
1122
+ # ────────────────────────────── attention mask and position ids ────────────────
1123
+
1124
+ attention_mask = (
1125
+ torch.arange(new_input_embeds.size(1), device=input_ids.device)
1126
+ .unsqueeze(0)
1127
+ < torch.tensor([e.size(0) for e in embeds_list],
1128
+ device=input_ids.device).unsqueeze(1)
1129
+ )
1130
+
1131
+ raw_pos = attention_mask.cumsum(dim=1) - 1
1132
+ position_ids = raw_pos.masked_fill(~attention_mask, 0).long()
1133
+
1134
+ if not self.training:
1135
+ new_labels = None
1136
+
1137
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1138
+
1139
+
1140
+
1141
+ # ------------------------------------------------------------------
1142
+ def forward(
1143
+ self,
1144
+ input_ids: torch.LongTensor = None,
1145
+ attention_mask: Optional[torch.Tensor] = None,
1146
+ position_ids: Optional[torch.LongTensor] = None,
1147
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1148
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1149
+ labels: Optional[torch.LongTensor] = None,
1150
+ use_cache: Optional[bool] = None,
1151
+ output_attentions: Optional[bool] = None,
1152
+ output_hidden_states: Optional[bool] = None,
1153
+ return_dict: Optional[bool] = None,
1154
+ cache_position: Optional[torch.LongTensor] = None,
1155
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1156
+ images: Optional[PIL.Image.Image] = None,
1157
+ **kwargs: Unpack[KwargsForCausalLM],
1158
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1159
+
1160
+ is_multi_modal = images is not None and not (
1161
+ (
1162
+ isinstance(images, list) and (len(images) == 0 or all(i == [] for i in images))
1163
+ )
1164
+ )
1165
+
1166
+
1167
+ if inputs_embeds is None and is_multi_modal:
1168
+ (
1169
+ input_ids,
1170
+ position_ids,
1171
+ attention_mask,
1172
+ past_key_values,
1173
+ inputs_embeds,
1174
+ labels
1175
+ ) = self.prepare_inputs_for_multimodal(
1176
+ input_ids=input_ids,
1177
+ images=images,
1178
+ position_ids=position_ids,
1179
+ attention_mask=attention_mask,
1180
+ past_key_values=past_key_values,
1181
+ labels=labels,
1182
+ )
1183
+
1184
+ if cache_position is not None and inputs_embeds is not None and cache_position.shape[0] != inputs_embeds.shape[1]:
1185
+ cache_position = torch.arange(inputs_embeds.shape[1], device=self.device)
1186
+
1187
+
1188
+ return Phi3ForCausalLM.forward(
1189
+ self,
1190
+ input_ids=input_ids,
1191
+ attention_mask=attention_mask,
1192
+ position_ids=position_ids,
1193
+ past_key_values=past_key_values,
1194
+ inputs_embeds=inputs_embeds,
1195
+ labels=labels,
1196
+ use_cache=use_cache,
1197
+ output_attentions=output_attentions,
1198
+ output_hidden_states=output_hidden_states,
1199
+ return_dict=return_dict,
1200
+ cache_position=cache_position,
1201
+ logits_to_keep=logits_to_keep,
1202
+ **kwargs
1203
+ )
1204
+
1205
+ def print_device_configuration(self):
1206
+ print("*************Device Configuration*********")
1207
+ if len(self.get_llm_parameters()) > 0:
1208
+ llm_device = set({str(p.device) for p in self.get_llm_parameters()})
1209
+ llm_dtype = set({p.dtype for p in self.get_llm_parameters()})
1210
+ print(f"LLM Parameters:\t\t\tdevice: {llm_device}\tdtype: {llm_dtype}\tfrozen: {self.is_llm_frozen()}")
1211
+ else:
1212
+ print("LLM parameters have not been initialized")
1213
+
1214
+ if self.get_model().vision_tower is not None:
1215
+ vt_device = set({str(p.device) for p in self.get_model().vision_tower.parameters()})
1216
+ vt_dtype = set({p.dtype for p in self.get_model().vision_tower.parameters()})
1217
+ print(f"Vision Tower Parameters:\tdevice: {vt_device}\tdtype: {vt_dtype}\tfrozen: {self.is_vision_tower_frozen()}")
1218
+ else:
1219
+ print("Vision tower parameters have not been initialized")
1220
+
1221
+ if self.get_model().mm_projector is not None:
1222
+ mm_device = set({str(p.device) for p in self.get_model().mm_projector.parameters()})
1223
+ mm_dtype = set({p.dtype for p in self.get_model().mm_projector.parameters()})
1224
+ print(f"MM Projector Parameters:\tdevice: {mm_device}\tdtype: {mm_dtype}\tfrozen: {self.is_vision_adapter_frozen()}")
1225
+ else:
1226
+ print("MM Projector parameters have not been initialized")
1227
+ print("******************************************")
1228
+
1229
+
1230
+
1231
+ def build_tokenizer(base_model_id: str) -> Tuple[AutoTokenizer, dict]:
1232
+ tok = AutoTokenizer.from_pretrained(base_model_id, padding_side="right")
1233
+ specials = {t: tok.convert_tokens_to_ids(t) for t in [IMAGE_TOKEN, IMG_START_TOKEN, IMG_END_TOKEN] if t in tok.vocab}
1234
+ if len(specials) < 3:
1235
+ n = tok.add_tokens([IMAGE_TOKEN, IMG_START_TOKEN, IMG_END_TOKEN], special_tokens=True)
1236
+ tok.pad_token = tok.eos_token
1237
+ specials = {
1238
+ "image": tok.convert_tokens_to_ids(IMAGE_TOKEN),
1239
+ "start": tok.convert_tokens_to_ids(IMG_START_TOKEN),
1240
+ "end": tok.convert_tokens_to_ids(IMG_END_TOKEN),
1241
+ }
1242
+ return tok, specials
1243
+