harryrobert commited on
Commit
cb5ead9
·
verified ·
1 Parent(s): 187408a

Upload folder using huggingface_hub

Browse files
__pycache__/configuration_latex_decoder.cpython-312.pyc CHANGED
Binary files a/__pycache__/configuration_latex_decoder.cpython-312.pyc and b/__pycache__/configuration_latex_decoder.cpython-312.pyc differ
 
__pycache__/configuration_latex_ocr.cpython-312.pyc CHANGED
Binary files a/__pycache__/configuration_latex_ocr.cpython-312.pyc and b/__pycache__/configuration_latex_ocr.cpython-312.pyc differ
 
__pycache__/modeling_latex_decoder.cpython-312.pyc CHANGED
Binary files a/__pycache__/modeling_latex_decoder.cpython-312.pyc and b/__pycache__/modeling_latex_decoder.cpython-312.pyc differ
 
__pycache__/modeling_latex_ocr.cpython-312.pyc CHANGED
Binary files a/__pycache__/modeling_latex_ocr.cpython-312.pyc and b/__pycache__/modeling_latex_ocr.cpython-312.pyc differ
 
__pycache__/tokenization_latex_ocr.cpython-312.pyc CHANGED
Binary files a/__pycache__/tokenization_latex_ocr.cpython-312.pyc and b/__pycache__/tokenization_latex_ocr.cpython-312.pyc differ
 
configuration_latex_decoder.py CHANGED
@@ -1,48 +1,48 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
- class LaTeXDecoderConfig(PretrainedConfig):
5
- model_type = "latex_decoder"
6
-
7
- def __init__(
8
- self,
9
- vocab_size: int = 8192,
10
- pad_id: int = 0,
11
- bos_id: int = 2,
12
- eos_id: int = 3,
13
- d_model: int = 512,
14
- n_heads: int = 8,
15
- n_layers: int = 6,
16
- d_ff: int = 1408,
17
- dropout: float = 0.1,
18
- max_seq_len: int = 200,
19
- rope_theta: float = 10000.0,
20
- tie_weights: bool = False,
21
- **kwargs,
22
- ):
23
- kwargs.pop("pad_token_id", None)
24
- kwargs.pop("bos_token_id", None)
25
- kwargs.pop("eos_token_id", None)
26
- super().__init__(
27
- pad_token_id=pad_id,
28
- bos_token_id=bos_id,
29
- eos_token_id=eos_id,
30
- **kwargs,
31
- )
32
- self.vocab_size = vocab_size
33
- self.pad_id = pad_id
34
- self.bos_id = bos_id
35
- self.eos_id = eos_id
36
- self.d_model = d_model
37
- self.n_heads = n_heads
38
- self.n_layers = n_layers
39
- self.d_ff = d_ff
40
- self.dropout = dropout
41
- self.max_seq_len = max_seq_len
42
- self.rope_theta = rope_theta
43
- self.tie_weights = tie_weights
44
-
45
- @property
46
- def head_dim(self) -> int:
47
- assert self.d_model % self.n_heads == 0
48
- return self.d_model // self.n_heads
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class LaTeXDecoderConfig(PretrainedConfig):
5
+ model_type = "latex_decoder"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size: int = 8192,
10
+ pad_id: int = 0,
11
+ bos_id: int = 2,
12
+ eos_id: int = 3,
13
+ d_model: int = 512,
14
+ n_heads: int = 8,
15
+ n_layers: int = 6,
16
+ d_ff: int = 1408,
17
+ dropout: float = 0.1,
18
+ max_seq_len: int = 200,
19
+ rope_theta: float = 10000.0,
20
+ tie_weights: bool = False,
21
+ **kwargs,
22
+ ):
23
+ kwargs.pop("pad_token_id", None)
24
+ kwargs.pop("bos_token_id", None)
25
+ kwargs.pop("eos_token_id", None)
26
+ super().__init__(
27
+ pad_token_id=pad_id,
28
+ bos_token_id=bos_id,
29
+ eos_token_id=eos_id,
30
+ **kwargs,
31
+ )
32
+ self.vocab_size = vocab_size
33
+ self.pad_id = pad_id
34
+ self.bos_id = bos_id
35
+ self.eos_id = eos_id
36
+ self.d_model = d_model
37
+ self.n_heads = n_heads
38
+ self.n_layers = n_layers
39
+ self.d_ff = d_ff
40
+ self.dropout = dropout
41
+ self.max_seq_len = max_seq_len
42
+ self.rope_theta = rope_theta
43
+ self.tie_weights = tie_weights
44
+
45
+ @property
46
+ def head_dim(self) -> int:
47
+ assert self.d_model % self.n_heads == 0
48
+ return self.d_model // self.n_heads
configuration_latex_ocr.py CHANGED
@@ -1,66 +1,66 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
- class Nav2TexConfig(PretrainedConfig):
5
- model_type = "nav2tex"
6
-
7
- def __init__(
8
- self,
9
- patch_size: int = 16,
10
- image_height: int = 64,
11
- max_image_width: int = 1024,
12
- max_image_height: int = 640,
13
- resize_in_dataset: bool = True,
14
- max_token_len: int = 200,
15
- navit_dim: int = 512,
16
- navit_depth: int = 8,
17
- navit_heads: int = 8,
18
- navit_dim_head: int = 64,
19
- navit_mlp_dim: int = 2048,
20
- navit_dropout: float = 0.0,
21
- navit_emb_dropout: float = 0.0,
22
- vision_hidden_size: int = 512,
23
- llm_hidden_size: int = 512,
24
- projector_intermediate_size: int = 1024,
25
- max_visual_tokens: int = 256,
26
- max_new_tokens: int = 200,
27
- num_beams: int = 4,
28
- decoder_arch: dict | None = None,
29
- decoder_weights_tied: bool = False,
30
- **kwargs,
31
- ):
32
- super().__init__(**kwargs)
33
- self.patch_size = patch_size
34
- self.image_height = image_height
35
- self.max_image_width = max_image_width
36
- self.max_image_height = max_image_height
37
- self.resize_in_dataset = resize_in_dataset
38
- self.max_token_len = max_token_len
39
- self.navit_dim = navit_dim
40
- self.navit_depth = navit_depth
41
- self.navit_heads = navit_heads
42
- self.navit_dim_head = navit_dim_head
43
- self.navit_mlp_dim = navit_mlp_dim
44
- self.navit_dropout = navit_dropout
45
- self.navit_emb_dropout = navit_emb_dropout
46
- self.vision_hidden_size = vision_hidden_size
47
- self.llm_hidden_size = llm_hidden_size
48
- self.projector_intermediate_size = projector_intermediate_size
49
- self.max_visual_tokens = max_visual_tokens
50
- self.max_new_tokens = max_new_tokens
51
- self.num_beams = num_beams
52
- self.decoder_arch = decoder_arch or {
53
- "vocab_size": 2046,
54
- "pad_id": 0,
55
- "bos_id": 2,
56
- "eos_id": 3,
57
- "d_model": 512,
58
- "n_heads": 8,
59
- "n_layers": 6,
60
- "d_ff": 1408,
61
- "dropout": 0.1,
62
- "max_seq_len": 200,
63
- "rope_theta": 10000.0,
64
- "tie_weights": True,
65
- }
66
- self.decoder_weights_tied = decoder_weights_tied
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class Nav2TexConfig(PretrainedConfig):
5
+ model_type = "nav2tex"
6
+
7
+ def __init__(
8
+ self,
9
+ patch_size: int = 16,
10
+ image_height: int = 64,
11
+ max_image_width: int = 1024,
12
+ max_image_height: int = 640,
13
+ resize_in_dataset: bool = True,
14
+ max_token_len: int = 200,
15
+ navit_dim: int = 512,
16
+ navit_depth: int = 8,
17
+ navit_heads: int = 8,
18
+ navit_dim_head: int = 64,
19
+ navit_mlp_dim: int = 2048,
20
+ navit_dropout: float = 0.0,
21
+ navit_emb_dropout: float = 0.0,
22
+ vision_hidden_size: int = 512,
23
+ llm_hidden_size: int = 512,
24
+ projector_intermediate_size: int = 1024,
25
+ max_visual_tokens: int = 256,
26
+ max_new_tokens: int = 200,
27
+ num_beams: int = 4,
28
+ decoder_arch: dict | None = None,
29
+ decoder_weights_tied: bool = False,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.patch_size = patch_size
34
+ self.image_height = image_height
35
+ self.max_image_width = max_image_width
36
+ self.max_image_height = max_image_height
37
+ self.resize_in_dataset = resize_in_dataset
38
+ self.max_token_len = max_token_len
39
+ self.navit_dim = navit_dim
40
+ self.navit_depth = navit_depth
41
+ self.navit_heads = navit_heads
42
+ self.navit_dim_head = navit_dim_head
43
+ self.navit_mlp_dim = navit_mlp_dim
44
+ self.navit_dropout = navit_dropout
45
+ self.navit_emb_dropout = navit_emb_dropout
46
+ self.vision_hidden_size = vision_hidden_size
47
+ self.llm_hidden_size = llm_hidden_size
48
+ self.projector_intermediate_size = projector_intermediate_size
49
+ self.max_visual_tokens = max_visual_tokens
50
+ self.max_new_tokens = max_new_tokens
51
+ self.num_beams = num_beams
52
+ self.decoder_arch = decoder_arch or {
53
+ "vocab_size": 2046,
54
+ "pad_id": 0,
55
+ "bos_id": 2,
56
+ "eos_id": 3,
57
+ "d_model": 512,
58
+ "n_heads": 8,
59
+ "n_layers": 6,
60
+ "d_ff": 1408,
61
+ "dropout": 0.1,
62
+ "max_seq_len": 200,
63
+ "rope_theta": 10000.0,
64
+ "tie_weights": True,
65
+ }
66
+ self.decoder_weights_tied = decoder_weights_tied
modeling_latex_decoder.py CHANGED
@@ -1,202 +1,202 @@
1
- # update v2
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from typing import Optional
7
-
8
- from transformers import PreTrainedModel
9
- from transformers.modeling_outputs import CausalLMOutput
10
-
11
- try:
12
- from .configuration_latex_decoder import LaTeXDecoderConfig
13
- except ImportError:
14
- from latex_ocr.configuration_latex_decoder import LaTeXDecoderConfig
15
-
16
-
17
- class RMSNorm(nn.Module):
18
- def __init__(self, d_model: int, eps: float = 1e-6):
19
- super().__init__()
20
- self.eps = eps
21
- self.weight = nn.Parameter(torch.ones(d_model))
22
-
23
- def forward(self, x: torch.Tensor) -> torch.Tensor:
24
- rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
25
- return x / rms * self.weight
26
-
27
-
28
- def _build_rope_cache(seq_len, head_dim, theta=10000.0, device=None, dtype=torch.float32):
29
- half = head_dim // 2
30
- inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
31
- pos = torch.arange(seq_len, device=device, dtype=torch.float32)
32
- freqs = torch.outer(pos, inv_freq)
33
- emb = torch.cat([freqs, freqs], dim=-1)
34
- return emb.cos().to(dtype), emb.sin().to(dtype)
35
-
36
-
37
- def _rotate_half(x: torch.Tensor) -> torch.Tensor:
38
- half = x.shape[-1] // 2
39
- x1, x2 = x[..., :half], x[..., half:]
40
- return torch.cat([-x2, x1], dim=-1)
41
-
42
-
43
- def apply_rope(q, k, cos, sin):
44
- cos = cos.unsqueeze(0).unsqueeze(0)
45
- sin = sin.unsqueeze(0).unsqueeze(0)
46
- return q * cos + _rotate_half(q) * sin, k * cos + _rotate_half(k) * sin
47
-
48
-
49
- class CausalSelfAttention(nn.Module):
50
- def __init__(self, cfg: LaTeXDecoderConfig):
51
- super().__init__()
52
- self.n_heads = cfg.n_heads
53
- self.head_dim = cfg.head_dim
54
- self.d_model = cfg.d_model
55
- self.dropout_p = cfg.dropout
56
- self.rope_theta = cfg.rope_theta
57
-
58
- self.qkv_proj = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
59
- self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
60
- self._rope_cache: dict = {}
61
-
62
- def _get_rope(self, seq_len, device, dtype):
63
- key = (seq_len, str(device), dtype)
64
- if key not in self._rope_cache:
65
- self._rope_cache[key] = _build_rope_cache(seq_len, self.head_dim, self.rope_theta, device, dtype)
66
- return self._rope_cache[key]
67
-
68
- def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
69
- B, T, C = x.shape
70
- q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
71
-
72
- q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
73
- k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
74
- v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
75
-
76
- cos, sin = self._get_rope(T, x.device, q.dtype)
77
- q, k = apply_rope(q, k, cos, sin)
78
-
79
- dropout_p = self.dropout_p if self.training else 0.0
80
-
81
- if attention_mask is not None:
82
- causal = torch.triu(torch.full((T, T), float("-inf"), device=x.device, dtype=q.dtype), diagonal=1)
83
- pad = (~attention_mask).unsqueeze(1).unsqueeze(2)
84
- attn_bias = causal.unsqueeze(0).unsqueeze(0).expand(B, 1, T, T).clone()
85
- attn_bias = attn_bias.masked_fill(pad, float("-inf"))
86
- out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias, dropout_p=dropout_p, is_causal=False)
87
- else:
88
- out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, is_causal=True)
89
-
90
- return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C))
91
-
92
-
93
- class SwiGLUFFN(nn.Module):
94
- def __init__(self, cfg: LaTeXDecoderConfig):
95
- super().__init__()
96
- self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
97
- self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
98
- self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
99
- self.dropout = nn.Dropout(cfg.dropout)
100
-
101
- def forward(self, x: torch.Tensor) -> torch.Tensor:
102
- return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
103
-
104
-
105
- class TransformerBlock(nn.Module):
106
- def __init__(self, cfg: LaTeXDecoderConfig):
107
- super().__init__()
108
- self.norm1 = RMSNorm(cfg.d_model)
109
- self.attn = CausalSelfAttention(cfg)
110
- self.norm2 = RMSNorm(cfg.d_model)
111
- self.ffn = SwiGLUFFN(cfg)
112
- self.drop = nn.Dropout(cfg.dropout)
113
-
114
- def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
115
- x = x + self.drop(self.attn(self.norm1(x), attention_mask))
116
- x = x + self.drop(self.ffn(self.norm2(x)))
117
- return x
118
-
119
-
120
- class LaTeXDecoderForCausalLM(PreTrainedModel):
121
- config_class = LaTeXDecoderConfig
122
- base_model_prefix = "model"
123
- supports_gradient_checkpointing = False
124
-
125
- def __init__(self, config: LaTeXDecoderConfig):
126
- super().__init__(config)
127
-
128
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id)
129
- self.embed_drop = nn.Dropout(config.dropout)
130
- self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
131
- self.norm_final = RMSNorm(config.d_model)
132
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
133
-
134
- if config.tie_weights:
135
- self.lm_head.weight = self.embed_tokens.weight
136
-
137
- self.post_init()
138
-
139
- def _init_weights(self, module: nn.Module):
140
- if isinstance(module, nn.Linear):
141
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
142
- if module.bias is not None:
143
- nn.init.zeros_(module.bias)
144
- elif isinstance(module, nn.Embedding):
145
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
146
-
147
- def forward(
148
- self,
149
- input_ids: torch.Tensor,
150
- attention_mask: Optional[torch.Tensor] = None,
151
- labels: Optional[torch.Tensor] = None,
152
- **kwargs,
153
- ) -> CausalLMOutput:
154
- x = self.embed_drop(self.embed_tokens(input_ids))
155
- for layer in self.layers:
156
- x = layer(x, attention_mask)
157
- logits = self.lm_head(self.norm_final(x))
158
-
159
- loss = None
160
- if labels is not None:
161
- shift_logits = logits[:, :-1, :].contiguous()
162
- shift_labels = labels[:, 1:].contiguous()
163
- shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_id, -100)
164
- loss = F.cross_entropy(
165
- shift_logits.view(-1, self.config.vocab_size),
166
- shift_labels.view(-1),
167
- ignore_index=-100,
168
- )
169
-
170
- return CausalLMOutput(loss=loss, logits=logits)
171
-
172
- @torch.inference_mode()
173
- def generate(
174
- self,
175
- prompt_ids: torch.Tensor,
176
- max_new_tokens: int = 200,
177
- temperature: float = 1.0,
178
- top_p: float = 0.9,
179
- eos_id: Optional[int] = None,
180
- ) -> torch.Tensor:
181
- eos = eos_id if eos_id is not None else self.config.eos_id
182
- generated = prompt_ids.clone()
183
-
184
- for _ in range(max_new_tokens):
185
- ctx = generated[:, -self.config.max_seq_len:]
186
- logits = self.forward(ctx).logits[:, -1, :]
187
-
188
- if temperature == 0.0:
189
- next_id = logits.argmax(dim=-1, keepdim=True)
190
- else:
191
- probs = F.softmax(logits / temperature, dim=-1)
192
- sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
193
- cumsum = sorted_probs.cumsum(dim=-1)
194
- sorted_probs[cumsum - sorted_probs > top_p] = 0.0
195
- sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
196
- next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
197
-
198
- generated = torch.cat([generated, next_id], dim=-1)
199
- if next_id.item() == eos:
200
- break
201
-
202
- return generated
 
1
+ # update v2
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import CausalLMOutput
10
+
11
+ try:
12
+ from .configuration_latex_decoder import LaTeXDecoderConfig
13
+ except ImportError:
14
+ from latex_ocr.configuration_latex_decoder import LaTeXDecoderConfig
15
+
16
+
17
+ class RMSNorm(nn.Module):
18
+ def __init__(self, d_model: int, eps: float = 1e-6):
19
+ super().__init__()
20
+ self.eps = eps
21
+ self.weight = nn.Parameter(torch.ones(d_model))
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
25
+ return x / rms * self.weight
26
+
27
+
28
+ def _build_rope_cache(seq_len, head_dim, theta=10000.0, device=None, dtype=torch.float32):
29
+ half = head_dim // 2
30
+ inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
31
+ pos = torch.arange(seq_len, device=device, dtype=torch.float32)
32
+ freqs = torch.outer(pos, inv_freq)
33
+ emb = torch.cat([freqs, freqs], dim=-1)
34
+ return emb.cos().to(dtype), emb.sin().to(dtype)
35
+
36
+
37
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
38
+ half = x.shape[-1] // 2
39
+ x1, x2 = x[..., :half], x[..., half:]
40
+ return torch.cat([-x2, x1], dim=-1)
41
+
42
+
43
+ def apply_rope(q, k, cos, sin):
44
+ cos = cos.unsqueeze(0).unsqueeze(0)
45
+ sin = sin.unsqueeze(0).unsqueeze(0)
46
+ return q * cos + _rotate_half(q) * sin, k * cos + _rotate_half(k) * sin
47
+
48
+
49
+ class CausalSelfAttention(nn.Module):
50
+ def __init__(self, cfg: LaTeXDecoderConfig):
51
+ super().__init__()
52
+ self.n_heads = cfg.n_heads
53
+ self.head_dim = cfg.head_dim
54
+ self.d_model = cfg.d_model
55
+ self.dropout_p = cfg.dropout
56
+ self.rope_theta = cfg.rope_theta
57
+
58
+ self.qkv_proj = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
59
+ self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
60
+ self._rope_cache: dict = {}
61
+
62
+ def _get_rope(self, seq_len, device, dtype):
63
+ key = (seq_len, str(device), dtype)
64
+ if key not in self._rope_cache:
65
+ self._rope_cache[key] = _build_rope_cache(seq_len, self.head_dim, self.rope_theta, device, dtype)
66
+ return self._rope_cache[key]
67
+
68
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
69
+ B, T, C = x.shape
70
+ q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
71
+
72
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
73
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
74
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
75
+
76
+ cos, sin = self._get_rope(T, x.device, q.dtype)
77
+ q, k = apply_rope(q, k, cos, sin)
78
+
79
+ dropout_p = self.dropout_p if self.training else 0.0
80
+
81
+ if attention_mask is not None:
82
+ causal = torch.triu(torch.full((T, T), float("-inf"), device=x.device, dtype=q.dtype), diagonal=1)
83
+ pad = (~attention_mask).unsqueeze(1).unsqueeze(2)
84
+ attn_bias = causal.unsqueeze(0).unsqueeze(0).expand(B, 1, T, T).clone()
85
+ attn_bias = attn_bias.masked_fill(pad, float("-inf"))
86
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias, dropout_p=dropout_p, is_causal=False)
87
+ else:
88
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, is_causal=True)
89
+
90
+ return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C))
91
+
92
+
93
+ class SwiGLUFFN(nn.Module):
94
+ def __init__(self, cfg: LaTeXDecoderConfig):
95
+ super().__init__()
96
+ self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
97
+ self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
98
+ self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
99
+ self.dropout = nn.Dropout(cfg.dropout)
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
103
+
104
+
105
+ class TransformerBlock(nn.Module):
106
+ def __init__(self, cfg: LaTeXDecoderConfig):
107
+ super().__init__()
108
+ self.norm1 = RMSNorm(cfg.d_model)
109
+ self.attn = CausalSelfAttention(cfg)
110
+ self.norm2 = RMSNorm(cfg.d_model)
111
+ self.ffn = SwiGLUFFN(cfg)
112
+ self.drop = nn.Dropout(cfg.dropout)
113
+
114
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
115
+ x = x + self.drop(self.attn(self.norm1(x), attention_mask))
116
+ x = x + self.drop(self.ffn(self.norm2(x)))
117
+ return x
118
+
119
+
120
+ class LaTeXDecoderForCausalLM(PreTrainedModel):
121
+ config_class = LaTeXDecoderConfig
122
+ base_model_prefix = "model"
123
+ supports_gradient_checkpointing = False
124
+
125
+ def __init__(self, config: LaTeXDecoderConfig):
126
+ super().__init__(config)
127
+
128
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id)
129
+ self.embed_drop = nn.Dropout(config.dropout)
130
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
131
+ self.norm_final = RMSNorm(config.d_model)
132
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
133
+
134
+ if config.tie_weights:
135
+ self.lm_head.weight = self.embed_tokens.weight
136
+
137
+ self.post_init()
138
+
139
+ def _init_weights(self, module: nn.Module):
140
+ if isinstance(module, nn.Linear):
141
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
142
+ if module.bias is not None:
143
+ nn.init.zeros_(module.bias)
144
+ elif isinstance(module, nn.Embedding):
145
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
146
+
147
+ def forward(
148
+ self,
149
+ input_ids: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ labels: Optional[torch.Tensor] = None,
152
+ **kwargs,
153
+ ) -> CausalLMOutput:
154
+ x = self.embed_drop(self.embed_tokens(input_ids))
155
+ for layer in self.layers:
156
+ x = layer(x, attention_mask)
157
+ logits = self.lm_head(self.norm_final(x))
158
+
159
+ loss = None
160
+ if labels is not None:
161
+ shift_logits = logits[:, :-1, :].contiguous()
162
+ shift_labels = labels[:, 1:].contiguous()
163
+ shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_id, -100)
164
+ loss = F.cross_entropy(
165
+ shift_logits.view(-1, self.config.vocab_size),
166
+ shift_labels.view(-1),
167
+ ignore_index=-100,
168
+ )
169
+
170
+ return CausalLMOutput(loss=loss, logits=logits)
171
+
172
+ @torch.inference_mode()
173
+ def generate(
174
+ self,
175
+ prompt_ids: torch.Tensor,
176
+ max_new_tokens: int = 200,
177
+ temperature: float = 1.0,
178
+ top_p: float = 0.9,
179
+ eos_id: Optional[int] = None,
180
+ ) -> torch.Tensor:
181
+ eos = eos_id if eos_id is not None else self.config.eos_id
182
+ generated = prompt_ids.clone()
183
+
184
+ for _ in range(max_new_tokens):
185
+ ctx = generated[:, -self.config.max_seq_len:]
186
+ logits = self.forward(ctx).logits[:, -1, :]
187
+
188
+ if temperature == 0.0:
189
+ next_id = logits.argmax(dim=-1, keepdim=True)
190
+ else:
191
+ probs = F.softmax(logits / temperature, dim=-1)
192
+ sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
193
+ cumsum = sorted_probs.cumsum(dim=-1)
194
+ sorted_probs[cumsum - sorted_probs > top_p] = 0.0
195
+ sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
196
+ next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
197
+
198
+ generated = torch.cat([generated, next_id], dim=-1)
199
+ if next_id.item() == eos:
200
+ break
201
+
202
+ return generated
modeling_latex_ocr.py CHANGED
@@ -1,508 +1,508 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from einops import rearrange
4
- from functools import partial
5
- from torch import nn
6
- from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
7
- from transformers import PreTrainedModel
8
- from transformers.modeling_outputs import BaseModelOutput
9
-
10
- try:
11
- from .configuration_latex_decoder import LaTeXDecoderConfig
12
- from .configuration_latex_ocr import Nav2TexConfig
13
- from .modeling_latex_decoder import LaTeXDecoderForCausalLM
14
- except ImportError:
15
- from nav2tex.configuration_latex_decoder import LaTeXDecoderConfig
16
- from nav2tex.configuration_latex_ocr import Nav2TexConfig
17
- from nav2tex.modeling_latex_decoder import LaTeXDecoderForCausalLM
18
-
19
- try:
20
- from flash_attn import flash_attn_func, flash_attn_varlen_func
21
- from flash_attn.bert_padding import pad_input, unpad_input
22
- HAS_FLASH_ATTN = True
23
- except ImportError:
24
- HAS_FLASH_ATTN = False
25
-
26
-
27
- def exists(val):
28
- return val is not None
29
-
30
-
31
- def divisible_by(numer, denom):
32
- return (numer % denom) == 0
33
-
34
-
35
- class LayerNorm(nn.Module):
36
- def __init__(self, dim):
37
- super().__init__()
38
- self.normalized_shape = (dim,)
39
- self.eps = 1e-5
40
- self.weight = nn.Parameter(torch.ones(dim))
41
- self.bias = nn.Parameter(torch.zeros(dim))
42
-
43
- def forward(self, x):
44
- return F.layer_norm(
45
- x.float(), self.normalized_shape,
46
- self.weight.float(), self.bias.float(), self.eps,
47
- ).to(x.dtype)
48
-
49
-
50
- class RMSNorm(nn.Module):
51
- def __init__(self, heads, dim):
52
- super().__init__()
53
- self.scale = dim ** 0.5
54
- self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
55
-
56
- def forward(self, x):
57
- return F.normalize(x, dim=-1) * self.scale * self.gamma.to(x.dtype)
58
-
59
-
60
- def rotate_half(x):
61
- x1, x2 = x.chunk(2, dim=-1)
62
- return torch.cat([-x2, x1], dim=-1)
63
-
64
-
65
- def apply_2d_rope(q, k, h_idx, w_idx):
66
- _, _, _, d = q.shape
67
- if d % 4 != 0:
68
- raise ValueError(f"apply_2d_rope expects dim_head divisible by 4, got D={d}")
69
- dim_half = d // 2
70
- dim_quarter = d // 4
71
- inv_freq = 1.0 / (10000 ** (torch.arange(dim_quarter, device=q.device).float() / dim_quarter))
72
- h_theta = h_idx[..., None].float() * inv_freq
73
- w_theta = w_idx[..., None].float() * inv_freq
74
- sin_h = torch.cat([h_theta.sin(), h_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
75
- cos_h = torch.cat([h_theta.cos(), h_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
76
- sin_w = torch.cat([w_theta.sin(), w_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
77
- cos_w = torch.cat([w_theta.cos(), w_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
78
-
79
- def rope(x, sin, cos):
80
- return x * cos + rotate_half(x) * sin
81
-
82
- q = torch.cat([rope(q[..., :dim_half], sin_h, cos_h), rope(q[..., dim_half:], sin_w, cos_w)], dim=-1)
83
- k = torch.cat([rope(k[..., :dim_half], sin_h, cos_h), rope(k[..., dim_half:], sin_w, cos_w)], dim=-1)
84
- return q, k
85
-
86
-
87
- class FeedForward(nn.Module):
88
- def __init__(self, dim, hidden_dim, dropout=0.0):
89
- super().__init__()
90
- self.net = nn.Sequential(
91
- LayerNorm(dim),
92
- nn.Linear(dim, hidden_dim),
93
- nn.GELU(),
94
- nn.Dropout(dropout),
95
- nn.Linear(hidden_dim, dim),
96
- nn.Dropout(dropout),
97
- )
98
-
99
- def forward(self, x):
100
- return self.net(x)
101
-
102
-
103
- class Attention(nn.Module):
104
- def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
105
- super().__init__()
106
- inner_dim = dim_head * heads
107
- self.heads = heads
108
- self.norm = LayerNorm(dim)
109
- self.q_norm = RMSNorm(heads, dim_head)
110
- self.k_norm = RMSNorm(heads, dim_head)
111
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
112
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
113
- self.attend = nn.Softmax(dim=-1)
114
- self.dropout = nn.Dropout(dropout)
115
- self.to_out = nn.Sequential(nn.Linear(inner_dim, dim, bias=False), nn.Dropout(dropout))
116
-
117
- def forward(self, x, mask=None, attn_mask=None, positions=None):
118
- x = self.norm(x)
119
- q = self.to_q(x)
120
- k, v = self.to_kv(x).chunk(2, dim=-1)
121
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v))
122
- q = self.q_norm(q)
123
- k = self.k_norm(k)
124
-
125
- if positions is not None:
126
- q, k = apply_2d_rope(q, k, positions[0], positions[1])
127
-
128
- if HAS_FLASH_ATTN and x.is_cuda and attn_mask is None:
129
- fa_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.bfloat16
130
- q_ = rearrange(q, "b h n d -> b n h d").contiguous().to(fa_dtype)
131
- k_ = rearrange(k, "b h n d -> b n h d").contiguous().to(fa_dtype)
132
- v_ = rearrange(v, "b h n d -> b n h d").contiguous().to(fa_dtype)
133
- if exists(mask):
134
- batch, seqlen = mask.shape
135
- q_unpad, indices, cu_q, max_q, *_ = unpad_input(q_, mask)
136
- k_unpad, _, cu_k, max_k, *_ = unpad_input(k_, mask)
137
- v_unpad, _, _, _, *_ = unpad_input(v_, mask)
138
- out_unpad = flash_attn_varlen_func(
139
- q_unpad, k_unpad, v_unpad,
140
- cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
141
- max_seqlen_q=max_q, max_seqlen_k=max_k,
142
- dropout_p=self.dropout.p if self.training else 0.0,
143
- causal=False,
144
- )
145
- out = pad_input(out_unpad, indices, batch, seqlen)
146
- else:
147
- out = flash_attn_func(
148
- q_, k_, v_,
149
- dropout_p=self.dropout.p if self.training else 0.0,
150
- causal=False,
151
- )
152
- out = rearrange(out, "b n h d -> b n (h d)").to(x.dtype)
153
- else:
154
- dots = torch.matmul(q, k.transpose(-1, -2))
155
- if exists(mask):
156
- dots = dots.masked_fill(~mask[:, None, None, :], -torch.finfo(dots.dtype).max)
157
- if exists(attn_mask):
158
- dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
159
- attn = self.dropout(self.attend(dots))
160
- out = rearrange(torch.matmul(attn, v), "b h n d -> b n (h d)")
161
- return self.to_out(out)
162
-
163
-
164
- class Transformer(nn.Module):
165
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0):
166
- super().__init__()
167
- self.layers = nn.ModuleList([
168
- nn.ModuleList([Attention(dim, heads, dim_head, dropout), FeedForward(dim, mlp_dim, dropout)])
169
- for _ in range(depth)
170
- ])
171
- self.norm = LayerNorm(dim)
172
-
173
- def forward(self, x, mask=None, attn_mask=None, positions=None):
174
- for attn, ff in self.layers:
175
- x = attn(x, mask=mask, attn_mask=attn_mask, positions=positions) + x
176
- x = ff(x) + x
177
- return self.norm(x)
178
-
179
-
180
- class NaViT_Encoder(nn.Module):
181
- def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim,
182
- channels=3, dim_head=64, dropout=0.0, emb_dropout=0.0):
183
- super().__init__()
184
- image_height, image_width = image_size
185
- assert divisible_by(image_height, patch_size)
186
- assert divisible_by(image_width, patch_size)
187
- self.patch_size = patch_size
188
- self.to_patch_embedding = nn.Sequential(
189
- LayerNorm(channels * patch_size ** 2),
190
- nn.Linear(channels * patch_size ** 2, dim),
191
- LayerNorm(dim),
192
- )
193
- self.dropout = nn.Dropout(emb_dropout)
194
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
195
-
196
- @property
197
- def device(self):
198
- return next(self.parameters()).device
199
-
200
- def forward(self, batched_images):
201
- p = self.patch_size
202
- device = self.device
203
- arange = partial(torch.arange, device=device)
204
- pad_sequence = partial(orig_pad_sequence, batch_first=True)
205
- batched_sequences, batched_positions = [], []
206
-
207
- for images in batched_images:
208
- sequences, positions = [], []
209
- for image in images:
210
- _, h, w = image.shape
211
- ph, pw = h // p, w // p
212
- seq = rearrange(image, "c (h p1) (w p2) -> (h w) (c p1 p2)", p1=p, p2=p)
213
- pos = torch.stack(torch.meshgrid(arange(ph), arange(pw), indexing="ij"), dim=-1)
214
- sequences.append(seq)
215
- positions.append(rearrange(pos, "h w c -> (h w) c"))
216
- batched_sequences.append(torch.cat(sequences, dim=0))
217
- batched_positions.append(torch.cat(positions, dim=0))
218
-
219
- patches = pad_sequence(batched_sequences)
220
- patch_positions = pad_sequence(batched_positions)
221
- lengths = torch.tensor([seq.shape[0] for seq in batched_sequences], device=device)
222
- mask = torch.arange(patches.shape[1], device=device)[None, :] < lengths[:, None]
223
- x = self.to_patch_embedding(patches.to(next(self.parameters()).dtype))
224
- h_idx, w_idx = patch_positions.unbind(dim=-1)
225
- x = self.dropout(x)
226
- x = self.transformer(x, mask=mask, positions=(h_idx, w_idx))
227
- return x, mask
228
-
229
-
230
- class MLPProjector(nn.Module):
231
- def __init__(self, vision_hidden_size=1024, llm_hidden_size=512, intermediate_size=2048):
232
- super().__init__()
233
- self.norm = nn.LayerNorm(vision_hidden_size)
234
- self.gate_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
235
- self.up_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
236
- self.down_proj = nn.Linear(intermediate_size, llm_hidden_size, bias=False)
237
-
238
- def forward(self, x):
239
- x = self.norm(x)
240
- return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
241
-
242
-
243
- class VisualEncoder(nn.Module):
244
- def __init__(self, encoder, bridge, max_visual_tokens):
245
- super().__init__()
246
- self.navit = encoder
247
- self.projector = bridge
248
- self.max_visual_tokens = max_visual_tokens
249
-
250
- def forward(self, batched_images):
251
- x, mask = self.navit(batched_images)
252
- if x.shape[1] > self.max_visual_tokens:
253
- x = x[:, :self.max_visual_tokens]
254
- mask = mask[:, :self.max_visual_tokens]
255
- return self.projector(x), mask
256
-
257
-
258
- class CustomDecoder(nn.Module):
259
- def __init__(self, config: Nav2TexConfig):
260
- super().__init__()
261
- dec = config.decoder_arch
262
- self._model = LaTeXDecoderForCausalLM(
263
- LaTeXDecoderConfig(
264
- vocab_size=dec["vocab_size"],
265
- pad_id=dec["pad_id"],
266
- bos_id=dec["bos_id"],
267
- eos_id=dec["eos_id"],
268
- d_model=dec["d_model"],
269
- n_heads=dec["n_heads"],
270
- n_layers=dec["n_layers"],
271
- d_ff=dec["d_ff"],
272
- dropout=dec.get("dropout", 0.1),
273
- max_seq_len=dec["max_seq_len"],
274
- rope_theta=dec.get("rope_theta", 10000.0),
275
- tie_weights=dec.get("tie_weights", True),
276
- )
277
- )
278
- self.pad_token_id = self._model.config.pad_id
279
- self.eos_token_id = self._model.config.eos_id
280
- self._vocab_size = self._model.config.vocab_size
281
- self._pad_id = self._model.config.pad_id
282
- if not config.decoder_weights_tied:
283
- self.untie_weights()
284
-
285
- def get_input_embeddings(self):
286
- return self._model.embed_tokens
287
-
288
- def tie_weights(self):
289
- self._model.lm_head.weight = self._model.embed_tokens.weight
290
-
291
- def untie_weights(self):
292
- if self.are_weights_tied():
293
- self._model.lm_head.weight = nn.Parameter(self._model.embed_tokens.weight.detach().clone())
294
-
295
- def are_weights_tied(self):
296
- return self._model.lm_head.weight.data_ptr() == self._model.embed_tokens.weight.data_ptr()
297
-
298
- def _forward_embeds(self, inputs_embeds, attention_mask=None):
299
- x = self._model.embed_drop(inputs_embeds)
300
- mask = attention_mask.bool() if attention_mask is not None else None
301
- for layer in self._model.layers:
302
- x = layer(x, mask)
303
- return self._model.lm_head(self._model.norm_final(x))
304
-
305
- def forward(self, inputs_embeds=None, attention_mask=None, labels=None, **kwargs):
306
- logits = self._forward_embeds(inputs_embeds, attention_mask)
307
- loss = None
308
- if labels is not None:
309
- shift_logits = logits[:, :-1].contiguous()
310
- shift_labels = labels[:, 1:].contiguous().masked_fill(
311
- labels[:, 1:].contiguous() == self._pad_id, -100
312
- )
313
- loss = F.cross_entropy(
314
- shift_logits.view(-1, self._vocab_size),
315
- shift_labels.view(-1),
316
- ignore_index=-100,
317
- )
318
- return BaseModelOutput(last_hidden_state=logits, hidden_states=(loss,))
319
-
320
- @torch.no_grad()
321
- def generate(self, inputs_embeds, attention_mask, max_new_tokens, num_beams=1):
322
- device = inputs_embeds.device
323
- batch = inputs_embeds.shape[0]
324
-
325
- if num_beams > 1:
326
- # beam search: only supports batch_size=1
327
- assert batch == 1, "beam search only supports batch_size=1"
328
- return self._beam_search(inputs_embeds, attention_mask, max_new_tokens, num_beams)
329
-
330
- return self._greedy_batch(inputs_embeds, attention_mask, max_new_tokens)
331
-
332
- @torch.no_grad()
333
- def _greedy_batch(self, inputs_embeds, attention_mask, max_new_tokens):
334
- """Greedy decoding with true batch support."""
335
- eos_id = self.eos_token_id
336
- pad_id = self._pad_id
337
- device = inputs_embeds.device
338
- batch = inputs_embeds.shape[0]
339
- d_model = inputs_embeds.shape[-1]
340
-
341
- # generated token ids per sample, and finished flags
342
- gen_ids = [[] for _ in range(batch)]
343
- finished = torch.zeros(batch, dtype=torch.bool, device=device)
344
-
345
- cur_embeds = inputs_embeds # (B, vis_len, D)
346
- cur_mask = attention_mask # (B, vis_len)
347
-
348
- for _ in range(max_new_tokens):
349
- logits = self._forward_embeds(cur_embeds, cur_mask) # (B, seq, vocab)
350
- next_tok = logits[:, -1, :].argmax(dim=-1) # (B,)
351
-
352
- for i in range(batch):
353
- if not finished[i]:
354
- gen_ids[i].append(next_tok[i].item())
355
- finished |= (next_tok == eos_id)
356
- if finished.all():
357
- break
358
-
359
- tok_emb = self._model.embed_tokens(next_tok.unsqueeze(1)) # (B, 1, D)
360
- tok_mask = cur_mask.new_ones(batch, 1)
361
- cur_embeds = torch.cat([cur_embeds, tok_emb], dim=1)
362
- cur_mask = torch.cat([cur_mask, tok_mask], dim=1)
363
-
364
- # pad to same length and return (B, max_len)
365
- max_len = max((len(ids) for ids in gen_ids), default=0)
366
- if max_len == 0:
367
- return torch.zeros(batch, 0, dtype=torch.long, device=device)
368
- out = torch.full((batch, max_len), pad_id, dtype=torch.long, device=device)
369
- for i, ids in enumerate(gen_ids):
370
- if ids:
371
- out[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device)
372
- return out
373
-
374
- @torch.no_grad()
375
- def _beam_search(self, inputs_embeds, attention_mask, max_new_tokens, num_beams):
376
- """Original beam search (batch_size=1 only)."""
377
- eos_id = self.eos_token_id
378
- device = inputs_embeds.device
379
- vis_emb = inputs_embeds[0]
380
- vis_len = vis_emb.shape[0]
381
- vis_mask = attention_mask[0] if attention_mask is not None else None
382
- beams = [(0.0, [], False) for _ in range(num_beams)]
383
-
384
- for _ in range(max_new_tokens):
385
- all_embeds, all_masks = [], []
386
- for score, ids, _ in beams:
387
- tok_emb = self._model.embed_tokens(torch.tensor(ids, device=device, dtype=torch.long)) if ids else None
388
- seq_emb = torch.cat([vis_emb, tok_emb], dim=0) if tok_emb is not None else vis_emb
389
- all_embeds.append(seq_emb)
390
- if vis_mask is not None:
391
- tok_mask = vis_mask.new_ones(len(ids)) if ids else vis_mask.new_zeros(0)
392
- all_masks.append(torch.cat([vis_mask, tok_mask]) if ids else vis_mask)
393
-
394
- max_len = max(e.shape[0] for e in all_embeds)
395
- d_model = all_embeds[0].shape[-1]
396
- padded_embeds = vis_emb.new_zeros(num_beams, max_len, d_model)
397
- padded_mask = vis_mask.new_zeros(num_beams, max_len) if vis_mask is not None else None
398
- for idx, emb in enumerate(all_embeds):
399
- padded_embeds[idx, :emb.shape[0]] = emb
400
- if padded_mask is not None:
401
- padded_mask[idx, :emb.shape[0]] = all_masks[idx]
402
-
403
- logits = self._forward_embeds(padded_embeds, padded_mask)
404
- candidates = []
405
- for beam_idx, (score, ids, done) in enumerate(beams):
406
- if done:
407
- candidates.append((score, ids, True))
408
- continue
409
- last_pos = vis_len + len(ids) - 1
410
- log_p = torch.log_softmax(logits[beam_idx, last_pos, :], dim=-1)
411
- if len(ids) == 0 and beam_idx > 0:
412
- log_p = log_p.fill_(-1e9)
413
- for lp, tok in zip(*map(lambda t: t.tolist(), log_p.topk(num_beams))):
414
- candidates.append((score + lp, ids + [tok], tok == eos_id))
415
- candidates.sort(key=lambda x: -x[0])
416
- beams = candidates[:num_beams]
417
- if all(done for _, _, done in beams):
418
- break
419
-
420
- best_ids = max(beams, key=lambda x: x[0])[1]
421
- if not best_ids:
422
- return torch.zeros(1, 0, dtype=torch.long, device=device)
423
- return torch.tensor(best_ids, dtype=torch.long, device=device).unsqueeze(0)
424
-
425
-
426
- class Nav2TexModel(PreTrainedModel):
427
- config_class = Nav2TexConfig
428
- base_model_prefix = "model"
429
- main_input_name = "pixel_values"
430
-
431
- def __init__(self, config: Nav2TexConfig):
432
- super().__init__(config)
433
- self.config = config
434
- self.visual_encoder = VisualEncoder(
435
- NaViT_Encoder(
436
- image_size=(config.image_height, config.max_image_width),
437
- patch_size=config.patch_size,
438
- dim=config.navit_dim,
439
- depth=config.navit_depth,
440
- heads=config.navit_heads,
441
- mlp_dim=config.navit_mlp_dim,
442
- dim_head=config.navit_dim_head,
443
- dropout=config.navit_dropout,
444
- emb_dropout=config.navit_emb_dropout,
445
- ),
446
- MLPProjector(
447
- vision_hidden_size=config.vision_hidden_size,
448
- llm_hidden_size=config.llm_hidden_size,
449
- intermediate_size=config.projector_intermediate_size,
450
- ),
451
- max_visual_tokens=config.max_visual_tokens,
452
- )
453
- self.decoder = CustomDecoder(config)
454
- self.post_init()
455
-
456
- def tie_weights(self):
457
- if self.config.decoder_weights_tied:
458
- self.decoder.tie_weights()
459
- else:
460
- self.decoder.untie_weights()
461
-
462
- def _init_weights(self, module):
463
- return
464
-
465
- @staticmethod
466
- def _to_batched_images(pixel_values):
467
- if isinstance(pixel_values, list):
468
- return pixel_values
469
- if isinstance(pixel_values, torch.Tensor):
470
- return [[img] for img in pixel_values]
471
- raise TypeError(f"Unsupported pixel_values type: {type(pixel_values)}")
472
-
473
- def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None, **kwargs):
474
- batched_images = self._to_batched_images(pixel_values)
475
- ve, vm = self.visual_encoder(batched_images)
476
- if input_ids is None:
477
- return BaseModelOutput(last_hidden_state=ve)
478
- te = self.decoder.get_input_embeddings()(input_ids)
479
- inputs_embeds = torch.cat([ve, te], dim=1)
480
- am = torch.cat([vm.to(dtype=attention_mask.dtype), attention_mask], dim=1)
481
- lv = torch.full((labels.shape[0], ve.shape[1]), -100, dtype=labels.dtype, device=labels.device)
482
- out = self.decoder(
483
- inputs_embeds=inputs_embeds,
484
- attention_mask=am,
485
- labels=torch.cat([lv, labels], dim=1),
486
- )
487
- return BaseModelOutput(last_hidden_state=out.last_hidden_state, hidden_states=(out.hidden_states[0],))
488
-
489
- @torch.no_grad()
490
- def generate(self, pixel_values, max_new_tokens=None, num_beams=None):
491
- batched_images = self._to_batched_images(pixel_values)
492
- ve, vm = self.visual_encoder(batched_images)
493
- batch = ve.shape[0]
494
- bos_id = self.config.decoder_arch["bos_id"]
495
- bos_emb = self.decoder.get_input_embeddings()(
496
- torch.full((batch, 1), bos_id, dtype=torch.long, device=ve.device)
497
- )
498
- inputs_embeds = torch.cat([ve, bos_emb], dim=1)
499
- attention_mask = torch.cat([
500
- vm.to(dtype=torch.long),
501
- torch.ones(batch, 1, dtype=torch.long, device=ve.device)
502
- ], dim=1)
503
- return self.decoder.generate(
504
- inputs_embeds=inputs_embeds,
505
- attention_mask=attention_mask,
506
- max_new_tokens=max_new_tokens or self.config.max_new_tokens,
507
- num_beams=num_beams or self.config.num_beams,
508
  )
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from functools import partial
5
+ from torch import nn
6
+ from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutput
9
+
10
+ try:
11
+ from .configuration_latex_decoder import LaTeXDecoderConfig
12
+ from .configuration_latex_ocr import Nav2TexConfig
13
+ from .modeling_latex_decoder import LaTeXDecoderForCausalLM
14
+ except ImportError:
15
+ from nav2tex.configuration_latex_decoder import LaTeXDecoderConfig
16
+ from nav2tex.configuration_latex_ocr import Nav2TexConfig
17
+ from nav2tex.modeling_latex_decoder import LaTeXDecoderForCausalLM
18
+
19
+ try:
20
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
21
+ from flash_attn.bert_padding import pad_input, unpad_input
22
+ HAS_FLASH_ATTN = True
23
+ except ImportError:
24
+ HAS_FLASH_ATTN = False
25
+
26
+
27
+ def exists(val):
28
+ return val is not None
29
+
30
+
31
+ def divisible_by(numer, denom):
32
+ return (numer % denom) == 0
33
+
34
+
35
+ class LayerNorm(nn.Module):
36
+ def __init__(self, dim):
37
+ super().__init__()
38
+ self.normalized_shape = (dim,)
39
+ self.eps = 1e-5
40
+ self.weight = nn.Parameter(torch.ones(dim))
41
+ self.bias = nn.Parameter(torch.zeros(dim))
42
+
43
+ def forward(self, x):
44
+ return F.layer_norm(
45
+ x.float(), self.normalized_shape,
46
+ self.weight.float(), self.bias.float(), self.eps,
47
+ ).to(x.dtype)
48
+
49
+
50
+ class RMSNorm(nn.Module):
51
+ def __init__(self, heads, dim):
52
+ super().__init__()
53
+ self.scale = dim ** 0.5
54
+ self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
55
+
56
+ def forward(self, x):
57
+ return F.normalize(x, dim=-1) * self.scale * self.gamma.to(x.dtype)
58
+
59
+
60
+ def rotate_half(x):
61
+ x1, x2 = x.chunk(2, dim=-1)
62
+ return torch.cat([-x2, x1], dim=-1)
63
+
64
+
65
+ def apply_2d_rope(q, k, h_idx, w_idx):
66
+ _, _, _, d = q.shape
67
+ if d % 4 != 0:
68
+ raise ValueError(f"apply_2d_rope expects dim_head divisible by 4, got D={d}")
69
+ dim_half = d // 2
70
+ dim_quarter = d // 4
71
+ inv_freq = 1.0 / (10000 ** (torch.arange(dim_quarter, device=q.device).float() / dim_quarter))
72
+ h_theta = h_idx[..., None].float() * inv_freq
73
+ w_theta = w_idx[..., None].float() * inv_freq
74
+ sin_h = torch.cat([h_theta.sin(), h_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
75
+ cos_h = torch.cat([h_theta.cos(), h_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
76
+ sin_w = torch.cat([w_theta.sin(), w_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
77
+ cos_w = torch.cat([w_theta.cos(), w_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
78
+
79
+ def rope(x, sin, cos):
80
+ return x * cos + rotate_half(x) * sin
81
+
82
+ q = torch.cat([rope(q[..., :dim_half], sin_h, cos_h), rope(q[..., dim_half:], sin_w, cos_w)], dim=-1)
83
+ k = torch.cat([rope(k[..., :dim_half], sin_h, cos_h), rope(k[..., dim_half:], sin_w, cos_w)], dim=-1)
84
+ return q, k
85
+
86
+
87
+ class FeedForward(nn.Module):
88
+ def __init__(self, dim, hidden_dim, dropout=0.0):
89
+ super().__init__()
90
+ self.net = nn.Sequential(
91
+ LayerNorm(dim),
92
+ nn.Linear(dim, hidden_dim),
93
+ nn.GELU(),
94
+ nn.Dropout(dropout),
95
+ nn.Linear(hidden_dim, dim),
96
+ nn.Dropout(dropout),
97
+ )
98
+
99
+ def forward(self, x):
100
+ return self.net(x)
101
+
102
+
103
+ class Attention(nn.Module):
104
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
105
+ super().__init__()
106
+ inner_dim = dim_head * heads
107
+ self.heads = heads
108
+ self.norm = LayerNorm(dim)
109
+ self.q_norm = RMSNorm(heads, dim_head)
110
+ self.k_norm = RMSNorm(heads, dim_head)
111
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
112
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
113
+ self.attend = nn.Softmax(dim=-1)
114
+ self.dropout = nn.Dropout(dropout)
115
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim, bias=False), nn.Dropout(dropout))
116
+
117
+ def forward(self, x, mask=None, attn_mask=None, positions=None):
118
+ x = self.norm(x)
119
+ q = self.to_q(x)
120
+ k, v = self.to_kv(x).chunk(2, dim=-1)
121
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v))
122
+ q = self.q_norm(q)
123
+ k = self.k_norm(k)
124
+
125
+ if positions is not None:
126
+ q, k = apply_2d_rope(q, k, positions[0], positions[1])
127
+
128
+ if HAS_FLASH_ATTN and x.is_cuda and attn_mask is None:
129
+ fa_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.bfloat16
130
+ q_ = rearrange(q, "b h n d -> b n h d").contiguous().to(fa_dtype)
131
+ k_ = rearrange(k, "b h n d -> b n h d").contiguous().to(fa_dtype)
132
+ v_ = rearrange(v, "b h n d -> b n h d").contiguous().to(fa_dtype)
133
+ if exists(mask):
134
+ batch, seqlen = mask.shape
135
+ q_unpad, indices, cu_q, max_q, *_ = unpad_input(q_, mask)
136
+ k_unpad, _, cu_k, max_k, *_ = unpad_input(k_, mask)
137
+ v_unpad, _, _, _, *_ = unpad_input(v_, mask)
138
+ out_unpad = flash_attn_varlen_func(
139
+ q_unpad, k_unpad, v_unpad,
140
+ cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
141
+ max_seqlen_q=max_q, max_seqlen_k=max_k,
142
+ dropout_p=self.dropout.p if self.training else 0.0,
143
+ causal=False,
144
+ )
145
+ out = pad_input(out_unpad, indices, batch, seqlen)
146
+ else:
147
+ out = flash_attn_func(
148
+ q_, k_, v_,
149
+ dropout_p=self.dropout.p if self.training else 0.0,
150
+ causal=False,
151
+ )
152
+ out = rearrange(out, "b n h d -> b n (h d)").to(x.dtype)
153
+ else:
154
+ dots = torch.matmul(q, k.transpose(-1, -2))
155
+ if exists(mask):
156
+ dots = dots.masked_fill(~mask[:, None, None, :], -torch.finfo(dots.dtype).max)
157
+ if exists(attn_mask):
158
+ dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
159
+ attn = self.dropout(self.attend(dots))
160
+ out = rearrange(torch.matmul(attn, v), "b h n d -> b n (h d)")
161
+ return self.to_out(out)
162
+
163
+
164
+ class Transformer(nn.Module):
165
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0):
166
+ super().__init__()
167
+ self.layers = nn.ModuleList([
168
+ nn.ModuleList([Attention(dim, heads, dim_head, dropout), FeedForward(dim, mlp_dim, dropout)])
169
+ for _ in range(depth)
170
+ ])
171
+ self.norm = LayerNorm(dim)
172
+
173
+ def forward(self, x, mask=None, attn_mask=None, positions=None):
174
+ for attn, ff in self.layers:
175
+ x = attn(x, mask=mask, attn_mask=attn_mask, positions=positions) + x
176
+ x = ff(x) + x
177
+ return self.norm(x)
178
+
179
+
180
+ class NaViT_Encoder(nn.Module):
181
+ def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim,
182
+ channels=3, dim_head=64, dropout=0.0, emb_dropout=0.0):
183
+ super().__init__()
184
+ image_height, image_width = image_size
185
+ assert divisible_by(image_height, patch_size)
186
+ assert divisible_by(image_width, patch_size)
187
+ self.patch_size = patch_size
188
+ self.to_patch_embedding = nn.Sequential(
189
+ LayerNorm(channels * patch_size ** 2),
190
+ nn.Linear(channels * patch_size ** 2, dim),
191
+ LayerNorm(dim),
192
+ )
193
+ self.dropout = nn.Dropout(emb_dropout)
194
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
195
+
196
+ @property
197
+ def device(self):
198
+ return next(self.parameters()).device
199
+
200
+ def forward(self, batched_images):
201
+ p = self.patch_size
202
+ device = self.device
203
+ arange = partial(torch.arange, device=device)
204
+ pad_sequence = partial(orig_pad_sequence, batch_first=True)
205
+ batched_sequences, batched_positions = [], []
206
+
207
+ for images in batched_images:
208
+ sequences, positions = [], []
209
+ for image in images:
210
+ _, h, w = image.shape
211
+ ph, pw = h // p, w // p
212
+ seq = rearrange(image, "c (h p1) (w p2) -> (h w) (c p1 p2)", p1=p, p2=p)
213
+ pos = torch.stack(torch.meshgrid(arange(ph), arange(pw), indexing="ij"), dim=-1)
214
+ sequences.append(seq)
215
+ positions.append(rearrange(pos, "h w c -> (h w) c"))
216
+ batched_sequences.append(torch.cat(sequences, dim=0))
217
+ batched_positions.append(torch.cat(positions, dim=0))
218
+
219
+ patches = pad_sequence(batched_sequences)
220
+ patch_positions = pad_sequence(batched_positions)
221
+ lengths = torch.tensor([seq.shape[0] for seq in batched_sequences], device=device)
222
+ mask = torch.arange(patches.shape[1], device=device)[None, :] < lengths[:, None]
223
+ x = self.to_patch_embedding(patches.to(next(self.parameters()).dtype))
224
+ h_idx, w_idx = patch_positions.unbind(dim=-1)
225
+ x = self.dropout(x)
226
+ x = self.transformer(x, mask=mask, positions=(h_idx, w_idx))
227
+ return x, mask
228
+
229
+
230
+ class MLPProjector(nn.Module):
231
+ def __init__(self, vision_hidden_size=1024, llm_hidden_size=512, intermediate_size=2048):
232
+ super().__init__()
233
+ self.norm = nn.LayerNorm(vision_hidden_size)
234
+ self.gate_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
235
+ self.up_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
236
+ self.down_proj = nn.Linear(intermediate_size, llm_hidden_size, bias=False)
237
+
238
+ def forward(self, x):
239
+ x = self.norm(x)
240
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
241
+
242
+
243
+ class VisualEncoder(nn.Module):
244
+ def __init__(self, encoder, bridge, max_visual_tokens):
245
+ super().__init__()
246
+ self.navit = encoder
247
+ self.projector = bridge
248
+ self.max_visual_tokens = max_visual_tokens
249
+
250
+ def forward(self, batched_images):
251
+ x, mask = self.navit(batched_images)
252
+ if x.shape[1] > self.max_visual_tokens:
253
+ x = x[:, :self.max_visual_tokens]
254
+ mask = mask[:, :self.max_visual_tokens]
255
+ return self.projector(x), mask
256
+
257
+
258
+ class CustomDecoder(nn.Module):
259
+ def __init__(self, config: Nav2TexConfig):
260
+ super().__init__()
261
+ dec = config.decoder_arch
262
+ self._model = LaTeXDecoderForCausalLM(
263
+ LaTeXDecoderConfig(
264
+ vocab_size=dec["vocab_size"],
265
+ pad_id=dec["pad_id"],
266
+ bos_id=dec["bos_id"],
267
+ eos_id=dec["eos_id"],
268
+ d_model=dec["d_model"],
269
+ n_heads=dec["n_heads"],
270
+ n_layers=dec["n_layers"],
271
+ d_ff=dec["d_ff"],
272
+ dropout=dec.get("dropout", 0.1),
273
+ max_seq_len=dec["max_seq_len"],
274
+ rope_theta=dec.get("rope_theta", 10000.0),
275
+ tie_weights=dec.get("tie_weights", True),
276
+ )
277
+ )
278
+ self.pad_token_id = self._model.config.pad_id
279
+ self.eos_token_id = self._model.config.eos_id
280
+ self._vocab_size = self._model.config.vocab_size
281
+ self._pad_id = self._model.config.pad_id
282
+ if not config.decoder_weights_tied:
283
+ self.untie_weights()
284
+
285
+ def get_input_embeddings(self):
286
+ return self._model.embed_tokens
287
+
288
+ def tie_weights(self):
289
+ self._model.lm_head.weight = self._model.embed_tokens.weight
290
+
291
+ def untie_weights(self):
292
+ if self.are_weights_tied():
293
+ self._model.lm_head.weight = nn.Parameter(self._model.embed_tokens.weight.detach().clone())
294
+
295
+ def are_weights_tied(self):
296
+ return self._model.lm_head.weight.data_ptr() == self._model.embed_tokens.weight.data_ptr()
297
+
298
+ def _forward_embeds(self, inputs_embeds, attention_mask=None):
299
+ x = self._model.embed_drop(inputs_embeds)
300
+ mask = attention_mask.bool() if attention_mask is not None else None
301
+ for layer in self._model.layers:
302
+ x = layer(x, mask)
303
+ return self._model.lm_head(self._model.norm_final(x))
304
+
305
+ def forward(self, inputs_embeds=None, attention_mask=None, labels=None, **kwargs):
306
+ logits = self._forward_embeds(inputs_embeds, attention_mask)
307
+ loss = None
308
+ if labels is not None:
309
+ shift_logits = logits[:, :-1].contiguous()
310
+ shift_labels = labels[:, 1:].contiguous().masked_fill(
311
+ labels[:, 1:].contiguous() == self._pad_id, -100
312
+ )
313
+ loss = F.cross_entropy(
314
+ shift_logits.view(-1, self._vocab_size),
315
+ shift_labels.view(-1),
316
+ ignore_index=-100,
317
+ )
318
+ return BaseModelOutput(last_hidden_state=logits, hidden_states=(loss,))
319
+
320
+ @torch.no_grad()
321
+ def generate(self, inputs_embeds, attention_mask, max_new_tokens, num_beams=1):
322
+ device = inputs_embeds.device
323
+ batch = inputs_embeds.shape[0]
324
+
325
+ if num_beams > 1:
326
+ # beam search: only supports batch_size=1
327
+ assert batch == 1, "beam search only supports batch_size=1"
328
+ return self._beam_search(inputs_embeds, attention_mask, max_new_tokens, num_beams)
329
+
330
+ return self._greedy_batch(inputs_embeds, attention_mask, max_new_tokens)
331
+
332
+ @torch.no_grad()
333
+ def _greedy_batch(self, inputs_embeds, attention_mask, max_new_tokens):
334
+ """Greedy decoding with true batch support."""
335
+ eos_id = self.eos_token_id
336
+ pad_id = self._pad_id
337
+ device = inputs_embeds.device
338
+ batch = inputs_embeds.shape[0]
339
+ d_model = inputs_embeds.shape[-1]
340
+
341
+ # generated token ids per sample, and finished flags
342
+ gen_ids = [[] for _ in range(batch)]
343
+ finished = torch.zeros(batch, dtype=torch.bool, device=device)
344
+
345
+ cur_embeds = inputs_embeds # (B, vis_len, D)
346
+ cur_mask = attention_mask # (B, vis_len)
347
+
348
+ for _ in range(max_new_tokens):
349
+ logits = self._forward_embeds(cur_embeds, cur_mask) # (B, seq, vocab)
350
+ next_tok = logits[:, -1, :].argmax(dim=-1) # (B,)
351
+
352
+ for i in range(batch):
353
+ if not finished[i]:
354
+ gen_ids[i].append(next_tok[i].item())
355
+ finished |= (next_tok == eos_id)
356
+ if finished.all():
357
+ break
358
+
359
+ tok_emb = self._model.embed_tokens(next_tok.unsqueeze(1)) # (B, 1, D)
360
+ tok_mask = cur_mask.new_ones(batch, 1)
361
+ cur_embeds = torch.cat([cur_embeds, tok_emb], dim=1)
362
+ cur_mask = torch.cat([cur_mask, tok_mask], dim=1)
363
+
364
+ # pad to same length and return (B, max_len)
365
+ max_len = max((len(ids) for ids in gen_ids), default=0)
366
+ if max_len == 0:
367
+ return torch.zeros(batch, 0, dtype=torch.long, device=device)
368
+ out = torch.full((batch, max_len), pad_id, dtype=torch.long, device=device)
369
+ for i, ids in enumerate(gen_ids):
370
+ if ids:
371
+ out[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device)
372
+ return out
373
+
374
+ @torch.no_grad()
375
+ def _beam_search(self, inputs_embeds, attention_mask, max_new_tokens, num_beams):
376
+ """Original beam search (batch_size=1 only)."""
377
+ eos_id = self.eos_token_id
378
+ device = inputs_embeds.device
379
+ vis_emb = inputs_embeds[0]
380
+ vis_len = vis_emb.shape[0]
381
+ vis_mask = attention_mask[0] if attention_mask is not None else None
382
+ beams = [(0.0, [], False) for _ in range(num_beams)]
383
+
384
+ for _ in range(max_new_tokens):
385
+ all_embeds, all_masks = [], []
386
+ for score, ids, _ in beams:
387
+ tok_emb = self._model.embed_tokens(torch.tensor(ids, device=device, dtype=torch.long)) if ids else None
388
+ seq_emb = torch.cat([vis_emb, tok_emb], dim=0) if tok_emb is not None else vis_emb
389
+ all_embeds.append(seq_emb)
390
+ if vis_mask is not None:
391
+ tok_mask = vis_mask.new_ones(len(ids)) if ids else vis_mask.new_zeros(0)
392
+ all_masks.append(torch.cat([vis_mask, tok_mask]) if ids else vis_mask)
393
+
394
+ max_len = max(e.shape[0] for e in all_embeds)
395
+ d_model = all_embeds[0].shape[-1]
396
+ padded_embeds = vis_emb.new_zeros(num_beams, max_len, d_model)
397
+ padded_mask = vis_mask.new_zeros(num_beams, max_len) if vis_mask is not None else None
398
+ for idx, emb in enumerate(all_embeds):
399
+ padded_embeds[idx, :emb.shape[0]] = emb
400
+ if padded_mask is not None:
401
+ padded_mask[idx, :emb.shape[0]] = all_masks[idx]
402
+
403
+ logits = self._forward_embeds(padded_embeds, padded_mask)
404
+ candidates = []
405
+ for beam_idx, (score, ids, done) in enumerate(beams):
406
+ if done:
407
+ candidates.append((score, ids, True))
408
+ continue
409
+ last_pos = vis_len + len(ids) - 1
410
+ log_p = torch.log_softmax(logits[beam_idx, last_pos, :], dim=-1)
411
+ if len(ids) == 0 and beam_idx > 0:
412
+ log_p = log_p.fill_(-1e9)
413
+ for lp, tok in zip(*map(lambda t: t.tolist(), log_p.topk(num_beams))):
414
+ candidates.append((score + lp, ids + [tok], tok == eos_id))
415
+ candidates.sort(key=lambda x: -x[0])
416
+ beams = candidates[:num_beams]
417
+ if all(done for _, _, done in beams):
418
+ break
419
+
420
+ best_ids = max(beams, key=lambda x: x[0])[1]
421
+ if not best_ids:
422
+ return torch.zeros(1, 0, dtype=torch.long, device=device)
423
+ return torch.tensor(best_ids, dtype=torch.long, device=device).unsqueeze(0)
424
+
425
+
426
+ class Nav2TexModel(PreTrainedModel):
427
+ config_class = Nav2TexConfig
428
+ base_model_prefix = "model"
429
+ main_input_name = "pixel_values"
430
+
431
+ def __init__(self, config: Nav2TexConfig):
432
+ super().__init__(config)
433
+ self.config = config
434
+ self.visual_encoder = VisualEncoder(
435
+ NaViT_Encoder(
436
+ image_size=(config.image_height, config.max_image_width),
437
+ patch_size=config.patch_size,
438
+ dim=config.navit_dim,
439
+ depth=config.navit_depth,
440
+ heads=config.navit_heads,
441
+ mlp_dim=config.navit_mlp_dim,
442
+ dim_head=config.navit_dim_head,
443
+ dropout=config.navit_dropout,
444
+ emb_dropout=config.navit_emb_dropout,
445
+ ),
446
+ MLPProjector(
447
+ vision_hidden_size=config.vision_hidden_size,
448
+ llm_hidden_size=config.llm_hidden_size,
449
+ intermediate_size=config.projector_intermediate_size,
450
+ ),
451
+ max_visual_tokens=config.max_visual_tokens,
452
+ )
453
+ self.decoder = CustomDecoder(config)
454
+ self.post_init()
455
+
456
+ def tie_weights(self, **kwargs):
457
+ if self.config.decoder_weights_tied:
458
+ self.decoder.tie_weights()
459
+ else:
460
+ self.decoder.untie_weights()
461
+
462
+ def _init_weights(self, module):
463
+ return
464
+
465
+ @staticmethod
466
+ def _to_batched_images(pixel_values):
467
+ if isinstance(pixel_values, list):
468
+ return pixel_values
469
+ if isinstance(pixel_values, torch.Tensor):
470
+ return [[img] for img in pixel_values]
471
+ raise TypeError(f"Unsupported pixel_values type: {type(pixel_values)}")
472
+
473
+ def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None, **kwargs):
474
+ batched_images = self._to_batched_images(pixel_values)
475
+ ve, vm = self.visual_encoder(batched_images)
476
+ if input_ids is None:
477
+ return BaseModelOutput(last_hidden_state=ve)
478
+ te = self.decoder.get_input_embeddings()(input_ids)
479
+ inputs_embeds = torch.cat([ve, te], dim=1)
480
+ am = torch.cat([vm.to(dtype=attention_mask.dtype), attention_mask], dim=1)
481
+ lv = torch.full((labels.shape[0], ve.shape[1]), -100, dtype=labels.dtype, device=labels.device)
482
+ out = self.decoder(
483
+ inputs_embeds=inputs_embeds,
484
+ attention_mask=am,
485
+ labels=torch.cat([lv, labels], dim=1),
486
+ )
487
+ return BaseModelOutput(last_hidden_state=out.last_hidden_state, hidden_states=(out.hidden_states[0],))
488
+
489
+ @torch.no_grad()
490
+ def generate(self, pixel_values, max_new_tokens=None, num_beams=None):
491
+ batched_images = self._to_batched_images(pixel_values)
492
+ ve, vm = self.visual_encoder(batched_images)
493
+ batch = ve.shape[0]
494
+ bos_id = self.config.decoder_arch["bos_id"]
495
+ bos_emb = self.decoder.get_input_embeddings()(
496
+ torch.full((batch, 1), bos_id, dtype=torch.long, device=ve.device)
497
+ )
498
+ inputs_embeds = torch.cat([ve, bos_emb], dim=1)
499
+ attention_mask = torch.cat([
500
+ vm.to(dtype=torch.long),
501
+ torch.ones(batch, 1, dtype=torch.long, device=ve.device)
502
+ ], dim=1)
503
+ return self.decoder.generate(
504
+ inputs_embeds=inputs_embeds,
505
+ attention_mask=attention_mask,
506
+ max_new_tokens=max_new_tokens or self.config.max_new_tokens,
507
+ num_beams=num_beams or self.config.num_beams,
508
  )