kevin510 commited on
Commit
f22e168
·
verified ·
1 Parent(s): a2b4415

Delete modeling_friday.py

Browse files
Files changed (1) hide show
  1. modeling_friday.py +0 -1267
modeling_friday.py DELETED
@@ -1,1267 +0,0 @@
1
- # Model Constants
2
- IMAGE_TOKEN = "<image>"
3
- IMG_START_TOKEN = "<img_start>"
4
- IMG_END_TOKEN = "<img_end>"
5
- IGNORE_INDEX = -100
6
- PAD_FOR_EOS = -300
7
-
8
-
9
-
10
-
11
-
12
- import torch
13
- import torch.nn.functional as F
14
-
15
- from PIL import Image
16
-
17
-
18
- import torch
19
-
20
- def mask_token_segment(
21
- start_id: int,
22
- end_id: int,
23
- input_ids: torch.Tensor,
24
- fill_value: int = -100):
25
- """
26
- Replace *every* token from each `start_id` **through** its matching `end_id`
27
- (boundaries included) with `fill_value`. Any spans that start with some
28
- other token are left untouched.
29
-
30
- Works on CUDA, TorchScript, batched via vmap, etc.—no Python loops.
31
- """
32
- if input_ids.dim() != 1:
33
- raise ValueError("`input_ids` must be 1-D")
34
-
35
- device = input_ids.device
36
- n = input_ids.size(0)
37
-
38
- # where the *target* start-tokens and end-tokens sit
39
- start_pos = (input_ids == start_id).nonzero(as_tuple=True)[0] # ascending
40
- end_pos = (input_ids == end_id).nonzero(as_tuple=True)[0] # ascending
41
-
42
- if start_pos.numel() == 0:
43
- return input_ids.clone()
44
-
45
- # ── pair every start with the first end that comes *after* it ────────────────
46
- # searchsorted gives the insertion index into the (sorted) end positions
47
- idx_in_end = torch.searchsorted(end_pos, start_pos, right=False)
48
-
49
- have_match = idx_in_end < end_pos.size(0) # safety: drop unmatched
50
- start_pos = start_pos[have_match]
51
- end_pos = end_pos[idx_in_end[have_match]]
52
-
53
- # (rare) guard against pathological orderings
54
- keep = end_pos > start_pos
55
- start_pos, end_pos = start_pos[keep], end_pos[keep]
56
-
57
- if start_pos.numel() == 0:
58
- return input_ids
59
-
60
- # ── differential “scan-line” trick to build the span mask in O(N) ───────────
61
- # +1 at each start index, -1 at the element *after* each end
62
- delta = torch.zeros(n + 1, dtype=torch.int8, device=device)
63
- delta[start_pos] += 1
64
- delta[end_pos + 1] -= 1 # +1 is safe because delta is length n+1
65
-
66
- inside = torch.cumsum(delta[:-1], dim=0) > 0 # boolean mask, incl. boundaries
67
-
68
- # ── apply ────────────────────────────────────────────────────────────────────
69
- out = input_ids.clone()
70
- out[inside] = fill_value
71
- return out
72
-
73
-
74
-
75
- def maybe_zero_3(param, ignore_status=False, name=None):
76
- from deepspeed import zero
77
- from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
78
- if hasattr(param, "ds_id"):
79
- if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
80
- if not ignore_status:
81
- print(name, 'no ignore status')
82
- with zero.GatheredParameters([param]):
83
- param = param.data.detach().cpu().clone()
84
- else:
85
- param = param.detach().cpu().clone()
86
- return param
87
-
88
-
89
- # Borrowed from peft.util.get_peft_model_state_dict
90
- def get_peft_state_maybe_zero_3(named_params, bias):
91
- if bias == "none":
92
- to_return = {k: t for k, t in named_params if "lora_" in k}
93
- elif bias == "all":
94
- to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
95
- elif bias == "lora_only":
96
- to_return = {}
97
- maybe_lora_bias = {}
98
- lora_bias_names = set()
99
- for k, t in named_params:
100
- if "lora_" in k:
101
- to_return[k] = t
102
- bias_name = k.split("lora_")[0] + "bias"
103
- lora_bias_names.add(bias_name)
104
- elif "bias" in k:
105
- maybe_lora_bias[k] = t
106
- for k, t in maybe_lora_bias:
107
- if bias_name in lora_bias_names:
108
- to_return[bias_name] = t
109
- else:
110
- raise NotImplementedError
111
- to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
112
- return to_return
113
-
114
-
115
- def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
116
- to_return = {k: t for k, t in named_params if "lora_" not in k}
117
- if require_grad_only:
118
- to_return = {k: t for k, t in to_return.items() if t.requires_grad}
119
- to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
120
- return to_return
121
-
122
-
123
- def find_all_linear_names(modules):
124
- lora_module_names = set()
125
- for name, module in modules():
126
- if isinstance(module, torch.nn.Linear):
127
- names = name.split('.')
128
- lora_module_names.add(names[0] if len(names) == 1 else names[-1])
129
-
130
- if 'lm_head' in lora_module_names: # needed for 16-bit
131
- lora_module_names.remove('lm_head')
132
- return list(lora_module_names)
133
-
134
-
135
- def expand2square(pil_img, background_color):
136
- width, height = pil_img.size
137
- if width == height:
138
- return pil_img
139
- elif width > height:
140
- result = Image.new(pil_img.mode, (width, width), background_color)
141
- result.paste(pil_img, (0, (width - height) // 2))
142
- return result
143
- else:
144
- result = Image.new(pil_img.mode, (height, height), background_color)
145
- result.paste(pil_img, ((height - width) // 2, 0))
146
- return result
147
-
148
- def pad_and_stack(img_list, pad_value=0.0):
149
- """
150
- img_list : list[Tensor] each (C, H, W) already *normalised*
151
- pad_value: float or tuple/list of 3 floats (one per channel)
152
- Use 0.0 if your processor has already centred to mean 0.
153
- Returns
154
- -------
155
- batch : Tensor (B, C, H_max, W_max)
156
- """
157
-
158
- # 1. target square size ---------------------------------------------------
159
- h_max = max(t.shape[1] for t in img_list)
160
- w_max = max(t.shape[2] for t in img_list)
161
- H, W = max(h_max, w_max), max(h_max, w_max)
162
-
163
- # 2. create padded copies -------------------------------------------------
164
- padded = []
165
- for img in img_list:
166
- c, h, w = img.shape
167
- canvas = img.new_full((c, H, W), pad_value) # filled with mean/zeros
168
- canvas[:, :h, :w] = img # top-left corner
169
- padded.append(canvas)
170
-
171
- return torch.stack(padded, 0) # (B,C,H,W)
172
-
173
-
174
-
175
-
176
-
177
- # ------------------------------------------------------------------------------------------
178
- # Copyright (c) 2024 Baifeng Shi.
179
- # All rights reserved.
180
- #
181
- # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
182
- # ------------------------------------------------------------------------------------------
183
-
184
- import torch
185
-
186
- def split_chessboard(x, num_split):
187
- """
188
- x: b * c * h * w
189
- Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
190
- """
191
- B, C, H, W = x.shape
192
- assert H % num_split == 0 and W % num_split == 0
193
- h, w = H // num_split, W // num_split
194
- x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
195
- return x_split
196
-
197
- def merge_chessboard(x, num_split):
198
- """
199
- x: b * c * h * w
200
- Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
201
- (inverse of split_chessboard)
202
- """
203
- B, C, H, W = x.shape
204
- assert B % (num_split**2) == 0
205
- b = B // (num_split**2)
206
- x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1)
207
- for i in range(num_split)], dim=-2)
208
- return x_merge
209
-
210
- def batched_forward(model, x, batch_size=-1):
211
- if batch_size == -1:
212
- return model(x)
213
- else:
214
- x_batched = x.split(batch_size)
215
- outs = [model(x) for x in x_batched]
216
- return torch.cat(outs, dim=0)
217
-
218
-
219
-
220
-
221
-
222
-
223
-
224
- # ------------------------------------------------------------------------------------------
225
- # Copyright (c) 2024 Baifeng Shi.
226
- # All rights reserved.
227
- #
228
- # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
229
- # ------------------------------------------------------------------------------------------
230
-
231
- import math
232
- import torch
233
- import torch.nn.functional as F
234
- from einops import rearrange
235
- from .utils import split_chessboard, merge_chessboard, batched_forward
236
-
237
- def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0,
238
- output_shape='bnc', split_forward=False):
239
-
240
- # print(f"Input shape: {input.shape}")
241
-
242
- assert input.dim() == 4, "Input image must be in the shape of BxCxHxW."
243
- assert input.shape[2] == input.shape[3], "Currently only square images are supported."
244
- assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)."
245
- assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token."
246
-
247
- b, c, input_size, _ = input.shape
248
-
249
- # image size for each scale
250
- assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes."
251
- img_sizes = img_sizes or [int(input_size * scale) for scale in scales]
252
-
253
- # prepare multiscale inputs
254
- max_split_size = max_split_size or input_size # The maximum size of each split of image. Set as the input size by default
255
- num_splits = [math.ceil(size / max_split_size) for size in img_sizes] # number of splits each scale
256
- input_multiscale = []
257
- for size, num_split in zip(img_sizes, num_splits):
258
- x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype)
259
- x = split_chessboard(x, num_split=num_split)
260
- input_multiscale.append(x)
261
-
262
- # run feedforward on each scale
263
- outs_multiscale = [batched_forward(model, x, b) if split_forward else model(x) for x in input_multiscale]
264
- if num_prefix_token > 0:
265
- outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]
266
- outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale]
267
- if output_shape == 'bnc':
268
- outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5))
269
- for out in outs_multiscale]
270
-
271
- # merge outputs of different splits for each scale separately
272
- outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)]
273
-
274
- # interpolate outputs from different scales and concat together
275
- output_size = outs_multiscale[resize_output_to_idx].shape[-2]
276
- out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size,
277
- mode='area').to(outs_multiscale[i].dtype)
278
- for i in range(len(outs_multiscale))], dim=1)
279
- if output_shape == 'bnc':
280
- out = rearrange(out, 'b c h w -> b (h w) c')
281
- if num_prefix_token > 0:
282
- # take the mean of prefix tokens from different splits for each scale
283
- outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale]
284
- out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1)
285
- out = torch.cat([out_prefix_multiscale, out], dim=1)
286
-
287
- return out
288
-
289
-
290
-
291
-
292
-
293
- import torch
294
- import torch.nn as nn
295
-
296
- class MLPAdapter(nn.Module):
297
-
298
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, activation='gelu', checkpoint_path=None, device=None, **kwargs):
299
- """
300
- Initialize the MLPAdapter with the given dimensions and activation function.
301
-
302
- Args:
303
- input_dim (int): Input dimension.
304
- hidden_dim (int): Hidden dimension.
305
- output_dim (int): Output dimension.
306
- layers (int): Number of layers in the MLP.
307
- activation (str): Activation function to use ('gelu' or 'relu').
308
- """
309
- super().__init__()
310
- self.num_layers = num_layers
311
- self.activation = activation
312
- self.output_dim = output_dim
313
-
314
- # Define the first layer
315
- layers_list = [nn.Linear(input_dim, hidden_dim, device=device)]
316
- if activation == 'gelu':
317
- layers_list.append(nn.GELU())
318
- elif activation == 'relu':
319
- layers_list.append(nn.ReLU())
320
- else:
321
- raise ValueError("Unsupported activation function. Use 'gelu' or 'relu'.")
322
-
323
- # Define the subsequent layers
324
- for _ in range(1, num_layers):
325
- layers_list.append(nn.Linear(hidden_dim, hidden_dim, device=device))
326
- if activation == 'gelu':
327
- layers_list.append(nn.GELU())
328
- elif activation == 'relu':
329
- layers_list.append(nn.ReLU())
330
-
331
- # Define the final output layer
332
- layers_list.append(nn.Linear(hidden_dim, output_dim, device=device))
333
- self.mlp = nn.Sequential(*layers_list)
334
-
335
- # Load checkpoint if provided
336
- if checkpoint_path:
337
- self.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)
338
- print(f"Loaded MLPAdapter from {checkpoint_path}")
339
-
340
- if device:
341
- self.to(device)
342
-
343
- def forward(self, x):
344
- """
345
- Forward pass through the MLPAdapter.
346
-
347
- Args:
348
- x (torch.Tensor): Input tensor.
349
-
350
- Returns:
351
- torch.Tensor: Output tensor after passing through the MLP.
352
- """
353
- return self.mlp(x)
354
-
355
-
356
-
357
-
358
- import torch
359
- import torch.nn as nn
360
- import torch.nn.functional as F
361
-
362
- import PIL.Image
363
- from typing import List
364
- from friday.util import expand2square, pad_and_stack
365
-
366
- from transformers import AutoModel, AutoImageProcessor
367
- from friday.util.s2wrapper import forward as multiscale_forward
368
-
369
-
370
- class FastVitVisionTower(nn.Module):
371
- def __init__(self, pretrained_model_name_or_path, model_params={}, pad_to_square=True, **kwargs):
372
- super().__init__()
373
-
374
- self.is_loaded = False
375
- self.pretrained_model_name_or_path = pretrained_model_name_or_path
376
- self.model_params = model_params
377
- self.pad_to_square = pad_to_square
378
- self.load_model()
379
-
380
- @property
381
- def output_dim(self):
382
- return self.vision_tower.config.embed_dim if self.vision_tower else None
383
-
384
- def load_model(self):
385
- if self.is_loaded:
386
- return
387
- self.image_processor = AutoImageProcessor.from_pretrained(self.pretrained_model_name_or_path)
388
- self.image_processor.crop_size = self.image_processor.size
389
- self.vision_tower = AutoModel.from_pretrained(
390
- self.pretrained_model_name_or_path,
391
- **self.model_params,
392
- )
393
- self.vision_tower.requires_grad_(False)
394
-
395
- self.is_loaded = True
396
-
397
- def preprocess_images(self, imgs: List[PIL.Image.Image], pad_and_stack_tensors=True) -> torch.Tensor:
398
- img_mean = tuple(int(x * 255) for x in self.image_processor.image_mean)
399
- if self.pad_to_square:
400
- imgs = [expand2square(img, img_mean) for img in imgs]
401
-
402
- imgs = [self.image_processor(img, do_resize=True, do_center_crop=False, return_tensors="pt")['pixel_values'][0] for img in imgs]
403
-
404
-
405
- if pad_and_stack_tensors:
406
- imgs = pad_and_stack(imgs, pad_value=0.0)
407
- imgs = imgs.to(dtype=torch.float32, device=self.device)
408
-
409
- return imgs
410
-
411
- def forward(self, images):
412
- if type(images) is list:
413
- image_features = []
414
- for image in images:
415
- image_feature = self.vision_tower(
416
- image.to(device=self.device, dtype=self.dtype).unsqueeze(0)
417
- )
418
- image_features.append(image_feature)
419
- else:
420
- image_features = self.vision_tower(
421
- images.to(device=self.device, dtype=self.dtype),
422
- )
423
-
424
- return image_features
425
-
426
- @property
427
- def dummy_feature(self):
428
- return torch.zeros(1, self.embed_dim, device=self.device, dtype=self.dtype)
429
-
430
- @property
431
- def dtype(self):
432
- return self.vision_tower.dtype
433
-
434
- @property
435
- def device(self):
436
- return self.vision_tower.device
437
-
438
- @property
439
- def config(self):
440
- if self.is_loaded:
441
- return self.vision_tower.config
442
- else:
443
- return self.cfg_only
444
-
445
- @property
446
- def hidden_size(self):
447
- return self.config.embed_dim
448
-
449
- @property
450
- def num_patches(self):
451
- return (self.config.image_size // self.config.patch_size) ** 2
452
-
453
-
454
- class FastVitVisionTowerS2(FastVitVisionTower):
455
- def __init__(self, pretrained_model_name_or_path, s2_scales, model_params={}, **kwargs):
456
- self.s2_scales = list(map(int, s2_scales.split(',')))
457
- self.s2_scales.sort()
458
- self.s2_split_size = self.s2_scales[0]
459
- self.s2_image_size = self.s2_scales[-1]
460
-
461
- super().__init__(pretrained_model_name_or_path, model_params)
462
-
463
- self.multiscale_forward = multiscale_forward
464
-
465
- @property
466
- def output_dim(self):
467
- return (2*self.vision_tower.config.embed_dim) if self.vision_tower else None
468
-
469
- def load_model(self):
470
- if self.is_loaded:
471
- return
472
-
473
- super().load_model()
474
- self.image_processor.size = self.image_processor.crop_size = {
475
- "height": self.s2_image_size,
476
- "width": self.s2_image_size
477
- }
478
-
479
- def forward_feature(self, images):
480
- image_size = self.vision_tower.config.image_size
481
- if images.shape[2] != image_size or images.shape[3] != image_size:
482
- images = F.interpolate(
483
- images,
484
- size=(image_size, image_size),
485
- mode="bilinear",
486
- align_corners=False,
487
- antialias=True
488
- )
489
-
490
- return self.vision_tower(
491
- images.to(device=self.device, dtype=self.dtype),
492
- )
493
-
494
- def forward(self, images):
495
- if type(images) is list:
496
- image_features = []
497
- for image in images:
498
- image_feature = self.multiscale_forward(
499
- self.forward_feature,
500
- image.unsqueeze(0),
501
- img_sizes=self.s2_scales,
502
- max_split_size=self.s2_split_size
503
- )
504
- image_features.append(image_feature)
505
- else:
506
- image_features = self.multiscale_forward(
507
- self.forward_feature,
508
- images,
509
- img_sizes=self.s2_scales,
510
- max_split_size=self.s2_split_size
511
- )
512
-
513
- return image_features
514
-
515
- @property
516
- def hidden_size(self):
517
- return self.config.embed_dim * len(self.s2_scales)
518
-
519
-
520
-
521
-
522
-
523
- import torch
524
- import torch.nn as nn
525
-
526
- import PIL.Image
527
- from typing import List
528
- from friday.util import expand2square, pad_and_stack
529
-
530
- from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig
531
- from friday.util.s2wrapper import forward as multiscale_forward
532
-
533
-
534
- class SiglipVisionTower(nn.Module):
535
- def __init__(self, pretrained_model_name_or_path, model_params={}, pad_to_square=True, **kwargs):
536
- super().__init__()
537
-
538
- self.is_loaded = False
539
- self.pretrained_model_name_or_path = pretrained_model_name_or_path
540
- self.model_params = model_params
541
- self.pad_to_square = pad_to_square
542
- self.select_layer = -2
543
- self.load_model()
544
-
545
- @property
546
- def output_dim(self):
547
- return self.vision_tower.config.hidden_size if self.vision_tower else None
548
-
549
- def load_model(self):
550
- if self.is_loaded:
551
- return
552
- self.image_processor = SiglipImageProcessor.from_pretrained(self.pretrained_model_name_or_path)
553
- self.image_processor.crop_size = self.image_processor.size
554
- self.vision_tower = SiglipVisionModel.from_pretrained(
555
- self.pretrained_model_name_or_path,
556
- **self.model_params,
557
- )
558
- self.vision_tower.requires_grad_(False)
559
-
560
- self.is_loaded = True
561
-
562
- def preprocess_images(self, imgs: List[PIL.Image.Image], pad_and_stack_tensors=True) -> torch.Tensor:
563
- img_mean = tuple(int(x * 255) for x in self.image_processor.image_mean)
564
- if self.pad_to_square:
565
- imgs = [expand2square(img, img_mean) for img in imgs]
566
- imgs = [self.image_processor(img, return_tensors="pt")['pixel_values'][0] for img in imgs]
567
-
568
- if pad_and_stack_tensors:
569
- imgs = pad_and_stack(imgs, pad_value=0.0)
570
- imgs = imgs.to(dtype=torch.float32, device=self.device)
571
-
572
- return imgs
573
-
574
- def feature_select(self, image_forward_outs):
575
- image_features = image_forward_outs.hidden_states[self.select_layer]
576
-
577
- return image_features
578
-
579
- def forward(self, images):
580
- if type(images) is list:
581
- image_features = []
582
- for image in images:
583
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
584
- output_hidden_states=True)
585
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
586
- image_features.append(image_feature)
587
- else:
588
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
589
- output_hidden_states=True)
590
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
591
-
592
- return image_features
593
-
594
- @property
595
- def dummy_feature(self):
596
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
597
-
598
- @property
599
- def dtype(self):
600
- return self.vision_tower.dtype
601
-
602
- @property
603
- def device(self):
604
- return self.vision_tower.device
605
-
606
- @property
607
- def config(self):
608
- if self.is_loaded:
609
- return self.vision_tower.config
610
- else:
611
- return self.cfg_only
612
-
613
- @property
614
- def hidden_size(self):
615
- return self.config.hidden_size
616
-
617
- @property
618
- def num_patches(self):
619
- return (self.config.image_size // self.config.patch_size) ** 2
620
-
621
-
622
- class SiglipVisionTowerS2(SiglipVisionTower):
623
- def __init__(self, pretrained_model_name_or_path, s2_scales, model_params={}, **kwargs):
624
- self.s2_scales = list(map(int, s2_scales.split(',')))
625
- self.s2_scales.sort()
626
- self.s2_split_size = self.s2_scales[0]
627
- self.s2_image_size = self.s2_scales[-1]
628
-
629
- super().__init__(pretrained_model_name_or_path, model_params)
630
-
631
- self.multiscale_forward = multiscale_forward
632
-
633
- self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
634
- self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
635
-
636
- @property
637
- def output_dim(self):
638
- return (2*self.vision_tower.config.hidden_size) if self.vision_tower else None
639
-
640
- def load_model(self):
641
- if self.is_loaded:
642
- return
643
-
644
- super().load_model()
645
- self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
646
- self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
647
-
648
- def forward_feature(self, images):
649
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
650
- output_hidden_states=True)
651
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
652
- return image_features
653
-
654
- def forward(self, images):
655
- if type(images) is list:
656
- image_features = []
657
- for image in images:
658
- image_feature = self.multiscale_forward(
659
- self.forward_feature,
660
- image.unsqueeze(0),
661
- img_sizes=self.s2_scales,
662
- max_split_size=self.s2_split_size
663
- )
664
- image_features.append(image_feature)
665
- else:
666
- image_features = self.multiscale_forward(
667
- self.forward_feature,
668
- images,
669
- img_sizes=self.s2_scales,
670
- max_split_size=self.s2_split_size
671
- )
672
-
673
- return image_features
674
-
675
- @property
676
- def hidden_size(self):
677
- return self.config.hidden_size * len(self.s2_scales)
678
-
679
-
680
-
681
-
682
-
683
- from __future__ import annotations
684
-
685
- import torch
686
- import torch.nn as nn
687
- import torch.nn.functional as F
688
- from torchvision import transforms
689
-
690
- from typing import List, Tuple, Optional, Union
691
-
692
- import PIL
693
-
694
- from transformers import AutoTokenizer, AutoConfig
695
- from transformers.modeling_outputs import CausalLMOutputWithPast
696
-
697
-
698
- from friday.model.vision_adapter import MLPAdapter
699
- from friday.model.vision_tower import (
700
- SiglipVisionTower,
701
- SiglipVisionTowerS2,
702
- FastVitVisionTower,
703
- FastVitVisionTowerS2
704
- )
705
- from friday.model.language_model.phi4 import (
706
- Phi3Config,
707
- Phi3Model,
708
- Phi3ForCausalLM
709
- )
710
- from friday.constants import (
711
- IMAGE_TOKEN,
712
- IMG_START_TOKEN,
713
- IMG_END_TOKEN,
714
- IGNORE_INDEX
715
- )
716
-
717
- DEFAULT_CFG_SPECIAL_TOKENS = {
718
- "image_token_id": 200029,
719
- "image_start_token_id": 200030,
720
- "image_end_token_id": 200031,
721
- }
722
- DEFAULT_CFG_VISION_TOWER = {
723
- "pretrained_model_name_or_path": "kevin510/fast-vit-hd",
724
- "type": "fastvit",
725
- "s2_scales": "512,1024",
726
- "use_s2": True,
727
- "pad_to_square": True,
728
- "freeze": False,
729
- "model_params": { "trust_remote_code": True }
730
- }
731
- DEFAULT_CFG_VISION_ADAPTER = {
732
- "input_dim": 6144,
733
- "hidden_dim": 3072,
734
- "output_dim": 3072,
735
- "layers": 2,
736
- "activation": "gelu",
737
- "freeze": False,
738
- }
739
-
740
-
741
- class FridayConfig(Phi3Config):
742
- model_type = "friday"
743
-
744
- def __init__(self,
745
- base_model_name_or_path: str | None = "microsoft/Phi-4-mini-reasoning",
746
- delay_load=False,
747
- tokenizer_model_max_length=None,
748
- **kwargs
749
- ):
750
- base_kwargs = {}
751
- if base_model_name_or_path is not None:
752
- base_cfg = AutoConfig.from_pretrained(
753
- base_model_name_or_path,
754
- trust_remote_code=True, # Phi‑4 uses custom code in the repo
755
- )
756
- base_kwargs = base_cfg.to_dict()
757
-
758
- merged = {**base_kwargs, **kwargs}
759
- self.delay_load = delay_load
760
- self.tokenizer_model_max_length = tokenizer_model_max_length
761
-
762
- self._cfg_vision_tower = DEFAULT_CFG_VISION_TOWER.copy()
763
- if "cfg_vision_tower" in kwargs:
764
- self._cfg_vision_tower.update(kwargs["cfg_vision_tower"])
765
-
766
- self._cfg_vision_adapter = DEFAULT_CFG_VISION_ADAPTER.copy()
767
- if "cfg_vision_adapter" in kwargs:
768
- self._cfg_vision_adapter.update(kwargs["cfg_vision_adapter"])
769
-
770
- self._cfg_special_tokens = DEFAULT_CFG_SPECIAL_TOKENS.copy()
771
- if "cfg_special_tokens" in kwargs:
772
- self._cfg_special_tokens.update(kwargs["cfg_special_tokens"])
773
-
774
- super().__init__(**merged)
775
-
776
-
777
- @property
778
- def cfg_vision_tower(self):
779
- return self._cfg_vision_tower
780
-
781
- @cfg_vision_tower.setter
782
- def cfg_vision_tower(self, value):
783
- if not value:
784
- raise ValueError("Name cannot be empty")
785
- self._cfg_vision_tower.update(value)
786
-
787
-
788
- @property
789
- def cfg_vision_adapter(self):
790
- return self._cfg_vision_adapter
791
-
792
- @cfg_vision_adapter.setter
793
- def cfg_vision_adapter(self, value):
794
- if not value:
795
- raise ValueError("Name cannot be empty")
796
- self._cfg_vision_adapter.update(value)
797
-
798
- @property
799
- def cfg_special_tokens(self):
800
- return self._cfg_special_tokens
801
-
802
- @cfg_special_tokens.setter
803
- def cfg_special_tokens(self, value):
804
- if not value:
805
- raise ValueError("Name cannot be empty")
806
- self._cfg_special_tokens.update(value)
807
-
808
-
809
- class FridayModel(Phi3Model):
810
- config_class = FridayConfig
811
-
812
- def __init__(self, config: FridayConfig):
813
- super().__init__(config)
814
-
815
- self.cfg_vision_adapter = config.cfg_vision_adapter
816
- self.cfg_vision_tower = config.cfg_vision_tower
817
-
818
- self.vision_tower = None
819
- self.mm_projector = None
820
- if not config.delay_load:
821
- self.initialize_vision_modules()
822
-
823
- def get_vision_tower(self):
824
- return self.vision_tower
825
-
826
- def initialize_vision_modules(self):
827
- if self.vision_tower is not None:
828
- return
829
-
830
- if self.cfg_vision_tower.get("type", "siglip").lower() == "siglip":
831
- if self.cfg_vision_tower.get("use_s2", True):
832
- self.vision_tower = SiglipVisionTowerS2(**self.cfg_vision_tower)
833
- else:
834
- self.vision_tower = SiglipVisionTower(**self.cfg_vision_tower)
835
- elif self.cfg_vision_tower.get("type", "siglip").lower() == "fastvit":
836
- if self.cfg_vision_tower.get("use_s2", True):
837
- self.vision_tower = FastVitVisionTowerS2(**self.cfg_vision_tower)
838
- else:
839
- self.vision_tower = FastVitVisionTower(**self.cfg_vision_tower)
840
- else:
841
- raise ValueError(f"Unsupported vision tower type: {self.cfg_vision_tower.get('type', 'siglip')}. Supported types are 'siglip' and 'fastvit'.")
842
-
843
- self.vision_tower.load_model()
844
- self.mm_projector = MLPAdapter(**self.cfg_vision_adapter)
845
-
846
- if self.cfg_vision_tower.get("freeze", False):
847
- self.set_vision_tower_requires_grad(False)
848
-
849
- if self.cfg_vision_adapter.get("freeze", False):
850
- self.set_vision_adapter_requires_grad(False)
851
-
852
- def compute_image_features(self, imgs: torch.Tensor) -> torch.Tensor:
853
- features = self.vision_tower(imgs)
854
- if isinstance(features, list):
855
- features = torch.stack(features, dim=1)
856
- return self.mm_projector(features)
857
-
858
- def set_vision_tower_requires_grad(self, requires_grad: bool):
859
- if self.vision_tower is not None:
860
- for param in self.vision_tower.parameters():
861
- param.requires_grad = requires_grad
862
- else:
863
- raise ValueError("Vision tower is not initialized. Please call initialize_vision_modules() first.")
864
-
865
- def set_vision_adapter_requires_grad(self, requires_grad: bool):
866
- if self.mm_projector is not None:
867
- for param in self.mm_projector.parameters():
868
- param.requires_grad = requires_grad
869
- else:
870
- raise ValueError("Vision adapter is not initialized. Please call initialize_vision_modules() first.")
871
-
872
- def set_vision_tower_dtype(self, dtype: torch.dtype):
873
- if self.vision_tower is not None:
874
- for p in self.vision_tower.parameters():
875
- p.data = p.data.to(dtype)
876
- else:
877
- raise ValueError("Vision tower is not initialized. Please call initialize_vision_modules() first.")
878
-
879
- def set_vision_adapter_dtype(self, dtype: torch.dtype):
880
- if self.mm_projector is not None:
881
- for p in self.mm_projector.parameters():
882
- p.data = p.data.to(dtype)
883
- else:
884
- raise ValueError("Vision adapter is not initialized. Please call initialize_vision_modules() first.")
885
-
886
- def is_vision_tower_frozen(self):
887
- if self.vision_tower is not None:
888
- return all(not p.requires_grad for p in self.vision_tower.parameters())
889
- else:
890
- raise ValueError("Vision tower is not initialized. Please call initialize_vision_modules() first.")
891
-
892
- def is_vision_adapter_frozen(self):
893
- if self.mm_projector is not None:
894
- return all(not p.requires_grad for p in self.mm_projector.parameters())
895
- else:
896
- raise ValueError("Vision adapter is not initialized. Please call initialize_vision_modules() first.")
897
-
898
-
899
- class FridayForCausalLM(Phi3ForCausalLM):
900
- config_class = FridayConfig
901
-
902
- def __init__(self, config: FridayConfig):
903
- super().__init__(config)
904
-
905
- self.config = config
906
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
907
- self.image_token_id = config.cfg_special_tokens["image_token_id"]
908
- self.image_start_id = config.cfg_special_tokens["image_start_token_id"]
909
- self.image_end_id = config.cfg_special_tokens["image_end_token_id"]
910
-
911
- self.model = FridayModel(config)
912
- self.post_init()
913
-
914
- def get_model(self) -> FridayModel:
915
- return self.model
916
-
917
- def get_vision_tower(self) -> SiglipVisionTower:
918
- return self.model.get_vision_tower()
919
-
920
- def get_vision_adapter(self) -> MLPAdapter:
921
- return self.model.mm_projector
922
-
923
- def get_llm_parameters(self, exclude_lora: bool = False):
924
- return [
925
- p for n, p in self.named_parameters()
926
- if "vision_tower" not in n and "mm_projector" not in n and (not exclude_lora or ("lora_" not in n))
927
- ]
928
-
929
- def get_llm_named_modules(self):
930
- return {n: m for n, m in self.named_modules() if "vision_tower" not in n and "mm_projector" not in n}
931
-
932
- def set_llm_requires_grad(self, requires_grad: bool, exclude_lora: bool = True):
933
- for n, p in self.named_parameters():
934
- if exclude_lora and ("lora_A" in n or "lora_B" in n):
935
- continue
936
- if "vision_tower" in n or "mm_projector" in n:
937
- continue
938
- p.requires_grad = requires_grad
939
-
940
- def set_vision_tower_requires_grad(self, requires_grad: bool):
941
- self.model.set_vision_tower_requires_grad(requires_grad)
942
-
943
- def set_vision_adapter_requires_grad(self, requires_grad: bool):
944
- self.model.set_vision_adapter_requires_grad(requires_grad)
945
-
946
- def set_llm_dtype(self, dtype: torch.dtype):
947
- for p in self.get_llm_parameters():
948
- p.data = p.data.to(dtype)
949
-
950
- def set_vision_tower_dtype(self, dtype: torch.dtype):
951
- self.model.set_vision_tower_dtype(dtype)
952
-
953
- def set_vision_adapter_dtype(self, dtype: torch.dtype):
954
- self.model.set_vision_adapter_dtype(dtype)
955
-
956
- def is_llm_frozen(self):
957
- return all(not p.requires_grad for p in self.get_llm_parameters())
958
-
959
- def is_vision_tower_frozen(self):
960
- return self.model.is_vision_tower_frozen()
961
-
962
- def is_vision_adapter_frozen(self):
963
- return self.model.is_vision_adapter_frozen()
964
-
965
-
966
-
967
- def initialize_vision_modules(self):
968
- self.model.initialize_vision_modules()
969
-
970
- def get_multimodal_input_embeddings(self, input_ids, image_features, return_labels=True) -> torch.Tensor:
971
- emb_start_image_id = self.model.embed_tokens(torch.tensor([self.image_start_id], device=self.device))
972
- emb_end_image_id = self.model.embed_tokens(torch.tensor([self.image_end_id], device=self.device))
973
- id_ignore = torch.tensor([IGNORE_INDEX], device=self.device)
974
-
975
- # repetition‑penalty safety ????
976
- # input_ids[input_ids == self.image_token_id] = 0
977
-
978
-
979
- # Iterate over each batch item
980
- embeds_list, labels_list = [], []
981
- for batch_id, item_ids in enumerate(input_ids):
982
-
983
- image_token_positions = (item_ids == self.image_token_id).nonzero(as_tuple=True)[0]
984
- if len(image_token_positions) != image_features[batch_id].shape[0]:
985
- raise ValueError(
986
- f"Mismatch between number of image tokens ({len(image_token_positions)}) and number of image features ({image_features[batch_id].shape[0]})"
987
- )
988
-
989
-
990
- cursor = 0
991
- emb_parts, lbl_parts = [], []
992
- for indx_image, image_token_pos in enumerate(image_token_positions):
993
- if image_token_pos > cursor:
994
- span = item_ids[cursor:image_token_pos]
995
- emb_parts.append(self.model.embed_tokens(span))
996
- lbl_parts.append(span)
997
-
998
- # <image_start>
999
- emb_parts.append(emb_start_image_id)
1000
- lbl_parts.append(id_ignore)
1001
-
1002
- # vision embeddings
1003
- image_tokens = image_features[batch_id][indx_image]
1004
- if image_tokens.shape[0] == 1 and image_tokens.ndim == 3:
1005
- image_tokens = image_tokens.squeeze(0)
1006
- emb_parts.append(image_tokens)
1007
- lbl_parts.append(id_ignore.repeat(image_tokens.shape[0]))
1008
-
1009
- # <image_end>
1010
- emb_parts.append(emb_end_image_id)
1011
- lbl_parts.append(id_ignore)
1012
-
1013
- cursor = image_token_pos + 1
1014
-
1015
- # tail text
1016
- if cursor < item_ids.shape[0]:
1017
- tail = item_ids[cursor:]
1018
- emb_parts.append(self.model.embed_tokens(tail))
1019
- lbl_parts.append(tail)
1020
-
1021
- embeds_list.append(torch.cat(emb_parts, dim=0))
1022
- labels_list.append(torch.cat(lbl_parts, dim=0))
1023
-
1024
- return (embeds_list, labels_list) if return_labels else embeds_list
1025
-
1026
- def prepare_inputs_for_multimodal(
1027
- self,
1028
- input_ids: torch.LongTensor,
1029
- images: List[List[PIL.Image.Image]], # B x N
1030
- position_ids: Optional[torch.LongTensor],
1031
- attention_mask: Optional[torch.Tensor],
1032
- past_key_values: Optional[List[torch.FloatTensor]],
1033
- labels: Optional[torch.LongTensor],
1034
- ) -> Tuple[Optional[torch.Tensor], Optional[torch.LongTensor], Optional[torch.Tensor], Optional[List[torch.FloatTensor]], torch.Tensor, Optional[torch.Tensor]]:
1035
-
1036
- # ─────────────────── early return (no image / streaming step) ───────────────────
1037
- # if we have already processed images and are in a streaming step we can skip the multimodal processing
1038
- # but we need to ensure the attention mask and position ids are correct
1039
-
1040
- if past_key_values is not None and attention_mask is not None and input_ids.shape[1] == 1:
1041
- tgt = past_key_values[-1][-1].shape[-2] + 1
1042
- attention_mask = torch.cat(
1043
- [attention_mask,
1044
- torch.ones((attention_mask.size(0),
1045
- tgt - attention_mask.size(1)),
1046
- dtype=attention_mask.dtype,
1047
- device=attention_mask.device)],
1048
- dim=1,
1049
- )
1050
- position_ids = (attention_mask.sum(dim=1, keepdim=True) - 1).long()
1051
-
1052
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
1053
-
1054
- # ─────────────────────────── images: (B, N) ───────────────────────────
1055
- if isinstance(images, list) and isinstance(images[0], list):
1056
- # images is a list of lists, each containing multiple images, B x N
1057
- # e.g. [[img1, img2], [img3, img4]]
1058
- assert len(images) == input_ids.shape[0], f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1059
- image_features = []
1060
- for sublst_images in images:
1061
- if len(sublst_images) == 0:
1062
- image_features.append(torch.zeros((0, self.get_model().mm_projector.output_dim), device=self.device))
1063
- else:
1064
- if isinstance(sublst_images[0], PIL.Image.Image):
1065
- image_features.append(
1066
- self.model.compute_image_features(
1067
- self.model.vision_tower.preprocess_images(sublst_images, pad_and_stack_tensors=True)
1068
- )
1069
- )
1070
- elif isinstance(sublst_images[0], torch.Tensor):
1071
- # This should be a list of tensors of pre-processed images, [(N X 3 X W x H), ...]
1072
- image_features.append(
1073
- self.model.compute_image_features(sublst_images)
1074
- )
1075
- elif isinstance(images, list) and isinstance(images[0], PIL.Image.Image):
1076
- # images is a list of images for a single batch item, 1 x N
1077
- # e.g. [img1, img2, img3]
1078
- assert input_ids.shape[0] == 1, f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1079
- image_features = [
1080
- self.model.compute_image_features(
1081
- self.model.vision_tower.preprocess_images(images, pad_and_stack_tensors=True)
1082
- )
1083
- ]
1084
- elif isinstance(images, list) and isinstance(images[0], torch.Tensor):
1085
- # This should be a list of tensors of pre-processed images, [(N X 3 X W x H), ...]
1086
- # The list length should match the batch size
1087
- assert input_ids.shape[0] == len(images), f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1088
- image_features = [
1089
- self.model.compute_image_features(imgs) for imgs in images
1090
- ]
1091
- elif isinstance(images, PIL.Image.Image):
1092
- # images is a single image, 1 x 1
1093
- # e.g. img1
1094
- assert input_ids.shape[0] == 1, f"Batch size mismatch: {len(images)} vs {input_ids.shape[0]}"
1095
- image_features = [
1096
- self.model.compute_image_features(
1097
- self.model.vision_tower.preprocess_images([images])
1098
- )
1099
- ]
1100
- else:
1101
- raise ValueError(f"Unsupported images format: {type(images)}. Expected list of PIL images, a single PIL image or a Tensor of pre-processed images")
1102
-
1103
- # ─────────────────────────── image_features: (B x N x D) ───────────────────────────
1104
- if isinstance(image_features, list):
1105
- assert input_ids.shape[0] == len(image_features), f"Incorrectly formatted image_features: list length should match batch size"
1106
- assert isinstance(image_features[0], torch.Tensor), f"Incorrectly formatted image_features: list items should be tensors"
1107
- elif isinstance(image_features, torch.Tensor):
1108
- assert input_ids.shape[0] == image_features.shape[0], f"Incorrectly formatted image_features: tensor should match batch size"
1109
-
1110
-
1111
- # ───────────────────────────── pad handling prelims ──────────────────────────────
1112
- if attention_mask is None:
1113
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1114
- else:
1115
- attention_mask = attention_mask.bool()
1116
- if position_ids is None:
1117
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
1118
-
1119
- input_ids_nopad = [ids[mask] for ids, mask in zip(input_ids, attention_mask)]
1120
- embeds_list, labels_list = self.get_multimodal_input_embeddings(
1121
- input_ids_nopad,
1122
- image_features,
1123
- return_labels=True
1124
- )
1125
-
1126
- # ───────────────────── truncate then pad back to rectangle ──────────────────────
1127
- new_input_embeds = torch.nn.utils.rnn.pad_sequence(
1128
- embeds_list,
1129
- batch_first=True,
1130
- padding_value=0.0
1131
- ).to(dtype=self.dtype)
1132
-
1133
- new_labels = torch.nn.utils.rnn.pad_sequence(
1134
- labels_list,
1135
- batch_first=True,
1136
- padding_value=IGNORE_INDEX
1137
- ).long()
1138
-
1139
- if self.config.tokenizer_model_max_length is not None:
1140
- new_input_embeds = new_input_embeds[:, :self.config.tokenizer_model_max_length]
1141
- new_labels = new_labels[:, :self.config.tokenizer_model_max_length]
1142
-
1143
-
1144
-
1145
-
1146
- # ────────────────────────────── attention mask and position ids ────────────────
1147
-
1148
- attention_mask = (
1149
- torch.arange(new_input_embeds.size(1), device=input_ids.device)
1150
- .unsqueeze(0)
1151
- < torch.tensor([e.size(0) for e in embeds_list],
1152
- device=input_ids.device).unsqueeze(1)
1153
- )
1154
-
1155
- raw_pos = attention_mask.cumsum(dim=1) - 1
1156
- position_ids = raw_pos.masked_fill(~attention_mask, 0).long()
1157
-
1158
- if not self.training:
1159
- new_labels = None
1160
-
1161
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1162
-
1163
-
1164
-
1165
- # ------------------------------------------------------------------
1166
- def forward(
1167
- self,
1168
- input_ids: torch.LongTensor = None,
1169
- attention_mask: Optional[torch.Tensor] = None,
1170
- position_ids: Optional[torch.LongTensor] = None,
1171
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1172
- inputs_embeds: Optional[torch.FloatTensor] = None,
1173
- labels: Optional[torch.LongTensor] = None,
1174
- use_cache: Optional[bool] = None,
1175
- output_attentions: Optional[bool] = None,
1176
- output_hidden_states: Optional[bool] = None,
1177
- return_dict: Optional[bool] = None,
1178
- cache_position: Optional[torch.LongTensor] = None,
1179
- logits_to_keep: Union[int, torch.Tensor] = 0,
1180
- images: Optional[PIL.Image.Image] = None,
1181
- **kwargs: Unpack[KwargsForCausalLM],
1182
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1183
-
1184
- is_multi_modal = images is not None and not (
1185
- (
1186
- isinstance(images, list) and (len(images) == 0 or all(i == [] for i in images))
1187
- )
1188
- )
1189
-
1190
-
1191
- if inputs_embeds is None and is_multi_modal:
1192
- (
1193
- input_ids,
1194
- position_ids,
1195
- attention_mask,
1196
- past_key_values,
1197
- inputs_embeds,
1198
- labels
1199
- ) = self.prepare_inputs_for_multimodal(
1200
- input_ids=input_ids,
1201
- images=images,
1202
- position_ids=position_ids,
1203
- attention_mask=attention_mask,
1204
- past_key_values=past_key_values,
1205
- labels=labels,
1206
- )
1207
-
1208
- if cache_position is not None and inputs_embeds is not None and cache_position.shape[0] != inputs_embeds.shape[1]:
1209
- cache_position = torch.arange(inputs_embeds.shape[1], device=self.device)
1210
-
1211
-
1212
- return Phi3ForCausalLM.forward(
1213
- self,
1214
- input_ids=input_ids,
1215
- attention_mask=attention_mask,
1216
- position_ids=position_ids,
1217
- past_key_values=past_key_values,
1218
- inputs_embeds=inputs_embeds,
1219
- labels=labels,
1220
- use_cache=use_cache,
1221
- output_attentions=output_attentions,
1222
- output_hidden_states=output_hidden_states,
1223
- return_dict=return_dict,
1224
- cache_position=cache_position,
1225
- logits_to_keep=logits_to_keep,
1226
- **kwargs
1227
- )
1228
-
1229
- def print_device_configuration(self):
1230
- print("*************Device Configuration*********")
1231
- if len(self.get_llm_parameters()) > 0:
1232
- llm_device = set({str(p.device) for p in self.get_llm_parameters()})
1233
- llm_dtype = set({p.dtype for p in self.get_llm_parameters()})
1234
- print(f"LLM Parameters:\t\t\tdevice: {llm_device}\tdtype: {llm_dtype}\tfrozen: {self.is_llm_frozen()}")
1235
- else:
1236
- print("LLM parameters have not been initialized")
1237
-
1238
- if self.get_model().vision_tower is not None:
1239
- vt_device = set({str(p.device) for p in self.get_model().vision_tower.parameters()})
1240
- vt_dtype = set({p.dtype for p in self.get_model().vision_tower.parameters()})
1241
- print(f"Vision Tower Parameters:\tdevice: {vt_device}\tdtype: {vt_dtype}\tfrozen: {self.is_vision_tower_frozen()}")
1242
- else:
1243
- print("Vision tower parameters have not been initialized")
1244
-
1245
- if self.get_model().mm_projector is not None:
1246
- mm_device = set({str(p.device) for p in self.get_model().mm_projector.parameters()})
1247
- mm_dtype = set({p.dtype for p in self.get_model().mm_projector.parameters()})
1248
- print(f"MM Projector Parameters:\tdevice: {mm_device}\tdtype: {mm_dtype}\tfrozen: {self.is_vision_adapter_frozen()}")
1249
- else:
1250
- print("MM Projector parameters have not been initialized")
1251
- print("******************************************")
1252
-
1253
-
1254
-
1255
- def build_tokenizer(base_model_id: str) -> Tuple[AutoTokenizer, dict]:
1256
- tok = AutoTokenizer.from_pretrained(base_model_id, padding_side="right")
1257
- specials = {t: tok.convert_tokens_to_ids(t) for t in [IMAGE_TOKEN, IMG_START_TOKEN, IMG_END_TOKEN] if t in tok.vocab}
1258
- if len(specials) < 3:
1259
- n = tok.add_tokens([IMAGE_TOKEN, IMG_START_TOKEN, IMG_END_TOKEN], special_tokens=True)
1260
- tok.pad_token = tok.eos_token
1261
- specials = {
1262
- "image": tok.convert_tokens_to_ids(IMAGE_TOKEN),
1263
- "start": tok.convert_tokens_to_ids(IMG_START_TOKEN),
1264
- "end": tok.convert_tokens_to_ids(IMG_END_TOKEN),
1265
- }
1266
- return tok, specials
1267
-