harryrobert commited on
Commit
187408a
·
verified ·
1 Parent(s): 41da41c

Upload folder using huggingface_hub

Browse files
__pycache__/configuration_latex_decoder.cpython-312.pyc ADDED
Binary file (2.09 kB). View file
 
__pycache__/configuration_latex_ocr.cpython-312.pyc ADDED
Binary file (2.5 kB). View file
 
__pycache__/image_processing_latex_ocr.cpython-312.pyc ADDED
Binary file (4.01 kB). View file
 
__pycache__/image_processing_latex_ocr.cpython-313.pyc ADDED
Binary file (4.05 kB). View file
 
__pycache__/modeling_latex_decoder.cpython-312.pyc ADDED
Binary file (15.7 kB). View file
 
__pycache__/modeling_latex_ocr.cpython-312.pyc ADDED
Binary file (35.6 kB). View file
 
__pycache__/pipeline_latex_ocr.cpython-312.pyc ADDED
Binary file (3.94 kB). View file
 
__pycache__/processing_latex_ocr.cpython-312.pyc ADDED
Binary file (1.98 kB). View file
 
__pycache__/tokenization_latex_ocr.cpython-312.pyc ADDED
Binary file (6.26 kB). View file
 
__pycache__/tokenization_latex_ocr.cpython-313.pyc ADDED
Binary file (6.39 kB). View file
 
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["Nav2TexModel"],
3
+ "model_type": "nav2tex",
4
+ "processor_class": "Nav2TexProcessor",
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_latex_ocr.Nav2TexConfig",
7
+ "AutoProcessor": "processing_latex_ocr.Nav2TexProcessor",
8
+ "AutoImageProcessor": "image_processing_latex_ocr.Nav2TexImageProcessor",
9
+ "AutoModel": "modeling_latex_ocr.Nav2TexModel",
10
+ "AutoTokenizer": ["tokenization_latex_ocr.LaTeXTokenizer", null],
11
+ "Pipeline": "pipeline_latex_ocr.Nav2TexPipeline"
12
+ },
13
+ "decoder_weights_tied": false,
14
+ "tie_word_embeddings": false,
15
+ "pad_token_id": 0,
16
+ "bos_token_id": 2,
17
+ "eos_token_id": 3,
18
+ "decoder_start_token_id": 2,
19
+ "navit_dim": 512,
20
+ "vision_hidden_size": 512,
21
+ "llm_hidden_size": 512,
22
+ "vocab_size": 2046,
23
+ "patch_size": 16,
24
+ "image_height": 64,
25
+ "max_visual_tokens": 256,
26
+ "max_new_tokens": 200,
27
+ "num_beams": 4,
28
+ "decoder_arch": {
29
+ "vocab_size": 2046,
30
+ "pad_id": 0,
31
+ "bos_id": 2,
32
+ "eos_id": 3,
33
+ "d_model": 512,
34
+ "n_heads": 8,
35
+ "n_layers": 6,
36
+ "d_ff": 1408,
37
+ "dropout": 0.1,
38
+ "max_seq_len": 200,
39
+ "rope_theta": 10000.0,
40
+ "tie_weights": false
41
+ }
42
+ }
configuration_latex_decoder.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class LaTeXDecoderConfig(PretrainedConfig):
5
+ model_type = "latex_decoder"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size: int = 8192,
10
+ pad_id: int = 0,
11
+ bos_id: int = 2,
12
+ eos_id: int = 3,
13
+ d_model: int = 512,
14
+ n_heads: int = 8,
15
+ n_layers: int = 6,
16
+ d_ff: int = 1408,
17
+ dropout: float = 0.1,
18
+ max_seq_len: int = 200,
19
+ rope_theta: float = 10000.0,
20
+ tie_weights: bool = False,
21
+ **kwargs,
22
+ ):
23
+ kwargs.pop("pad_token_id", None)
24
+ kwargs.pop("bos_token_id", None)
25
+ kwargs.pop("eos_token_id", None)
26
+ super().__init__(
27
+ pad_token_id=pad_id,
28
+ bos_token_id=bos_id,
29
+ eos_token_id=eos_id,
30
+ **kwargs,
31
+ )
32
+ self.vocab_size = vocab_size
33
+ self.pad_id = pad_id
34
+ self.bos_id = bos_id
35
+ self.eos_id = eos_id
36
+ self.d_model = d_model
37
+ self.n_heads = n_heads
38
+ self.n_layers = n_layers
39
+ self.d_ff = d_ff
40
+ self.dropout = dropout
41
+ self.max_seq_len = max_seq_len
42
+ self.rope_theta = rope_theta
43
+ self.tie_weights = tie_weights
44
+
45
+ @property
46
+ def head_dim(self) -> int:
47
+ assert self.d_model % self.n_heads == 0
48
+ return self.d_model // self.n_heads
configuration_latex_ocr.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class Nav2TexConfig(PretrainedConfig):
5
+ model_type = "nav2tex"
6
+
7
+ def __init__(
8
+ self,
9
+ patch_size: int = 16,
10
+ image_height: int = 64,
11
+ max_image_width: int = 1024,
12
+ max_image_height: int = 640,
13
+ resize_in_dataset: bool = True,
14
+ max_token_len: int = 200,
15
+ navit_dim: int = 512,
16
+ navit_depth: int = 8,
17
+ navit_heads: int = 8,
18
+ navit_dim_head: int = 64,
19
+ navit_mlp_dim: int = 2048,
20
+ navit_dropout: float = 0.0,
21
+ navit_emb_dropout: float = 0.0,
22
+ vision_hidden_size: int = 512,
23
+ llm_hidden_size: int = 512,
24
+ projector_intermediate_size: int = 1024,
25
+ max_visual_tokens: int = 256,
26
+ max_new_tokens: int = 200,
27
+ num_beams: int = 4,
28
+ decoder_arch: dict | None = None,
29
+ decoder_weights_tied: bool = False,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.patch_size = patch_size
34
+ self.image_height = image_height
35
+ self.max_image_width = max_image_width
36
+ self.max_image_height = max_image_height
37
+ self.resize_in_dataset = resize_in_dataset
38
+ self.max_token_len = max_token_len
39
+ self.navit_dim = navit_dim
40
+ self.navit_depth = navit_depth
41
+ self.navit_heads = navit_heads
42
+ self.navit_dim_head = navit_dim_head
43
+ self.navit_mlp_dim = navit_mlp_dim
44
+ self.navit_dropout = navit_dropout
45
+ self.navit_emb_dropout = navit_emb_dropout
46
+ self.vision_hidden_size = vision_hidden_size
47
+ self.llm_hidden_size = llm_hidden_size
48
+ self.projector_intermediate_size = projector_intermediate_size
49
+ self.max_visual_tokens = max_visual_tokens
50
+ self.max_new_tokens = max_new_tokens
51
+ self.num_beams = num_beams
52
+ self.decoder_arch = decoder_arch or {
53
+ "vocab_size": 2046,
54
+ "pad_id": 0,
55
+ "bos_id": 2,
56
+ "eos_id": 3,
57
+ "d_model": 512,
58
+ "n_heads": 8,
59
+ "n_layers": 6,
60
+ "d_ff": 1408,
61
+ "dropout": 0.1,
62
+ "max_seq_len": 200,
63
+ "rope_theta": 10000.0,
64
+ "tie_weights": True,
65
+ }
66
+ self.decoder_weights_tied = decoder_weights_tied
image_processing_latex_ocr.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image, ImageOps, ImageEnhance
4
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ def _prepare_for_inference(img: Image.Image) -> Image.Image:
11
+ """
12
+ Normalize real-world inputs (screenshots, camera, PDF crops) to the
13
+ clean white-background style the model was trained on.
14
+
15
+ Steps applied in order:
16
+ 1. Convert to grayscale luminance to check background tone
17
+ 2. If dark background (mean < 0.45), invert — handles dark mode / night mode
18
+ 3. Auto-contrast to stretch histogram — fixes low-contrast scans/photos
19
+ 4. Mild sharpening to counter screenshot JPEG blur
20
+ """
21
+ arr = np.array(img.convert("L"), dtype=np.float32) / 255.0
22
+ if arr.mean() < 0.45:
23
+ img = ImageOps.invert(img.convert("RGB"))
24
+ img = ImageOps.autocontrast(img, cutoff=1)
25
+ img = ImageEnhance.Sharpness(img).enhance(1.4)
26
+ return img.convert("RGB")
27
+
28
+
29
+ class Nav2TexImageProcessor(BaseImageProcessor):
30
+ model_type = "nav2tex"
31
+
32
+ def __init__(
33
+ self,
34
+ image_height=64,
35
+ max_image_width=1024,
36
+ patch_size=16,
37
+ **kwargs
38
+ ):
39
+ super().__init__(**kwargs)
40
+ self.image_height = image_height
41
+ self.max_image_width = max_image_width
42
+ self.patch_size = patch_size
43
+
44
+ def preprocess(self, images, do_prepare=True, **kwargs) -> BatchFeature:
45
+ if not isinstance(images, list):
46
+ images = [images]
47
+
48
+ processed_images = []
49
+ for img in images:
50
+ if img.mode != "RGB":
51
+ img = img.convert("RGB")
52
+
53
+ if do_prepare:
54
+ img = _prepare_for_inference(img)
55
+
56
+ w, h = img.size
57
+ new_w = int(round(w * self.image_height / max(h, 1)))
58
+ new_w = min(new_w, self.max_image_width)
59
+ new_w = max((new_w // self.patch_size) * self.patch_size, self.patch_size)
60
+
61
+ if (w, h) != (new_w, self.image_height):
62
+ img = img.resize((new_w, self.image_height), Image.BILINEAR)
63
+
64
+ img_array = np.array(img).astype(np.float32) / 255.0
65
+ img_array = (img_array - 0.5) / 0.5
66
+ img_array = np.transpose(img_array, (2, 0, 1))
67
+ processed_images.append(img_array)
68
+
69
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type="pt")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:059d08d87e8d08a784566c7da2ac5a6fb900fb3a45076e4265dc52fab861f30f
3
+ size 194225496
modeling_latex_decoder.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # update v2
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import CausalLMOutput
10
+
11
+ try:
12
+ from .configuration_latex_decoder import LaTeXDecoderConfig
13
+ except ImportError:
14
+ from latex_ocr.configuration_latex_decoder import LaTeXDecoderConfig
15
+
16
+
17
+ class RMSNorm(nn.Module):
18
+ def __init__(self, d_model: int, eps: float = 1e-6):
19
+ super().__init__()
20
+ self.eps = eps
21
+ self.weight = nn.Parameter(torch.ones(d_model))
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
25
+ return x / rms * self.weight
26
+
27
+
28
+ def _build_rope_cache(seq_len, head_dim, theta=10000.0, device=None, dtype=torch.float32):
29
+ half = head_dim // 2
30
+ inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
31
+ pos = torch.arange(seq_len, device=device, dtype=torch.float32)
32
+ freqs = torch.outer(pos, inv_freq)
33
+ emb = torch.cat([freqs, freqs], dim=-1)
34
+ return emb.cos().to(dtype), emb.sin().to(dtype)
35
+
36
+
37
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
38
+ half = x.shape[-1] // 2
39
+ x1, x2 = x[..., :half], x[..., half:]
40
+ return torch.cat([-x2, x1], dim=-1)
41
+
42
+
43
+ def apply_rope(q, k, cos, sin):
44
+ cos = cos.unsqueeze(0).unsqueeze(0)
45
+ sin = sin.unsqueeze(0).unsqueeze(0)
46
+ return q * cos + _rotate_half(q) * sin, k * cos + _rotate_half(k) * sin
47
+
48
+
49
+ class CausalSelfAttention(nn.Module):
50
+ def __init__(self, cfg: LaTeXDecoderConfig):
51
+ super().__init__()
52
+ self.n_heads = cfg.n_heads
53
+ self.head_dim = cfg.head_dim
54
+ self.d_model = cfg.d_model
55
+ self.dropout_p = cfg.dropout
56
+ self.rope_theta = cfg.rope_theta
57
+
58
+ self.qkv_proj = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
59
+ self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
60
+ self._rope_cache: dict = {}
61
+
62
+ def _get_rope(self, seq_len, device, dtype):
63
+ key = (seq_len, str(device), dtype)
64
+ if key not in self._rope_cache:
65
+ self._rope_cache[key] = _build_rope_cache(seq_len, self.head_dim, self.rope_theta, device, dtype)
66
+ return self._rope_cache[key]
67
+
68
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
69
+ B, T, C = x.shape
70
+ q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
71
+
72
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
73
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
74
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
75
+
76
+ cos, sin = self._get_rope(T, x.device, q.dtype)
77
+ q, k = apply_rope(q, k, cos, sin)
78
+
79
+ dropout_p = self.dropout_p if self.training else 0.0
80
+
81
+ if attention_mask is not None:
82
+ causal = torch.triu(torch.full((T, T), float("-inf"), device=x.device, dtype=q.dtype), diagonal=1)
83
+ pad = (~attention_mask).unsqueeze(1).unsqueeze(2)
84
+ attn_bias = causal.unsqueeze(0).unsqueeze(0).expand(B, 1, T, T).clone()
85
+ attn_bias = attn_bias.masked_fill(pad, float("-inf"))
86
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias, dropout_p=dropout_p, is_causal=False)
87
+ else:
88
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, is_causal=True)
89
+
90
+ return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C))
91
+
92
+
93
+ class SwiGLUFFN(nn.Module):
94
+ def __init__(self, cfg: LaTeXDecoderConfig):
95
+ super().__init__()
96
+ self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
97
+ self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
98
+ self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
99
+ self.dropout = nn.Dropout(cfg.dropout)
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
103
+
104
+
105
+ class TransformerBlock(nn.Module):
106
+ def __init__(self, cfg: LaTeXDecoderConfig):
107
+ super().__init__()
108
+ self.norm1 = RMSNorm(cfg.d_model)
109
+ self.attn = CausalSelfAttention(cfg)
110
+ self.norm2 = RMSNorm(cfg.d_model)
111
+ self.ffn = SwiGLUFFN(cfg)
112
+ self.drop = nn.Dropout(cfg.dropout)
113
+
114
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
115
+ x = x + self.drop(self.attn(self.norm1(x), attention_mask))
116
+ x = x + self.drop(self.ffn(self.norm2(x)))
117
+ return x
118
+
119
+
120
+ class LaTeXDecoderForCausalLM(PreTrainedModel):
121
+ config_class = LaTeXDecoderConfig
122
+ base_model_prefix = "model"
123
+ supports_gradient_checkpointing = False
124
+
125
+ def __init__(self, config: LaTeXDecoderConfig):
126
+ super().__init__(config)
127
+
128
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id)
129
+ self.embed_drop = nn.Dropout(config.dropout)
130
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
131
+ self.norm_final = RMSNorm(config.d_model)
132
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
133
+
134
+ if config.tie_weights:
135
+ self.lm_head.weight = self.embed_tokens.weight
136
+
137
+ self.post_init()
138
+
139
+ def _init_weights(self, module: nn.Module):
140
+ if isinstance(module, nn.Linear):
141
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
142
+ if module.bias is not None:
143
+ nn.init.zeros_(module.bias)
144
+ elif isinstance(module, nn.Embedding):
145
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
146
+
147
+ def forward(
148
+ self,
149
+ input_ids: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ labels: Optional[torch.Tensor] = None,
152
+ **kwargs,
153
+ ) -> CausalLMOutput:
154
+ x = self.embed_drop(self.embed_tokens(input_ids))
155
+ for layer in self.layers:
156
+ x = layer(x, attention_mask)
157
+ logits = self.lm_head(self.norm_final(x))
158
+
159
+ loss = None
160
+ if labels is not None:
161
+ shift_logits = logits[:, :-1, :].contiguous()
162
+ shift_labels = labels[:, 1:].contiguous()
163
+ shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_id, -100)
164
+ loss = F.cross_entropy(
165
+ shift_logits.view(-1, self.config.vocab_size),
166
+ shift_labels.view(-1),
167
+ ignore_index=-100,
168
+ )
169
+
170
+ return CausalLMOutput(loss=loss, logits=logits)
171
+
172
+ @torch.inference_mode()
173
+ def generate(
174
+ self,
175
+ prompt_ids: torch.Tensor,
176
+ max_new_tokens: int = 200,
177
+ temperature: float = 1.0,
178
+ top_p: float = 0.9,
179
+ eos_id: Optional[int] = None,
180
+ ) -> torch.Tensor:
181
+ eos = eos_id if eos_id is not None else self.config.eos_id
182
+ generated = prompt_ids.clone()
183
+
184
+ for _ in range(max_new_tokens):
185
+ ctx = generated[:, -self.config.max_seq_len:]
186
+ logits = self.forward(ctx).logits[:, -1, :]
187
+
188
+ if temperature == 0.0:
189
+ next_id = logits.argmax(dim=-1, keepdim=True)
190
+ else:
191
+ probs = F.softmax(logits / temperature, dim=-1)
192
+ sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
193
+ cumsum = sorted_probs.cumsum(dim=-1)
194
+ sorted_probs[cumsum - sorted_probs > top_p] = 0.0
195
+ sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
196
+ next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
197
+
198
+ generated = torch.cat([generated, next_id], dim=-1)
199
+ if next_id.item() == eos:
200
+ break
201
+
202
+ return generated
modeling_latex_ocr.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from functools import partial
5
+ from torch import nn
6
+ from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutput
9
+
10
+ try:
11
+ from .configuration_latex_decoder import LaTeXDecoderConfig
12
+ from .configuration_latex_ocr import Nav2TexConfig
13
+ from .modeling_latex_decoder import LaTeXDecoderForCausalLM
14
+ except ImportError:
15
+ from nav2tex.configuration_latex_decoder import LaTeXDecoderConfig
16
+ from nav2tex.configuration_latex_ocr import Nav2TexConfig
17
+ from nav2tex.modeling_latex_decoder import LaTeXDecoderForCausalLM
18
+
19
+ try:
20
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
21
+ from flash_attn.bert_padding import pad_input, unpad_input
22
+ HAS_FLASH_ATTN = True
23
+ except ImportError:
24
+ HAS_FLASH_ATTN = False
25
+
26
+
27
+ def exists(val):
28
+ return val is not None
29
+
30
+
31
+ def divisible_by(numer, denom):
32
+ return (numer % denom) == 0
33
+
34
+
35
+ class LayerNorm(nn.Module):
36
+ def __init__(self, dim):
37
+ super().__init__()
38
+ self.normalized_shape = (dim,)
39
+ self.eps = 1e-5
40
+ self.weight = nn.Parameter(torch.ones(dim))
41
+ self.bias = nn.Parameter(torch.zeros(dim))
42
+
43
+ def forward(self, x):
44
+ return F.layer_norm(
45
+ x.float(), self.normalized_shape,
46
+ self.weight.float(), self.bias.float(), self.eps,
47
+ ).to(x.dtype)
48
+
49
+
50
+ class RMSNorm(nn.Module):
51
+ def __init__(self, heads, dim):
52
+ super().__init__()
53
+ self.scale = dim ** 0.5
54
+ self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
55
+
56
+ def forward(self, x):
57
+ return F.normalize(x, dim=-1) * self.scale * self.gamma.to(x.dtype)
58
+
59
+
60
+ def rotate_half(x):
61
+ x1, x2 = x.chunk(2, dim=-1)
62
+ return torch.cat([-x2, x1], dim=-1)
63
+
64
+
65
+ def apply_2d_rope(q, k, h_idx, w_idx):
66
+ _, _, _, d = q.shape
67
+ if d % 4 != 0:
68
+ raise ValueError(f"apply_2d_rope expects dim_head divisible by 4, got D={d}")
69
+ dim_half = d // 2
70
+ dim_quarter = d // 4
71
+ inv_freq = 1.0 / (10000 ** (torch.arange(dim_quarter, device=q.device).float() / dim_quarter))
72
+ h_theta = h_idx[..., None].float() * inv_freq
73
+ w_theta = w_idx[..., None].float() * inv_freq
74
+ sin_h = torch.cat([h_theta.sin(), h_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
75
+ cos_h = torch.cat([h_theta.cos(), h_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
76
+ sin_w = torch.cat([w_theta.sin(), w_theta.sin()], dim=-1).to(q.dtype)[:, None, :, :]
77
+ cos_w = torch.cat([w_theta.cos(), w_theta.cos()], dim=-1).to(q.dtype)[:, None, :, :]
78
+
79
+ def rope(x, sin, cos):
80
+ return x * cos + rotate_half(x) * sin
81
+
82
+ q = torch.cat([rope(q[..., :dim_half], sin_h, cos_h), rope(q[..., dim_half:], sin_w, cos_w)], dim=-1)
83
+ k = torch.cat([rope(k[..., :dim_half], sin_h, cos_h), rope(k[..., dim_half:], sin_w, cos_w)], dim=-1)
84
+ return q, k
85
+
86
+
87
+ class FeedForward(nn.Module):
88
+ def __init__(self, dim, hidden_dim, dropout=0.0):
89
+ super().__init__()
90
+ self.net = nn.Sequential(
91
+ LayerNorm(dim),
92
+ nn.Linear(dim, hidden_dim),
93
+ nn.GELU(),
94
+ nn.Dropout(dropout),
95
+ nn.Linear(hidden_dim, dim),
96
+ nn.Dropout(dropout),
97
+ )
98
+
99
+ def forward(self, x):
100
+ return self.net(x)
101
+
102
+
103
+ class Attention(nn.Module):
104
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
105
+ super().__init__()
106
+ inner_dim = dim_head * heads
107
+ self.heads = heads
108
+ self.norm = LayerNorm(dim)
109
+ self.q_norm = RMSNorm(heads, dim_head)
110
+ self.k_norm = RMSNorm(heads, dim_head)
111
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
112
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
113
+ self.attend = nn.Softmax(dim=-1)
114
+ self.dropout = nn.Dropout(dropout)
115
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim, bias=False), nn.Dropout(dropout))
116
+
117
+ def forward(self, x, mask=None, attn_mask=None, positions=None):
118
+ x = self.norm(x)
119
+ q = self.to_q(x)
120
+ k, v = self.to_kv(x).chunk(2, dim=-1)
121
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v))
122
+ q = self.q_norm(q)
123
+ k = self.k_norm(k)
124
+
125
+ if positions is not None:
126
+ q, k = apply_2d_rope(q, k, positions[0], positions[1])
127
+
128
+ if HAS_FLASH_ATTN and x.is_cuda and attn_mask is None:
129
+ fa_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.bfloat16
130
+ q_ = rearrange(q, "b h n d -> b n h d").contiguous().to(fa_dtype)
131
+ k_ = rearrange(k, "b h n d -> b n h d").contiguous().to(fa_dtype)
132
+ v_ = rearrange(v, "b h n d -> b n h d").contiguous().to(fa_dtype)
133
+ if exists(mask):
134
+ batch, seqlen = mask.shape
135
+ q_unpad, indices, cu_q, max_q, *_ = unpad_input(q_, mask)
136
+ k_unpad, _, cu_k, max_k, *_ = unpad_input(k_, mask)
137
+ v_unpad, _, _, _, *_ = unpad_input(v_, mask)
138
+ out_unpad = flash_attn_varlen_func(
139
+ q_unpad, k_unpad, v_unpad,
140
+ cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
141
+ max_seqlen_q=max_q, max_seqlen_k=max_k,
142
+ dropout_p=self.dropout.p if self.training else 0.0,
143
+ causal=False,
144
+ )
145
+ out = pad_input(out_unpad, indices, batch, seqlen)
146
+ else:
147
+ out = flash_attn_func(
148
+ q_, k_, v_,
149
+ dropout_p=self.dropout.p if self.training else 0.0,
150
+ causal=False,
151
+ )
152
+ out = rearrange(out, "b n h d -> b n (h d)").to(x.dtype)
153
+ else:
154
+ dots = torch.matmul(q, k.transpose(-1, -2))
155
+ if exists(mask):
156
+ dots = dots.masked_fill(~mask[:, None, None, :], -torch.finfo(dots.dtype).max)
157
+ if exists(attn_mask):
158
+ dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
159
+ attn = self.dropout(self.attend(dots))
160
+ out = rearrange(torch.matmul(attn, v), "b h n d -> b n (h d)")
161
+ return self.to_out(out)
162
+
163
+
164
+ class Transformer(nn.Module):
165
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0):
166
+ super().__init__()
167
+ self.layers = nn.ModuleList([
168
+ nn.ModuleList([Attention(dim, heads, dim_head, dropout), FeedForward(dim, mlp_dim, dropout)])
169
+ for _ in range(depth)
170
+ ])
171
+ self.norm = LayerNorm(dim)
172
+
173
+ def forward(self, x, mask=None, attn_mask=None, positions=None):
174
+ for attn, ff in self.layers:
175
+ x = attn(x, mask=mask, attn_mask=attn_mask, positions=positions) + x
176
+ x = ff(x) + x
177
+ return self.norm(x)
178
+
179
+
180
+ class NaViT_Encoder(nn.Module):
181
+ def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim,
182
+ channels=3, dim_head=64, dropout=0.0, emb_dropout=0.0):
183
+ super().__init__()
184
+ image_height, image_width = image_size
185
+ assert divisible_by(image_height, patch_size)
186
+ assert divisible_by(image_width, patch_size)
187
+ self.patch_size = patch_size
188
+ self.to_patch_embedding = nn.Sequential(
189
+ LayerNorm(channels * patch_size ** 2),
190
+ nn.Linear(channels * patch_size ** 2, dim),
191
+ LayerNorm(dim),
192
+ )
193
+ self.dropout = nn.Dropout(emb_dropout)
194
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
195
+
196
+ @property
197
+ def device(self):
198
+ return next(self.parameters()).device
199
+
200
+ def forward(self, batched_images):
201
+ p = self.patch_size
202
+ device = self.device
203
+ arange = partial(torch.arange, device=device)
204
+ pad_sequence = partial(orig_pad_sequence, batch_first=True)
205
+ batched_sequences, batched_positions = [], []
206
+
207
+ for images in batched_images:
208
+ sequences, positions = [], []
209
+ for image in images:
210
+ _, h, w = image.shape
211
+ ph, pw = h // p, w // p
212
+ seq = rearrange(image, "c (h p1) (w p2) -> (h w) (c p1 p2)", p1=p, p2=p)
213
+ pos = torch.stack(torch.meshgrid(arange(ph), arange(pw), indexing="ij"), dim=-1)
214
+ sequences.append(seq)
215
+ positions.append(rearrange(pos, "h w c -> (h w) c"))
216
+ batched_sequences.append(torch.cat(sequences, dim=0))
217
+ batched_positions.append(torch.cat(positions, dim=0))
218
+
219
+ patches = pad_sequence(batched_sequences)
220
+ patch_positions = pad_sequence(batched_positions)
221
+ lengths = torch.tensor([seq.shape[0] for seq in batched_sequences], device=device)
222
+ mask = torch.arange(patches.shape[1], device=device)[None, :] < lengths[:, None]
223
+ x = self.to_patch_embedding(patches.to(next(self.parameters()).dtype))
224
+ h_idx, w_idx = patch_positions.unbind(dim=-1)
225
+ x = self.dropout(x)
226
+ x = self.transformer(x, mask=mask, positions=(h_idx, w_idx))
227
+ return x, mask
228
+
229
+
230
+ class MLPProjector(nn.Module):
231
+ def __init__(self, vision_hidden_size=1024, llm_hidden_size=512, intermediate_size=2048):
232
+ super().__init__()
233
+ self.norm = nn.LayerNorm(vision_hidden_size)
234
+ self.gate_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
235
+ self.up_proj = nn.Linear(vision_hidden_size, intermediate_size, bias=False)
236
+ self.down_proj = nn.Linear(intermediate_size, llm_hidden_size, bias=False)
237
+
238
+ def forward(self, x):
239
+ x = self.norm(x)
240
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
241
+
242
+
243
+ class VisualEncoder(nn.Module):
244
+ def __init__(self, encoder, bridge, max_visual_tokens):
245
+ super().__init__()
246
+ self.navit = encoder
247
+ self.projector = bridge
248
+ self.max_visual_tokens = max_visual_tokens
249
+
250
+ def forward(self, batched_images):
251
+ x, mask = self.navit(batched_images)
252
+ if x.shape[1] > self.max_visual_tokens:
253
+ x = x[:, :self.max_visual_tokens]
254
+ mask = mask[:, :self.max_visual_tokens]
255
+ return self.projector(x), mask
256
+
257
+
258
+ class CustomDecoder(nn.Module):
259
+ def __init__(self, config: Nav2TexConfig):
260
+ super().__init__()
261
+ dec = config.decoder_arch
262
+ self._model = LaTeXDecoderForCausalLM(
263
+ LaTeXDecoderConfig(
264
+ vocab_size=dec["vocab_size"],
265
+ pad_id=dec["pad_id"],
266
+ bos_id=dec["bos_id"],
267
+ eos_id=dec["eos_id"],
268
+ d_model=dec["d_model"],
269
+ n_heads=dec["n_heads"],
270
+ n_layers=dec["n_layers"],
271
+ d_ff=dec["d_ff"],
272
+ dropout=dec.get("dropout", 0.1),
273
+ max_seq_len=dec["max_seq_len"],
274
+ rope_theta=dec.get("rope_theta", 10000.0),
275
+ tie_weights=dec.get("tie_weights", True),
276
+ )
277
+ )
278
+ self.pad_token_id = self._model.config.pad_id
279
+ self.eos_token_id = self._model.config.eos_id
280
+ self._vocab_size = self._model.config.vocab_size
281
+ self._pad_id = self._model.config.pad_id
282
+ if not config.decoder_weights_tied:
283
+ self.untie_weights()
284
+
285
+ def get_input_embeddings(self):
286
+ return self._model.embed_tokens
287
+
288
+ def tie_weights(self):
289
+ self._model.lm_head.weight = self._model.embed_tokens.weight
290
+
291
+ def untie_weights(self):
292
+ if self.are_weights_tied():
293
+ self._model.lm_head.weight = nn.Parameter(self._model.embed_tokens.weight.detach().clone())
294
+
295
+ def are_weights_tied(self):
296
+ return self._model.lm_head.weight.data_ptr() == self._model.embed_tokens.weight.data_ptr()
297
+
298
+ def _forward_embeds(self, inputs_embeds, attention_mask=None):
299
+ x = self._model.embed_drop(inputs_embeds)
300
+ mask = attention_mask.bool() if attention_mask is not None else None
301
+ for layer in self._model.layers:
302
+ x = layer(x, mask)
303
+ return self._model.lm_head(self._model.norm_final(x))
304
+
305
+ def forward(self, inputs_embeds=None, attention_mask=None, labels=None, **kwargs):
306
+ logits = self._forward_embeds(inputs_embeds, attention_mask)
307
+ loss = None
308
+ if labels is not None:
309
+ shift_logits = logits[:, :-1].contiguous()
310
+ shift_labels = labels[:, 1:].contiguous().masked_fill(
311
+ labels[:, 1:].contiguous() == self._pad_id, -100
312
+ )
313
+ loss = F.cross_entropy(
314
+ shift_logits.view(-1, self._vocab_size),
315
+ shift_labels.view(-1),
316
+ ignore_index=-100,
317
+ )
318
+ return BaseModelOutput(last_hidden_state=logits, hidden_states=(loss,))
319
+
320
+ @torch.no_grad()
321
+ def generate(self, inputs_embeds, attention_mask, max_new_tokens, num_beams=1):
322
+ device = inputs_embeds.device
323
+ batch = inputs_embeds.shape[0]
324
+
325
+ if num_beams > 1:
326
+ # beam search: only supports batch_size=1
327
+ assert batch == 1, "beam search only supports batch_size=1"
328
+ return self._beam_search(inputs_embeds, attention_mask, max_new_tokens, num_beams)
329
+
330
+ return self._greedy_batch(inputs_embeds, attention_mask, max_new_tokens)
331
+
332
+ @torch.no_grad()
333
+ def _greedy_batch(self, inputs_embeds, attention_mask, max_new_tokens):
334
+ """Greedy decoding with true batch support."""
335
+ eos_id = self.eos_token_id
336
+ pad_id = self._pad_id
337
+ device = inputs_embeds.device
338
+ batch = inputs_embeds.shape[0]
339
+ d_model = inputs_embeds.shape[-1]
340
+
341
+ # generated token ids per sample, and finished flags
342
+ gen_ids = [[] for _ in range(batch)]
343
+ finished = torch.zeros(batch, dtype=torch.bool, device=device)
344
+
345
+ cur_embeds = inputs_embeds # (B, vis_len, D)
346
+ cur_mask = attention_mask # (B, vis_len)
347
+
348
+ for _ in range(max_new_tokens):
349
+ logits = self._forward_embeds(cur_embeds, cur_mask) # (B, seq, vocab)
350
+ next_tok = logits[:, -1, :].argmax(dim=-1) # (B,)
351
+
352
+ for i in range(batch):
353
+ if not finished[i]:
354
+ gen_ids[i].append(next_tok[i].item())
355
+ finished |= (next_tok == eos_id)
356
+ if finished.all():
357
+ break
358
+
359
+ tok_emb = self._model.embed_tokens(next_tok.unsqueeze(1)) # (B, 1, D)
360
+ tok_mask = cur_mask.new_ones(batch, 1)
361
+ cur_embeds = torch.cat([cur_embeds, tok_emb], dim=1)
362
+ cur_mask = torch.cat([cur_mask, tok_mask], dim=1)
363
+
364
+ # pad to same length and return (B, max_len)
365
+ max_len = max((len(ids) for ids in gen_ids), default=0)
366
+ if max_len == 0:
367
+ return torch.zeros(batch, 0, dtype=torch.long, device=device)
368
+ out = torch.full((batch, max_len), pad_id, dtype=torch.long, device=device)
369
+ for i, ids in enumerate(gen_ids):
370
+ if ids:
371
+ out[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device)
372
+ return out
373
+
374
+ @torch.no_grad()
375
+ def _beam_search(self, inputs_embeds, attention_mask, max_new_tokens, num_beams):
376
+ """Original beam search (batch_size=1 only)."""
377
+ eos_id = self.eos_token_id
378
+ device = inputs_embeds.device
379
+ vis_emb = inputs_embeds[0]
380
+ vis_len = vis_emb.shape[0]
381
+ vis_mask = attention_mask[0] if attention_mask is not None else None
382
+ beams = [(0.0, [], False) for _ in range(num_beams)]
383
+
384
+ for _ in range(max_new_tokens):
385
+ all_embeds, all_masks = [], []
386
+ for score, ids, _ in beams:
387
+ tok_emb = self._model.embed_tokens(torch.tensor(ids, device=device, dtype=torch.long)) if ids else None
388
+ seq_emb = torch.cat([vis_emb, tok_emb], dim=0) if tok_emb is not None else vis_emb
389
+ all_embeds.append(seq_emb)
390
+ if vis_mask is not None:
391
+ tok_mask = vis_mask.new_ones(len(ids)) if ids else vis_mask.new_zeros(0)
392
+ all_masks.append(torch.cat([vis_mask, tok_mask]) if ids else vis_mask)
393
+
394
+ max_len = max(e.shape[0] for e in all_embeds)
395
+ d_model = all_embeds[0].shape[-1]
396
+ padded_embeds = vis_emb.new_zeros(num_beams, max_len, d_model)
397
+ padded_mask = vis_mask.new_zeros(num_beams, max_len) if vis_mask is not None else None
398
+ for idx, emb in enumerate(all_embeds):
399
+ padded_embeds[idx, :emb.shape[0]] = emb
400
+ if padded_mask is not None:
401
+ padded_mask[idx, :emb.shape[0]] = all_masks[idx]
402
+
403
+ logits = self._forward_embeds(padded_embeds, padded_mask)
404
+ candidates = []
405
+ for beam_idx, (score, ids, done) in enumerate(beams):
406
+ if done:
407
+ candidates.append((score, ids, True))
408
+ continue
409
+ last_pos = vis_len + len(ids) - 1
410
+ log_p = torch.log_softmax(logits[beam_idx, last_pos, :], dim=-1)
411
+ if len(ids) == 0 and beam_idx > 0:
412
+ log_p = log_p.fill_(-1e9)
413
+ for lp, tok in zip(*map(lambda t: t.tolist(), log_p.topk(num_beams))):
414
+ candidates.append((score + lp, ids + [tok], tok == eos_id))
415
+ candidates.sort(key=lambda x: -x[0])
416
+ beams = candidates[:num_beams]
417
+ if all(done for _, _, done in beams):
418
+ break
419
+
420
+ best_ids = max(beams, key=lambda x: x[0])[1]
421
+ if not best_ids:
422
+ return torch.zeros(1, 0, dtype=torch.long, device=device)
423
+ return torch.tensor(best_ids, dtype=torch.long, device=device).unsqueeze(0)
424
+
425
+
426
+ class Nav2TexModel(PreTrainedModel):
427
+ config_class = Nav2TexConfig
428
+ base_model_prefix = "model"
429
+ main_input_name = "pixel_values"
430
+
431
+ def __init__(self, config: Nav2TexConfig):
432
+ super().__init__(config)
433
+ self.config = config
434
+ self.visual_encoder = VisualEncoder(
435
+ NaViT_Encoder(
436
+ image_size=(config.image_height, config.max_image_width),
437
+ patch_size=config.patch_size,
438
+ dim=config.navit_dim,
439
+ depth=config.navit_depth,
440
+ heads=config.navit_heads,
441
+ mlp_dim=config.navit_mlp_dim,
442
+ dim_head=config.navit_dim_head,
443
+ dropout=config.navit_dropout,
444
+ emb_dropout=config.navit_emb_dropout,
445
+ ),
446
+ MLPProjector(
447
+ vision_hidden_size=config.vision_hidden_size,
448
+ llm_hidden_size=config.llm_hidden_size,
449
+ intermediate_size=config.projector_intermediate_size,
450
+ ),
451
+ max_visual_tokens=config.max_visual_tokens,
452
+ )
453
+ self.decoder = CustomDecoder(config)
454
+ self.post_init()
455
+
456
+ def tie_weights(self):
457
+ if self.config.decoder_weights_tied:
458
+ self.decoder.tie_weights()
459
+ else:
460
+ self.decoder.untie_weights()
461
+
462
+ def _init_weights(self, module):
463
+ return
464
+
465
+ @staticmethod
466
+ def _to_batched_images(pixel_values):
467
+ if isinstance(pixel_values, list):
468
+ return pixel_values
469
+ if isinstance(pixel_values, torch.Tensor):
470
+ return [[img] for img in pixel_values]
471
+ raise TypeError(f"Unsupported pixel_values type: {type(pixel_values)}")
472
+
473
+ def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None, **kwargs):
474
+ batched_images = self._to_batched_images(pixel_values)
475
+ ve, vm = self.visual_encoder(batched_images)
476
+ if input_ids is None:
477
+ return BaseModelOutput(last_hidden_state=ve)
478
+ te = self.decoder.get_input_embeddings()(input_ids)
479
+ inputs_embeds = torch.cat([ve, te], dim=1)
480
+ am = torch.cat([vm.to(dtype=attention_mask.dtype), attention_mask], dim=1)
481
+ lv = torch.full((labels.shape[0], ve.shape[1]), -100, dtype=labels.dtype, device=labels.device)
482
+ out = self.decoder(
483
+ inputs_embeds=inputs_embeds,
484
+ attention_mask=am,
485
+ labels=torch.cat([lv, labels], dim=1),
486
+ )
487
+ return BaseModelOutput(last_hidden_state=out.last_hidden_state, hidden_states=(out.hidden_states[0],))
488
+
489
+ @torch.no_grad()
490
+ def generate(self, pixel_values, max_new_tokens=None, num_beams=None):
491
+ batched_images = self._to_batched_images(pixel_values)
492
+ ve, vm = self.visual_encoder(batched_images)
493
+ batch = ve.shape[0]
494
+ bos_id = self.config.decoder_arch["bos_id"]
495
+ bos_emb = self.decoder.get_input_embeddings()(
496
+ torch.full((batch, 1), bos_id, dtype=torch.long, device=ve.device)
497
+ )
498
+ inputs_embeds = torch.cat([ve, bos_emb], dim=1)
499
+ attention_mask = torch.cat([
500
+ vm.to(dtype=torch.long),
501
+ torch.ones(batch, 1, dtype=torch.long, device=ve.device)
502
+ ], dim=1)
503
+ return self.decoder.generate(
504
+ inputs_embeds=inputs_embeds,
505
+ attention_mask=attention_mask,
506
+ max_new_tokens=max_new_tokens or self.config.max_new_tokens,
507
+ num_beams=num_beams or self.config.num_beams,
508
+ )
pipeline_latex_ocr.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from pathlib import Path
4
+ from PIL import Image
5
+ from huggingface_hub import snapshot_download
6
+
7
+
8
+ class Nav2TexPipeline:
9
+ def __init__(self, model, processor, device):
10
+ self.model = model
11
+ self.processor = processor
12
+ self.device = device
13
+
14
+ @classmethod
15
+ def from_pretrained(cls, repo_id_or_path: str, device: str = None):
16
+ if device is None:
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ path = Path(repo_id_or_path)
20
+ if not path.exists():
21
+ path = Path(snapshot_download(repo_id_or_path))
22
+
23
+ sys.path.insert(0, str(path))
24
+
25
+ from nav2tex.tokenization_latex_ocr import LaTeXTokenizer
26
+ from nav2tex.image_processing_latex_ocr import Nav2TexImageProcessor
27
+ from nav2tex.processing_latex_ocr import Nav2TexProcessor
28
+ from nav2tex.modeling_latex_ocr import Nav2TexModel
29
+ from nav2tex.configuration_latex_ocr import Nav2TexConfig
30
+
31
+ config = Nav2TexConfig.from_pretrained(str(path))
32
+ image_processor = Nav2TexImageProcessor.from_pretrained(str(path))
33
+ tokenizer = LaTeXTokenizer(str(path / "tokenizer.json"))
34
+ processor = Nav2TexProcessor(image_processor=image_processor, tokenizer=tokenizer)
35
+ model = Nav2TexModel.from_pretrained(str(path), config=config).to(device).eval()
36
+
37
+ return cls(model=model, processor=processor, device=device)
38
+
39
+ def __call__(self, image, max_new_tokens: int = None, num_beams: int = None):
40
+ single = not isinstance(image, list)
41
+ images = [image] if single else image
42
+
43
+ loaded = []
44
+ for img in images:
45
+ if isinstance(img, (str, Path)):
46
+ img = Image.open(img).convert("RGB")
47
+ elif isinstance(img, Image.Image):
48
+ img = img.convert("RGB")
49
+ else:
50
+ raise TypeError(f"Unsupported image type: {type(img)}")
51
+ loaded.append(img)
52
+
53
+ kwargs = {}
54
+ if max_new_tokens is not None:
55
+ kwargs["max_new_tokens"] = max_new_tokens
56
+ if num_beams is not None:
57
+ kwargs["num_beams"] = num_beams
58
+
59
+ # image processor handles variable-width images one at a time;
60
+ # collect pixel_values as a list for NaViT's batched_images path
61
+ pixel_values = [
62
+ self.processor(images=img, return_tensors="pt")["pixel_values"].to(self.device)
63
+ for img in loaded
64
+ ]
65
+
66
+ with torch.no_grad():
67
+ generated_ids = self.model.generate(pixel_values, **kwargs)
68
+
69
+ results = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
70
+ return results[0] if single else results
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_processor_type": "Nav2TexImageProcessor",
3
+ "image_height": 64,
4
+ "max_image_width": 1024,
5
+ "patch_size": 16,
6
+ "auto_map": {
7
+ "AutoProcessor": "processing_latex_ocr.Nav2TexProcessor",
8
+ "AutoImageProcessor": "image_processing_latex_ocr.Nav2TexImageProcessor"
9
+ }
10
+ }
processing_latex_ocr.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ProcessorMixin
2
+ from nav2tex.image_processing_latex_ocr import Nav2TexImageProcessor
3
+ from nav2tex.tokenization_latex_ocr import LaTeXTokenizer
4
+
5
+ class Nav2TexProcessor(ProcessorMixin):
6
+ attributes = ["image_processor", "tokenizer"]
7
+ image_processor_class = "AutoImageProcessor"
8
+ tokenizer_class = "AutoTokenizer"
9
+
10
+ def __init__(self, image_processor, tokenizer):
11
+ super().__init__(image_processor, tokenizer)
12
+
13
+ def __call__(self, images=None, text=None, return_tensors=None, **kwargs):
14
+ if images is None and text is None:
15
+ raise ValueError("You must specify either images or text.")
16
+
17
+ output = {}
18
+ if images is not None:
19
+ image_inputs = self.image_processor(images, return_tensors=return_tensors, **kwargs)
20
+ output.update(image_inputs)
21
+
22
+ if text is not None:
23
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
24
+ output.update(text_inputs)
25
+
26
+ return output
27
+
28
+ def batch_decode(self, *args, **kwargs):
29
+ return self.tokenizer.batch_decode(*args, **kwargs)
30
+
31
+ def decode(self, *args, **kwargs):
32
+ return self.tokenizer.decode(*args, **kwargs)
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": {
3
+ "content": "<pad>",
4
+ "single_word": false,
5
+ "lstrip": false,
6
+ "rstrip": false,
7
+ "normalized": false
8
+ },
9
+ "unk_token": {
10
+ "content": "<unk>",
11
+ "single_word": false,
12
+ "lstrip": false,
13
+ "rstrip": false,
14
+ "normalized": false
15
+ },
16
+ "bos_token": {
17
+ "content": "<bos>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false
22
+ },
23
+ "eos_token": {
24
+ "content": "<eos>",
25
+ "single_word": false,
26
+ "lstrip": false,
27
+ "rstrip": false,
28
+ "normalized": false
29
+ }
30
+ }
tokenization_latex_ocr.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Tuple
4
+ from transformers import PreTrainedTokenizer
5
+
6
+
7
+ class LaTeXTokenizer(PreTrainedTokenizer):
8
+ vocab_files_names = {"vocab_file": "tokenizer.json"}
9
+ model_input_names = ["input_ids", "attention_mask"]
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_file: str,
14
+ pad_token="<pad>",
15
+ unk_token="<unk>",
16
+ bos_token="<bos>",
17
+ eos_token="<eos>",
18
+ **kwargs,
19
+ ):
20
+ with open(Path(vocab_file), encoding="utf-8") as f:
21
+ data = json.load(f)
22
+
23
+ if "model" in data:
24
+ self.token2id: Dict[str, int] = data["model"]["vocab"]
25
+ self.id2token: Dict[int, str] = {int(v): k for k, v in self.token2id.items()}
26
+ self.merges = []
27
+ cfg = {}
28
+ else:
29
+ self.token2id = data["token2id"]
30
+ self.id2token = {int(k): v for k, v in data["id2token"].items()}
31
+ self.merges = data.get("merges", [])
32
+ cfg = data.get("config", {})
33
+
34
+ kwargs.setdefault("model_max_length", cfg.get("model_max_length", 256))
35
+ kwargs.setdefault("padding_side", cfg.get("padding_side", "right"))
36
+ kwargs.setdefault("truncation_side", cfg.get("truncation_side", "right"))
37
+
38
+ super().__init__(
39
+ pad_token=pad_token,
40
+ unk_token=unk_token,
41
+ bos_token=bos_token,
42
+ eos_token=eos_token,
43
+ **kwargs,
44
+ )
45
+
46
+ @property
47
+ def vocab_size(self) -> int:
48
+ return len(self.token2id)
49
+
50
+ def get_vocab(self) -> Dict[str, int]:
51
+ return dict(self.token2id)
52
+
53
+ def _tokenize(self, text: str) -> List[str]:
54
+ tokens = []
55
+ i = 0
56
+ while i < len(text):
57
+ matched = False
58
+ for length in range(min(20, len(text) - i), 0, -1):
59
+ substr = text[i:i + length]
60
+ if substr in self.token2id:
61
+ tokens.append(substr)
62
+ i += length
63
+ matched = True
64
+ break
65
+ if not matched:
66
+ tokens.append(text[i])
67
+ i += 1
68
+ return tokens
69
+
70
+ def _convert_token_to_id(self, token: str) -> int:
71
+ return self.token2id.get(token, self.token2id.get("<unk>", 1))
72
+
73
+ def _convert_id_to_token(self, index: int) -> str:
74
+ return self.id2token.get(index, "<unk>")
75
+
76
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
77
+ return "".join(tokens)
78
+
79
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
80
+ save_dir = Path(save_directory)
81
+ save_dir.mkdir(parents=True, exist_ok=True)
82
+ vocab_file = save_dir / (
83
+ (filename_prefix + "-" if filename_prefix else "") + "tokenizer.json"
84
+ )
85
+ data = {
86
+ "token2id": self.token2id,
87
+ "id2token": {str(k): v for k, v in self.id2token.items()},
88
+ "merges": [list(p) for p in self.merges],
89
+ "config": {
90
+ "vocab_size": self.vocab_size,
91
+ "pad_token": str(self.pad_token),
92
+ "unk_token": str(self.unk_token),
93
+ "bos_token": str(self.bos_token),
94
+ "eos_token": str(self.eos_token),
95
+ "pad_id": self.pad_token_id,
96
+ "unk_id": self.unk_token_id,
97
+ "bos_id": self.bos_token_id,
98
+ "eos_id": self.eos_token_id,
99
+ "model_max_length": self.model_max_length,
100
+ "padding_side": self.padding_side,
101
+ "truncation_side": self.truncation_side,
102
+ },
103
+ }
104
+ with open(vocab_file, "w", encoding="utf-8") as f:
105
+ json.dump(data, f, ensure_ascii=False, indent=2)
106
+ return (str(vocab_file),)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 2046,
3
+ "n_frozen": 697,
4
+ "special_tokens": [
5
+ "<pad>",
6
+ "<unk>",
7
+ "<bos>",
8
+ "<eos>"
9
+ ],
10
+ "pad_token": "<pad>",
11
+ "unk_token": "<unk>",
12
+ "bos_token": "<bos>",
13
+ "eos_token": "<eos>",
14
+ "pad_id": 0,
15
+ "unk_id": 1,
16
+ "bos_id": 2,
17
+ "eos_id": 3,
18
+ "model_max_length": 256,
19
+ "padding_side": "right",
20
+ "truncation_side": "right",
21
+ "tokenizer_version": 2,
22
+ "tokenizer_class": "LaTeXTokenizer",
23
+ "auto_map": {
24
+ "AutoTokenizer": ["tokenization_latex_ocr.LaTeXTokenizer", null]
25
+ }
26
+ }