harryrobert commited on
Commit
ad3aae7
·
verified ·
1 Parent(s): cb5ead9

Upload nav2tex/modeling_latex_ocr.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. nav2tex/modeling_latex_ocr.py +508 -0
nav2tex/modeling_latex_ocr.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from functools import partial
5
+ from torch import nn
6
+ from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutput
9
+
10
+ try:
11
+ from .configuration_latex_decoder import LaTeXDecoderConfig
12
+ from .configuration_latex_ocr import Nav2TexConfig
13
+ from .modeling_latex_decoder import LaTeXDecoderForCausalLM
14
+ except ImportError:
15
+ from nav2tex.configuration_latex_decoder import LaTeXDecoderConfig
16
+ from nav2tex.configuration_latex_ocr import Nav2TexConfig
17
+ from nav2tex.modeling_latex_decoder import LaTeXDecoderForCausalLM
18
+
19
+ try:
20
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
21
+ from flash_attn.bert_padding import pad_input, unpad_input
22
+ HAS_FLASH_ATTN = True
23
+ except ImportError:
24
+ HAS_FLASH_ATTN = False
25
+
26
+
27
+ def exists(val):
28
+ return val is not None
29
+
30
+
31
+ def divisible_by(numer, denom):
32
+ return (numer % denom) == 0
33
+
34
+
35
+ class LayerNorm(nn.Module):
36
+ def __init__(self, dim):
37
+ super().__init__()
38
+ self.normalized_shape = (dim,)
39
+ self.eps = 1e-5
40
+ self.weight = nn.Parameter(torch.ones(dim))
41
+ self.bias = nn.Parameter(torch.zeros(dim))
42
+
43
+ def forward(self, x):
44
+ return F.layer_norm(
45
+ x.float(), self.normalized_shape,
46
+ self.weight.float(), self.bias.float(), self.eps,
47
+ ).to(x.dtype)
48
+
49
+
50
+ class RMSNorm(nn.Module):
51
+ def __init__(self, heads, dim):
52
+ super().__init__()
53
+ self.scale = dim ** 0.5
54
+ self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
55
+
56
+ def forward(self, x):
57
+ return F.normalize(x, dim=-1) * self.scale * self.gamma.to(x.dtype)
58
+
59
+
60
+ def rotate_half(x):
61
+ x1, x2 = x.chunk(2, dim=-1)
62
+ return torch.cat([-x2, x1], dim=-1)
63
+
64
+
65
+ def apply_2d_rope(q, k, h_idx, w_idx):
66
+ _, _, _, d = q.shape
67
+ if d % 4 != 0:
68
+ raise ValueError(f"apply_2d_rope expects dim_head divisible by 4, got D={d}")
69
+ dim_half = d // 2
70
+ dim_quarter = d // 4
71
+ inv_freq = 1.0 / (10000 ** (torch.arange(dim_quarter, device=q.device).float() / dim_quarter))
72
+ h_theta = h_idx[..., None].float() * inv_freq
73
+ w_theta = w_idx[..., None].float() * inv_freq
74
+ sin_h = torch.cat([h_theta.sin(), h_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
75
+ cos_h = torch.cat([h_theta.cos(), h_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
76
+ sin_w = torch.cat([w_theta.sin(), w_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
77
+ cos_w = torch.cat([w_theta.cos(), w_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
78
+
79
+ def rope(x, sin, cos):
80
+ return x * cos + rotate_half(x) * sin
81
+
82
+ q = torch.cat([rope(q[..., :dim_half], sin_h, cos_h), rope(q[..., dim_half:], sin_w, cos_w)], dim=-1)
83
+ k = torch.cat([rope(k[..., :dim_half], sin_h, cos_h), rope(k[..., dim_half:], sin_w, cos_w)], dim=-1)
84
+ return q, k
85
+
86
+
87
+ class FeedForward(nn.Module):
88
+ def __init__(self, dim, hidden_dim, dropout=0.0):
89
+ super().__init__()
90
+ self.net = nn.Sequential(
91
+ LayerNorm(dim),
92
+ nn.Linear(dim, hidden_dim),
93
+ nn.GELU(),
94
+ nn.Dropout(dropout),
95
+ nn.Linear(hidden_dim, dim),
96
+ nn.Dropout(dropout),
97
+ )
98
+
99
+ def forward(self, x):
100
+ return self.net(x)
101
+
102
+
103
+ class Attention(nn.Module):
104
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
105
+ super().__init__()
106
+ inner_dim = dim_head * heads
107
+ self.heads = heads
108
+ self.norm = LayerNorm(dim)
109
+ self.q_norm = RMSNorm(heads, dim_head)
110
+ self.k_norm = RMSNorm(heads, dim_head)
111
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
112
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
113
+ self.attend = nn.Softmax(dim=-1)
114
+ self.dropout = nn.Dropout(dropout)
115
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim, bias=False), nn.Dropout(dropout))
116
+
117
+ def forward(self, x, mask=None, attn_mask=None, positions=None):
118
+ x = self.norm(x)
119
+ q = self.to_q(x)
120
+ k, v = self.to_kv(x).chunk(2, dim=-1)
121
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v))
122
+ q = self.q_norm(q)
123
+ k = self.k_norm(k)
124
+
125
+ if positions is not None:
126
+ q, k = apply_2d_rope(q, k, positions[0], positions[1])
127
+
128
+ if HAS_FLASH_ATTN and x.is_cuda and attn_mask is None:
129
+ fa_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.bfloat16
130
+ q_ = rearrange(q, "b h n d -> b n h d").contiguous().to(fa_dtype)
131
+ k_ = rearrange(k, "b h n d -> b n h d").contiguous().to(fa_dtype)
132
+ v_ = rearrange(v, "b h n d -> b n h d").contiguous().to(fa_dtype)
133
+ if exists(mask):
134
+ batch, seqlen = mask.shape
135
+ q_unpad, indices, cu_q, max_q, *_ = unpad_input(q_, mask)
136
+ k_unpad, _, cu_k, max_k, *_ = unpad_input(k_, mask)
137
+ v_unpad, _, _, _, *_ = unpad_input(v_, mask)
138
+ out_unpad = flash_attn_varlen_func(
139
+ q_unpad, k_unpad, v_unpad,
140
+ cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
141
+ max_seqlen_q=max_q, max_seqlen_k=max_k,
142
+ dropout_p=self.dropout.p if self.training else 0.0,
143
+ causal=False,
144
+ )
145
+ out = pad_input(out_unpad, indices, batch, seqlen)
146
+ else:
147
+ out = flash_attn_func(
148
+ q_, k_, v_,
149
+ dropout_p=self.dropout.p if self.training else 0.0,
150
+ causal=False,
151
+ )
152
+ out = rearrange(out, "b n h d -> b n (h d)").to(x.dtype)
153
+ else:
154
+ dots = torch.matmul(q, k.transpose(-1, -2))
155
+ if exists(mask):
156
+ dots = dots.masked_fill(~mask[:, None, None, :], -torch.finfo(dots.dtype).max)
157
+ if exists(attn_mask):
158
+ dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
159
+ attn = self.dropout(self.attend(dots))
160
+ out = rearrange(torch.matmul(attn, v), "b h n d -> b n (h d)")
161
+ return self.to_out(out)
162
+
163
+
164
+ class Transformer(nn.Module):
165
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0):
166
+ super().__init__()
167
+ self.layers = nn.ModuleList([
168
+ nn.ModuleList([Attention(dim, heads, dim_head, dropout), FeedForward(dim, mlp_dim, dropout)])
169
+ for _ in range(depth)
170
+ ])
171
+ self.norm = LayerNorm(dim)
172
+
173
+ def forward(self, x, mask=None, attn_mask=None, positions=None):
174
+ for attn, ff in self.layers:
175
+ x = attn(x, mask=mask, attn_mask=attn_mask, positions=positions) + x
176
+ x = ff(x) + x
177
+ return self.norm(x)
178
+
179
+
180
+ class NaViT_Encoder(nn.Module):
181
+ def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim,
182
+ channels=3, dim_head=64, dropout=0.0, emb_dropout=0.0):
183
+ super().__init__()
184
+ image_height, image_width = image_size
185
+ assert divisible_by(image_height, patch_size)
186
+ assert divisible_by(image_width, patch_size)
187
+ self.patch_size = patch_size
188
+ self.to_patch_embedding = nn.Sequential(
189
+ LayerNorm(channels * patch_size ** 2),
190
+ nn.Linear(channels * patch_size ** 2, dim),
191
+ LayerNorm(dim),
192
+ )
193
+ self.dropout = nn.Dropout(emb_dropout)
194
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
195
+
196
+ @property
197
+ def device(self):
198
+ return next(self.parameters()).device
199
+
200
+ def forward(self, batched_images):
201
+ p = self.patch_size
202
+ device = self.device
203
+ arange = partial(torch.arange, device=device)
204
+ pad_sequence = partial(orig_pad_sequence, batch_first=True)
205
+ batched_sequences, batched_positions = [], []
206
+
207
+ for images in batched_images:
208
+ sequences, positions = [], []
209
+ for image in images:
210
+ _, h, w = image.shape
211
+ ph, pw = h // p, w // p
212
+ seq = rearrange(image, "c (h p1) (w p2) -> (h w) (c p1 p2)", p1=p, p2=p)
213
+ pos = torch.stack(torch.meshgrid(arange(ph), arange(pw), indexing="ij"), dim=-1)
214
+ sequences.append(seq)
215
+ positions.append(rearrange(pos, "h w c -> (h w) c"))
216
+ batched_sequences.append(torch.cat(sequences, dim=0))
217
+ batched_positions.append(torch.cat(positions, dim=0))
218
+
219
+ patches = pad_sequence(batched_sequences)
220
+ patch_positions = pad_sequence(batched_positions)
221
+ lengths = torch.tensor([seq.shape[0] for seq in batched_sequences], device=device)
222
+ mask = torch.arange(patches.shape[1], device=device)[None, :] < lengths[:, None]
223
+ x = self.to_patch_embedding(patches.to(next(self.parameters()).dtype))
224
+ h_idx, w_idx = patch_positions.unbind(dim=-1)
225
+ x = self.dropout(x)
226
+ x = self.transformer(x, mask=mask, positions=(h_idx, w_idx))
227
+ return x, mask
228
+
229
+
230
+ class MLPProjector(nn.Module):
231
+ def __init__(self, vision_hidden_size=1024, llm_hidden_size=512, intermediate_size=2048):
232
+ super().__init__()
233
+ self.norm = nn.LayerNorm(vision_hidden_size)
234
+ self.gate_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
235
+ self.up_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
236
+ self.down_proj = nn.Linear(intermediate_size, llm_hidden_size, bias=False)
237
+
238
+ def forward(self, x):
239
+ x = self.norm(x)
240
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
241
+
242
+
243
+ class VisualEncoder(nn.Module):
244
+ def __init__(self, encoder, bridge, max_visual_tokens):
245
+ super().__init__()
246
+ self.navit = encoder
247
+ self.projector = bridge
248
+ self.max_visual_tokens = max_visual_tokens
249
+
250
+ def forward(self, batched_images):
251
+ x, mask = self.navit(batched_images)
252
+ if x.shape[1] > self.max_visual_tokens:
253
+ x = x[:, :self.max_visual_tokens]
254
+ mask = mask[:, :self.max_visual_tokens]
255
+ return self.projector(x), mask
256
+
257
+
258
+ class CustomDecoder(nn.Module):
259
+ def __init__(self, config: Nav2TexConfig):
260
+ super().__init__()
261
+ dec = config.decoder_arch
262
+ self._model = LaTeXDecoderForCausalLM(
263
+ LaTeXDecoderConfig(
264
+ vocab_size=dec["vocab_size"],
265
+ pad_id=dec["pad_id"],
266
+ bos_id=dec["bos_id"],
267
+ eos_id=dec["eos_id"],
268
+ d_model=dec["d_model"],
269
+ n_heads=dec["n_heads"],
270
+ n_layers=dec["n_layers"],
271
+ d_ff=dec["d_ff"],
272
+ dropout=dec.get("dropout", 0.1),
273
+ max_seq_len=dec["max_seq_len"],
274
+ rope_theta=dec.get("rope_theta", 10000.0),
275
+ tie_weights=dec.get("tie_weights", True),
276
+ )
277
+ )
278
+ self.pad_token_id = self._model.config.pad_id
279
+ self.eos_token_id = self._model.config.eos_id
280
+ self._vocab_size = self._model.config.vocab_size
281
+ self._pad_id = self._model.config.pad_id
282
+ if not config.decoder_weights_tied:
283
+ self.untie_weights()
284
+
285
+ def get_input_embeddings(self):
286
+ return self._model.embed_tokens
287
+
288
+ def tie_weights(self):
289
+ self._model.lm_head.weight = self._model.embed_tokens.weight
290
+
291
+ def untie_weights(self):
292
+ if self.are_weights_tied():
293
+ self._model.lm_head.weight = nn.Parameter(self._model.embed_tokens.weight.detach().clone())
294
+
295
+ def are_weights_tied(self):
296
+ return self._model.lm_head.weight.data_ptr() == self._model.embed_tokens.weight.data_ptr()
297
+
298
+ def _forward_embeds(self, inputs_embeds, attention_mask=None):
299
+ x = self._model.embed_drop(inputs_embeds)
300
+ mask = attention_mask.bool() if attention_mask is not None else None
301
+ for layer in self._model.layers:
302
+ x = layer(x, mask)
303
+ return self._model.lm_head(self._model.norm_final(x))
304
+
305
+ def forward(self, inputs_embeds=None, attention_mask=None, labels=None, **kwargs):
306
+ logits = self._forward_embeds(inputs_embeds, attention_mask)
307
+ loss = None
308
+ if labels is not None:
309
+ shift_logits = logits[:, :-1].contiguous()
310
+ shift_labels = labels[:, 1:].contiguous().masked_fill(
311
+ labels[:, 1:].contiguous() == self._pad_id, -100
312
+ )
313
+ loss = F.cross_entropy(
314
+ shift_logits.view(-1, self._vocab_size),
315
+ shift_labels.view(-1),
316
+ ignore_index=-100,
317
+ )
318
+ return BaseModelOutput(last_hidden_state=logits, hidden_states=(loss,))
319
+
320
+ @torch.no_grad()
321
+ def generate(self, inputs_embeds, attention_mask, max_new_tokens, num_beams=1):
322
+ device = inputs_embeds.device
323
+ batch = inputs_embeds.shape[0]
324
+
325
+ if num_beams > 1:
326
+ # beam search: only supports batch_size=1
327
+ assert batch == 1, "beam search only supports batch_size=1"
328
+ return self._beam_search(inputs_embeds, attention_mask, max_new_tokens, num_beams)
329
+
330
+ return self._greedy_batch(inputs_embeds, attention_mask, max_new_tokens)
331
+
332
+ @torch.no_grad()
333
+ def _greedy_batch(self, inputs_embeds, attention_mask, max_new_tokens):
334
+ """Greedy decoding with true batch support."""
335
+ eos_id = self.eos_token_id
336
+ pad_id = self._pad_id
337
+ device = inputs_embeds.device
338
+ batch = inputs_embeds.shape[0]
339
+ d_model = inputs_embeds.shape[-1]
340
+
341
+ # generated token ids per sample, and finished flags
342
+ gen_ids = [[] for _ in range(batch)]
343
+ finished = torch.zeros(batch, dtype=torch.bool, device=device)
344
+
345
+ cur_embeds = inputs_embeds # (B, vis_len, D)
346
+ cur_mask = attention_mask # (B, vis_len)
347
+
348
+ for _ in range(max_new_tokens):
349
+ logits = self._forward_embeds(cur_embeds, cur_mask) # (B, seq, vocab)
350
+ next_tok = logits[:, -1, :].argmax(dim=-1) # (B,)
351
+
352
+ for i in range(batch):
353
+ if not finished[i]:
354
+ gen_ids[i].append(next_tok[i].item())
355
+ finished |= (next_tok == eos_id)
356
+ if finished.all():
357
+ break
358
+
359
+ tok_emb = self._model.embed_tokens(next_tok.unsqueeze(1)) # (B, 1, D)
360
+ tok_mask = cur_mask.new_ones(batch, 1)
361
+ cur_embeds = torch.cat([cur_embeds, tok_emb], dim=1)
362
+ cur_mask = torch.cat([cur_mask, tok_mask], dim=1)
363
+
364
+ # pad to same length and return (B, max_len)
365
+ max_len = max((len(ids) for ids in gen_ids), default=0)
366
+ if max_len == 0:
367
+ return torch.zeros(batch, 0, dtype=torch.long, device=device)
368
+ out = torch.full((batch, max_len), pad_id, dtype=torch.long, device=device)
369
+ for i, ids in enumerate(gen_ids):
370
+ if ids:
371
+ out[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device)
372
+ return out
373
+
374
+ @torch.no_grad()
375
+ def _beam_search(self, inputs_embeds, attention_mask, max_new_tokens, num_beams):
376
+ """Original beam search (batch_size=1 only)."""
377
+ eos_id = self.eos_token_id
378
+ device = inputs_embeds.device
379
+ vis_emb = inputs_embeds[0]
380
+ vis_len = vis_emb.shape[0]
381
+ vis_mask = attention_mask[0] if attention_mask is not None else None
382
+ beams = [(0.0, [], False) for _ in range(num_beams)]
383
+
384
+ for _ in range(max_new_tokens):
385
+ all_embeds, all_masks = [], []
386
+ for score, ids, _ in beams:
387
+ tok_emb = self._model.embed_tokens(torch.tensor(ids, device=device, dtype=torch.long)) if ids else None
388
+ seq_emb = torch.cat([vis_emb, tok_emb], dim=0) if tok_emb is not None else vis_emb
389
+ all_embeds.append(seq_emb)
390
+ if vis_mask is not None:
391
+ tok_mask = vis_mask.new_ones(len(ids)) if ids else vis_mask.new_zeros(0)
392
+ all_masks.append(torch.cat([vis_mask, tok_mask]) if ids else vis_mask)
393
+
394
+ max_len = max(e.shape[0] for e in all_embeds)
395
+ d_model = all_embeds[0].shape[-1]
396
+ padded_embeds = vis_emb.new_zeros(num_beams, max_len, d_model)
397
+ padded_mask = vis_mask.new_zeros(num_beams, max_len) if vis_mask is not None else None
398
+ for idx, emb in enumerate(all_embeds):
399
+ padded_embeds[idx, :emb.shape[0]] = emb
400
+ if padded_mask is not None:
401
+ padded_mask[idx, :emb.shape[0]] = all_masks[idx]
402
+
403
+ logits = self._forward_embeds(padded_embeds, padded_mask)
404
+ candidates = []
405
+ for beam_idx, (score, ids, done) in enumerate(beams):
406
+ if done:
407
+ candidates.append((score, ids, True))
408
+ continue
409
+ last_pos = vis_len + len(ids) - 1
410
+ log_p = torch.log_softmax(logits[beam_idx, last_pos, :], dim=-1)
411
+ if len(ids) == 0 and beam_idx > 0:
412
+ log_p = log_p.fill_(-1e9)
413
+ for lp, tok in zip(*map(lambda t: t.tolist(), log_p.topk(num_beams))):
414
+ candidates.append((score + lp, ids + [tok], tok == eos_id))
415
+ candidates.sort(key=lambda x: -x[0])
416
+ beams = candidates[:num_beams]
417
+ if all(done for _, _, done in beams):
418
+ break
419
+
420
+ best_ids = max(beams, key=lambda x: x[0])[1]
421
+ if not best_ids:
422
+ return torch.zeros(1, 0, dtype=torch.long, device=device)
423
+ return torch.tensor(best_ids, dtype=torch.long, device=device).unsqueeze(0)
424
+
425
+
426
+ class Nav2TexModel(PreTrainedModel):
427
+ config_class = Nav2TexConfig
428
+ base_model_prefix = "model"
429
+ main_input_name = "pixel_values"
430
+
431
+ def __init__(self, config: Nav2TexConfig):
432
+ super().__init__(config)
433
+ self.config = config
434
+ self.visual_encoder = VisualEncoder(
435
+ NaViT_Encoder(
436
+ image_size=(config.image_height, config.max_image_width),
437
+ patch_size=config.patch_size,
438
+ dim=config.navit_dim,
439
+ depth=config.navit_depth,
440
+ heads=config.navit_heads,
441
+ mlp_dim=config.navit_mlp_dim,
442
+ dim_head=config.navit_dim_head,
443
+ dropout=config.navit_dropout,
444
+ emb_dropout=config.navit_emb_dropout,
445
+ ),
446
+ MLPProjector(
447
+ vision_hidden_size=config.vision_hidden_size,
448
+ llm_hidden_size=config.llm_hidden_size,
449
+ intermediate_size=config.projector_intermediate_size,
450
+ ),
451
+ max_visual_tokens=config.max_visual_tokens,
452
+ )
453
+ self.decoder = CustomDecoder(config)
454
+ self.post_init()
455
+
456
+ def tie_weights(self, **kwargs):
457
+ if self.config.decoder_weights_tied:
458
+ self.decoder.tie_weights()
459
+ else:
460
+ self.decoder.untie_weights()
461
+
462
+ def _init_weights(self, module):
463
+ return
464
+
465
+ @staticmethod
466
+ def _to_batched_images(pixel_values):
467
+ if isinstance(pixel_values, list):
468
+ return pixel_values
469
+ if isinstance(pixel_values, torch.Tensor):
470
+ return [[img] for img in pixel_values]
471
+ raise TypeError(f"Unsupported pixel_values type: {type(pixel_values)}")
472
+
473
+ def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None, **kwargs):
474
+ batched_images = self._to_batched_images(pixel_values)
475
+ ve, vm = self.visual_encoder(batched_images)
476
+ if input_ids is None:
477
+ return BaseModelOutput(last_hidden_state=ve)
478
+ te = self.decoder.get_input_embeddings()(input_ids)
479
+ inputs_embeds = torch.cat([ve, te], dim=1)
480
+ am = torch.cat([vm.to(dtype=attention_mask.dtype), attention_mask], dim=1)
481
+ lv = torch.full((labels.shape[0], ve.shape[1]), -100, dtype=labels.dtype, device=labels.device)
482
+ out = self.decoder(
483
+ inputs_embeds=inputs_embeds,
484
+ attention_mask=am,
485
+ labels=torch.cat([lv, labels], dim=1),
486
+ )
487
+ return BaseModelOutput(last_hidden_state=out.last_hidden_state, hidden_states=(out.hidden_states[0],))
488
+
489
+ @torch.no_grad()
490
+ def generate(self, pixel_values, max_new_tokens=None, num_beams=None):
491
+ batched_images = self._to_batched_images(pixel_values)
492
+ ve, vm = self.visual_encoder(batched_images)
493
+ batch = ve.shape[0]
494
+ bos_id = self.config.decoder_arch["bos_id"]
495
+ bos_emb = self.decoder.get_input_embeddings()(
496
+ torch.full((batch, 1), bos_id, dtype=torch.long, device=ve.device)
497
+ )
498
+ inputs_embeds = torch.cat([ve, bos_emb], dim=1)
499
+ attention_mask = torch.cat([
500
+ vm.to(dtype=torch.long),
501
+ torch.ones(batch, 1, dtype=torch.long, device=ve.device)
502
+ ], dim=1)
503
+ return self.decoder.generate(
504
+ inputs_embeds=inputs_embeds,
505
+ attention_mask=attention_mask,
506
+ max_new_tokens=max_new_tokens or self.config.max_new_tokens,
507
+ num_beams=num_beams or self.config.num_beams,
508
+ )