harryrobert commited on
Commit
211851f
·
verified ·
1 Parent(s): 3372a56

Upload folder using huggingface_hub

Browse files
__pycache__/configuration_latex_decoder.cpython-312.pyc ADDED
Binary file (2.09 kB). View file
 
__pycache__/configuration_latex_ocr.cpython-312.pyc ADDED
Binary file (2.5 kB). View file
 
__pycache__/image_processing_latex_ocr.cpython-312.pyc ADDED
Binary file (4.01 kB). View file
 
__pycache__/image_processing_latex_ocr.cpython-313.pyc ADDED
Binary file (4.05 kB). View file
 
__pycache__/modeling_latex_decoder.cpython-312.pyc ADDED
Binary file (15.7 kB). View file
 
__pycache__/modeling_latex_ocr.cpython-312.pyc ADDED
Binary file (35.6 kB). View file
 
__pycache__/pipeline_latex_ocr.cpython-312.pyc ADDED
Binary file (3.94 kB). View file
 
__pycache__/processing_latex_ocr.cpython-312.pyc ADDED
Binary file (1.98 kB). View file
 
__pycache__/tokenization_latex_ocr.cpython-312.pyc ADDED
Binary file (6.26 kB). View file
 
__pycache__/tokenization_latex_ocr.cpython-313.pyc ADDED
Binary file (6.39 kB). View file
 
image_processing_latex_ocr.py CHANGED
@@ -1,11 +1,31 @@
1
  import torch
2
  import numpy as np
3
- from PIL import Image
4
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
5
  from transformers.utils import logging
6
 
7
  logger = logging.get_logger(__name__)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class LaTeXOCRImageProcessor(BaseImageProcessor):
10
  model_type = "latex_ocr"
11
 
@@ -21,24 +41,28 @@ class LaTeXOCRImageProcessor(BaseImageProcessor):
21
  self.max_image_width = max_image_width
22
  self.patch_size = patch_size
23
 
24
- def preprocess(self, images, **kwargs) -> BatchFeature:
25
  if not isinstance(images, list):
26
  images = [images]
27
-
28
  processed_images = []
29
  for img in images:
30
  if img.mode != "RGB":
31
  img = img.convert("RGB")
32
 
 
 
 
33
  w, h = img.size
34
  new_w = int(round(w * self.image_height / max(h, 1)))
35
  new_w = min(new_w, self.max_image_width)
36
  new_w = max((new_w // self.patch_size) * self.patch_size, self.patch_size)
37
-
38
  if (w, h) != (new_w, self.image_height):
39
  img = img.resize((new_w, self.image_height), Image.BILINEAR)
40
 
41
  img_array = np.array(img).astype(np.float32) / 255.0
 
42
  img_array = np.transpose(img_array, (2, 0, 1))
43
  processed_images.append(img_array)
44
 
 
1
  import torch
2
  import numpy as np
3
+ from PIL import Image, ImageOps, ImageEnhance
4
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
5
  from transformers.utils import logging
6
 
7
  logger = logging.get_logger(__name__)
8
 
9
+
10
+ def _prepare_for_inference(img: Image.Image) -> Image.Image:
11
+ """
12
+ Normalize real-world inputs (screenshots, camera, PDF crops) to the
13
+ clean white-background style the model was trained on.
14
+
15
+ Steps applied in order:
16
+ 1. Convert to grayscale luminance to check background tone
17
+ 2. If dark background (mean < 0.45), invert — handles dark mode / night mode
18
+ 3. Auto-contrast to stretch histogram — fixes low-contrast scans/photos
19
+ 4. Mild sharpening to counter screenshot JPEG blur
20
+ """
21
+ arr = np.array(img.convert("L"), dtype=np.float32) / 255.0
22
+ if arr.mean() < 0.45:
23
+ img = ImageOps.invert(img.convert("RGB"))
24
+ img = ImageOps.autocontrast(img, cutoff=1)
25
+ img = ImageEnhance.Sharpness(img).enhance(1.4)
26
+ return img.convert("RGB")
27
+
28
+
29
  class LaTeXOCRImageProcessor(BaseImageProcessor):
30
  model_type = "latex_ocr"
31
 
 
41
  self.max_image_width = max_image_width
42
  self.patch_size = patch_size
43
 
44
+ def preprocess(self, images, do_prepare=True, **kwargs) -> BatchFeature:
45
  if not isinstance(images, list):
46
  images = [images]
47
+
48
  processed_images = []
49
  for img in images:
50
  if img.mode != "RGB":
51
  img = img.convert("RGB")
52
 
53
+ if do_prepare:
54
+ img = _prepare_for_inference(img)
55
+
56
  w, h = img.size
57
  new_w = int(round(w * self.image_height / max(h, 1)))
58
  new_w = min(new_w, self.max_image_width)
59
  new_w = max((new_w // self.patch_size) * self.patch_size, self.patch_size)
60
+
61
  if (w, h) != (new_w, self.image_height):
62
  img = img.resize((new_w, self.image_height), Image.BILINEAR)
63
 
64
  img_array = np.array(img).astype(np.float32) / 255.0
65
+ img_array = (img_array - 0.5) / 0.5
66
  img_array = np.transpose(img_array, (2, 0, 1))
67
  processed_images.append(img_array)
68
 
modeling_latex_decoder.py CHANGED
@@ -11,7 +11,7 @@ from transformers.modeling_outputs import CausalLMOutput
11
  try:
12
  from .configuration_latex_decoder import LaTeXDecoderConfig
13
  except ImportError:
14
- from configuration_latex_decoder import LaTeXDecoderConfig
15
 
16
 
17
  class RMSNorm(nn.Module):
 
11
  try:
12
  from .configuration_latex_decoder import LaTeXDecoderConfig
13
  except ImportError:
14
+ from latex_ocr.configuration_latex_decoder import LaTeXDecoderConfig
15
 
16
 
17
  class RMSNorm(nn.Module):
modeling_latex_ocr.py CHANGED
@@ -12,9 +12,9 @@ try:
12
  from .configuration_latex_ocr import LaTeXOCRConfig
13
  from .modeling_latex_decoder import LaTeXDecoderForCausalLM
14
  except ImportError:
15
- from configuration_latex_decoder import LaTeXDecoderConfig
16
- from configuration_latex_ocr import LaTeXOCRConfig
17
- from modeling_latex_decoder import LaTeXDecoderForCausalLM
18
 
19
  try:
20
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -319,14 +319,67 @@ class CustomDecoder(nn.Module):
319
 
320
  @torch.no_grad()
321
  def generate(self, inputs_embeds, attention_mask, max_new_tokens, num_beams=1):
322
- eos_id = self.eos_token_id
323
  device = inputs_embeds.device
324
- batch = inputs_embeds.shape[0]
325
- assert batch == 1, "beam search only supports batch_size=1"
326
- vis_emb = inputs_embeds[0]
327
- vis_len = vis_emb.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  vis_mask = attention_mask[0] if attention_mask is not None else None
329
- beams = [(0.0, [], False) for _ in range(num_beams)]
330
 
331
  for _ in range(max_new_tokens):
332
  all_embeds, all_masks = [], []
@@ -341,7 +394,7 @@ class CustomDecoder(nn.Module):
341
  max_len = max(e.shape[0] for e in all_embeds)
342
  d_model = all_embeds[0].shape[-1]
343
  padded_embeds = vis_emb.new_zeros(num_beams, max_len, d_model)
344
- padded_mask = vis_mask.new_zeros(num_beams, max_len) if vis_mask is not None else None
345
  for idx, emb in enumerate(all_embeds):
346
  padded_embeds[idx, :emb.shape[0]] = emb
347
  if padded_mask is not None:
@@ -366,7 +419,7 @@ class CustomDecoder(nn.Module):
366
 
367
  best_ids = max(beams, key=lambda x: x[0])[1]
368
  if not best_ids:
369
- return torch.zeros(batch, 0, dtype=torch.long, device=device)
370
  return torch.tensor(best_ids, dtype=torch.long, device=device).unsqueeze(0)
371
 
372
 
@@ -400,6 +453,12 @@ class LaTeXOCRModel(PreTrainedModel):
400
  self.decoder = CustomDecoder(config)
401
  self.post_init()
402
 
 
 
 
 
 
 
403
  def _init_weights(self, module):
404
  return
405
 
 
12
  from .configuration_latex_ocr import LaTeXOCRConfig
13
  from .modeling_latex_decoder import LaTeXDecoderForCausalLM
14
  except ImportError:
15
+ from latex_ocr.configuration_latex_decoder import LaTeXDecoderConfig
16
+ from latex_ocr.configuration_latex_ocr import LaTeXOCRConfig
17
+ from latex_ocr.modeling_latex_decoder import LaTeXDecoderForCausalLM
18
 
19
  try:
20
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
319
 
320
  @torch.no_grad()
321
  def generate(self, inputs_embeds, attention_mask, max_new_tokens, num_beams=1):
 
322
  device = inputs_embeds.device
323
+ batch = inputs_embeds.shape[0]
324
+
325
+ if num_beams > 1:
326
+ # beam search: only supports batch_size=1
327
+ assert batch == 1, "beam search only supports batch_size=1"
328
+ return self._beam_search(inputs_embeds, attention_mask, max_new_tokens, num_beams)
329
+
330
+ return self._greedy_batch(inputs_embeds, attention_mask, max_new_tokens)
331
+
332
+ @torch.no_grad()
333
+ def _greedy_batch(self, inputs_embeds, attention_mask, max_new_tokens):
334
+ """Greedy decoding with true batch support."""
335
+ eos_id = self.eos_token_id
336
+ pad_id = self._pad_id
337
+ device = inputs_embeds.device
338
+ batch = inputs_embeds.shape[0]
339
+ d_model = inputs_embeds.shape[-1]
340
+
341
+ # generated token ids per sample, and finished flags
342
+ gen_ids = [[] for _ in range(batch)]
343
+ finished = torch.zeros(batch, dtype=torch.bool, device=device)
344
+
345
+ cur_embeds = inputs_embeds # (B, vis_len, D)
346
+ cur_mask = attention_mask # (B, vis_len)
347
+
348
+ for _ in range(max_new_tokens):
349
+ logits = self._forward_embeds(cur_embeds, cur_mask) # (B, seq, vocab)
350
+ next_tok = logits[:, -1, :].argmax(dim=-1) # (B,)
351
+
352
+ for i in range(batch):
353
+ if not finished[i]:
354
+ gen_ids[i].append(next_tok[i].item())
355
+ finished |= (next_tok == eos_id)
356
+ if finished.all():
357
+ break
358
+
359
+ tok_emb = self._model.embed_tokens(next_tok.unsqueeze(1)) # (B, 1, D)
360
+ tok_mask = cur_mask.new_ones(batch, 1)
361
+ cur_embeds = torch.cat([cur_embeds, tok_emb], dim=1)
362
+ cur_mask = torch.cat([cur_mask, tok_mask], dim=1)
363
+
364
+ # pad to same length and return (B, max_len)
365
+ max_len = max((len(ids) for ids in gen_ids), default=0)
366
+ if max_len == 0:
367
+ return torch.zeros(batch, 0, dtype=torch.long, device=device)
368
+ out = torch.full((batch, max_len), pad_id, dtype=torch.long, device=device)
369
+ for i, ids in enumerate(gen_ids):
370
+ if ids:
371
+ out[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device)
372
+ return out
373
+
374
+ @torch.no_grad()
375
+ def _beam_search(self, inputs_embeds, attention_mask, max_new_tokens, num_beams):
376
+ """Original beam search (batch_size=1 only)."""
377
+ eos_id = self.eos_token_id
378
+ device = inputs_embeds.device
379
+ vis_emb = inputs_embeds[0]
380
+ vis_len = vis_emb.shape[0]
381
  vis_mask = attention_mask[0] if attention_mask is not None else None
382
+ beams = [(0.0, [], False) for _ in range(num_beams)]
383
 
384
  for _ in range(max_new_tokens):
385
  all_embeds, all_masks = [], []
 
394
  max_len = max(e.shape[0] for e in all_embeds)
395
  d_model = all_embeds[0].shape[-1]
396
  padded_embeds = vis_emb.new_zeros(num_beams, max_len, d_model)
397
+ padded_mask = vis_mask.new_zeros(num_beams, max_len) if vis_mask is not None else None
398
  for idx, emb in enumerate(all_embeds):
399
  padded_embeds[idx, :emb.shape[0]] = emb
400
  if padded_mask is not None:
 
419
 
420
  best_ids = max(beams, key=lambda x: x[0])[1]
421
  if not best_ids:
422
+ return torch.zeros(1, 0, dtype=torch.long, device=device)
423
  return torch.tensor(best_ids, dtype=torch.long, device=device).unsqueeze(0)
424
 
425
 
 
453
  self.decoder = CustomDecoder(config)
454
  self.post_init()
455
 
456
+ def tie_weights(self):
457
+ if self.config.decoder_weights_tied:
458
+ self.decoder.tie_weights()
459
+ else:
460
+ self.decoder.untie_weights()
461
+
462
  def _init_weights(self, module):
463
  return
464
 
pipeline_latex_ocr.py CHANGED
@@ -22,11 +22,11 @@ class LaTeXOCRPipeline:
22
 
23
  sys.path.insert(0, str(path))
24
 
25
- from tokenization_latex_ocr import LaTeXTokenizer
26
- from image_processing_latex_ocr import LaTeXOCRImageProcessor
27
- from processing_latex_ocr import LaTeXOCRProcessor
28
- from modeling_latex_ocr import LaTeXOCRModel
29
- from configuration_latex_ocr import LaTeXOCRConfig
30
 
31
  config = LaTeXOCRConfig.from_pretrained(str(path))
32
  image_processor = LaTeXOCRImageProcessor.from_pretrained(str(path))
@@ -36,16 +36,19 @@ class LaTeXOCRPipeline:
36
 
37
  return cls(model=model, processor=processor, device=device)
38
 
39
- def __call__(self, image, max_new_tokens: int = None, num_beams: int = None) -> str:
40
- if isinstance(image, (str, Path)):
41
- image = Image.open(image).convert("RGB")
42
- elif isinstance(image, Image.Image):
43
- image = image.convert("RGB")
44
- else:
45
- raise TypeError(f"Unsupported image type: {type(image)}")
46
 
47
- inputs = self.processor(images=image, return_tensors="pt")
48
- pixel_values = inputs["pixel_values"].to(self.device)
 
 
 
 
 
 
 
49
 
50
  kwargs = {}
51
  if max_new_tokens is not None:
@@ -53,7 +56,15 @@ class LaTeXOCRPipeline:
53
  if num_beams is not None:
54
  kwargs["num_beams"] = num_beams
55
 
 
 
 
 
 
 
 
56
  with torch.no_grad():
57
  generated_ids = self.model.generate(pixel_values, **kwargs)
58
 
59
- return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
22
 
23
  sys.path.insert(0, str(path))
24
 
25
+ from latex_ocr.tokenization_latex_ocr import LaTeXTokenizer
26
+ from latex_ocr.image_processing_latex_ocr import LaTeXOCRImageProcessor
27
+ from latex_ocr.processing_latex_ocr import LaTeXOCRProcessor
28
+ from latex_ocr.modeling_latex_ocr import LaTeXOCRModel
29
+ from latex_ocr.configuration_latex_ocr import LaTeXOCRConfig
30
 
31
  config = LaTeXOCRConfig.from_pretrained(str(path))
32
  image_processor = LaTeXOCRImageProcessor.from_pretrained(str(path))
 
36
 
37
  return cls(model=model, processor=processor, device=device)
38
 
39
+ def __call__(self, image, max_new_tokens: int = None, num_beams: int = None):
40
+ single = not isinstance(image, list)
41
+ images = [image] if single else image
 
 
 
 
42
 
43
+ loaded = []
44
+ for img in images:
45
+ if isinstance(img, (str, Path)):
46
+ img = Image.open(img).convert("RGB")
47
+ elif isinstance(img, Image.Image):
48
+ img = img.convert("RGB")
49
+ else:
50
+ raise TypeError(f"Unsupported image type: {type(img)}")
51
+ loaded.append(img)
52
 
53
  kwargs = {}
54
  if max_new_tokens is not None:
 
56
  if num_beams is not None:
57
  kwargs["num_beams"] = num_beams
58
 
59
+ # image processor handles variable-width images one at a time;
60
+ # collect pixel_values as a list for NaViT's batched_images path
61
+ pixel_values = [
62
+ self.processor(images=img, return_tensors="pt")["pixel_values"].to(self.device)
63
+ for img in loaded
64
+ ]
65
+
66
  with torch.no_grad():
67
  generated_ids = self.model.generate(pixel_values, **kwargs)
68
 
69
+ results = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
70
+ return results[0] if single else results
processing_latex_ocr.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import ProcessorMixin
2
- from image_processing_latex_ocr import LaTeXOCRImageProcessor
3
- from tokenization_latex_ocr import LaTeXTokenizer
4
 
5
  class LaTeXOCRProcessor(ProcessorMixin):
6
  attributes = ["image_processor", "tokenizer"]
 
1
  from transformers import ProcessorMixin
2
+ from latex_ocr.image_processing_latex_ocr import LaTeXOCRImageProcessor
3
+ from latex_ocr.tokenization_latex_ocr import LaTeXTokenizer
4
 
5
  class LaTeXOCRProcessor(ProcessorMixin):
6
  attributes = ["image_processor", "tokenizer"]
tokenization_latex_ocr.py CHANGED
@@ -23,19 +23,14 @@ class LaTeXTokenizer(PreTrainedTokenizer):
23
  if "model" in data:
24
  self.token2id: Dict[str, int] = data["model"]["vocab"]
25
  self.id2token: Dict[int, str] = {int(v): k for k, v in self.token2id.items()}
26
- self.merges: List[Tuple[str, str]] = []
27
  cfg = {}
28
  else:
29
  self.token2id = data["token2id"]
30
  self.id2token = {int(k): v for k, v in data["id2token"].items()}
31
- self.merges = [tuple(m) for m in data.get("merges", [])]
32
  cfg = data.get("config", {})
33
 
34
- self.bpe_ranks: Dict[Tuple[str, str], int] = {
35
- pair: idx for idx, pair in enumerate(self.merges)
36
- }
37
- self._bpe_cache: Dict[str, str] = {}
38
-
39
  kwargs.setdefault("model_max_length", cfg.get("model_max_length", 256))
40
  kwargs.setdefault("padding_side", cfg.get("padding_side", "right"))
41
  kwargs.setdefault("truncation_side", cfg.get("truncation_side", "right"))
@@ -48,42 +43,6 @@ class LaTeXTokenizer(PreTrainedTokenizer):
48
  **kwargs,
49
  )
50
 
51
- def _get_pairs(self, word: Tuple[str, ...]):
52
- return {(word[i], word[i + 1]) for i in range(len(word) - 1)}
53
-
54
- def _bpe(self, token: str) -> str:
55
- if token in self._bpe_cache:
56
- return self._bpe_cache[token]
57
-
58
- word = tuple(token)
59
- pairs = self._get_pairs(word)
60
-
61
- if not pairs:
62
- return token
63
-
64
- while True:
65
- bigram = min(pairs, key=lambda p: self.bpe_ranks.get(p, float("inf")))
66
- if bigram not in self.bpe_ranks:
67
- break
68
- first, second = bigram
69
- new_word = []
70
- i = 0
71
- while i < len(word):
72
- if i < len(word) - 1 and word[i] == first and word[i + 1] == second:
73
- new_word.append(first + second)
74
- i += 2
75
- else:
76
- new_word.append(word[i])
77
- i += 1
78
- word = tuple(new_word)
79
- pairs = self._get_pairs(word)
80
- if not pairs:
81
- break
82
-
83
- result = " ".join(word)
84
- self._bpe_cache[token] = result
85
- return result
86
-
87
  @property
88
  def vocab_size(self) -> int:
89
  return len(self.token2id)
@@ -99,7 +58,7 @@ class LaTeXTokenizer(PreTrainedTokenizer):
99
  for length in range(min(20, len(text) - i), 0, -1):
100
  substr = text[i:i + length]
101
  if substr in self.token2id:
102
- tokens.extend(self._bpe(substr).split())
103
  i += length
104
  matched = True
105
  break
 
23
  if "model" in data:
24
  self.token2id: Dict[str, int] = data["model"]["vocab"]
25
  self.id2token: Dict[int, str] = {int(v): k for k, v in self.token2id.items()}
26
+ self.merges = []
27
  cfg = {}
28
  else:
29
  self.token2id = data["token2id"]
30
  self.id2token = {int(k): v for k, v in data["id2token"].items()}
31
+ self.merges = data.get("merges", [])
32
  cfg = data.get("config", {})
33
 
 
 
 
 
 
34
  kwargs.setdefault("model_max_length", cfg.get("model_max_length", 256))
35
  kwargs.setdefault("padding_side", cfg.get("padding_side", "right"))
36
  kwargs.setdefault("truncation_side", cfg.get("truncation_side", "right"))
 
43
  **kwargs,
44
  )
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @property
47
  def vocab_size(self) -> int:
48
  return len(self.token2id)
 
58
  for length in range(min(20, len(text) - i), 0, -1):
59
  substr = text[i:i + length]
60
  if substr in self.token2id:
61
+ tokens.append(substr)
62
  i += length
63
  matched = True
64
  break