saracandu commited on
Commit
707b7ba
·
verified ·
1 Parent(s): 0de87c2
Files changed (1) hide show
  1. tokenizer_stlenc.py +41 -20
tokenizer_stlenc.py CHANGED
@@ -2,18 +2,30 @@ import json
2
  import os
3
  import torch
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
- from transformers import PreTrainedTokenizer
6
- from huggingface_hub import hf_hub_download
7
 
8
  class STLTokenizer(PreTrainedTokenizer):
9
- def __init__(self, vocab_file="vocab.json", unk_token="unk", pad_token="pad",
10
- bos_token="/s", eos_token="s", model_max_length=512, **kwargs):
11
-
 
 
 
 
 
 
 
 
 
12
  current_dir = os.path.dirname(__file__)
13
  full_vocab_path = os.path.join(current_dir, vocab_file)
14
 
15
  if not os.path.exists(full_vocab_path):
16
- full_vocab_path = hf_hub_download("saracandu/stldec_random", "vocab.json")
 
 
 
 
17
 
18
  with open(full_vocab_path, "r", encoding="utf-8") as f:
19
  self.vocab = json.load(f)
@@ -30,36 +42,40 @@ class STLTokenizer(PreTrainedTokenizer):
30
  )
31
 
32
  @property
33
- def vocab_size(self):
34
  return len(self.vocab)
35
 
36
- def get_vocab(self):
37
  return dict(self.vocab)
38
 
39
- def _tokenize(self, text):
40
- # La tua logica di tokenizzazione
41
  text = f'{self.bos_token} {text} {self.eos_token}'.replace(' ', '@')
42
- tokens, i = [], 0
 
 
43
  while i < len(text):
44
  best_match = None
45
  for j in range(min(i + 50, len(text)), i, -1):
46
- sub = text[i:j]
47
- if sub in self.vocab:
48
- best_match = sub
49
  break
 
50
  if best_match:
51
- tokens.append(best_match); i += len(best_match)
 
52
  else:
53
- tokens.append(self.unk_token); i += 1
 
54
  return tokens
55
 
56
- def _convert_token_to_id(self, token):
57
  return self.vocab.get(token, self.vocab.get(self.unk_token))
58
 
59
- def _convert_id_to_token(self, index):
60
  return self.id_to_token.get(index, self.unk_token)
61
 
62
- def save_vocabulary(self, save_directory, filename_prefix=None):
63
  if not os.path.isdir(save_directory):
64
  os.makedirs(save_directory)
65
 
@@ -69,4 +85,9 @@ class STLTokenizer(PreTrainedTokenizer):
69
  with open(vocab_file, "w", encoding="utf-8") as f:
70
  json.dump(self.vocab, f, indent=2, ensure_ascii=False)
71
 
72
- return (vocab_file,)
 
 
 
 
 
 
2
  import os
3
  import torch
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from transformers import PreTrainedTokenizer, AutoTokenizer
 
6
 
7
  class STLTokenizer(PreTrainedTokenizer):
8
+ model_type = "stl_encoder"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_file="vocab.json",
13
+ unk_token="unk",
14
+ pad_token="pad",
15
+ bos_token="/s",
16
+ eos_token="s",
17
+ model_max_length=512,
18
+ **kwargs
19
+ ):
20
  current_dir = os.path.dirname(__file__)
21
  full_vocab_path = os.path.join(current_dir, vocab_file)
22
 
23
  if not os.path.exists(full_vocab_path):
24
+ from huggingface_hub import hf_hub_download
25
+ try:
26
+ full_vocab_path = hf_hub_download("saracandu/stlenc", vocab_file)
27
+ except:
28
+ full_vocab_path = vocab_file
29
 
30
  with open(full_vocab_path, "r", encoding="utf-8") as f:
31
  self.vocab = json.load(f)
 
42
  )
43
 
44
  @property
45
+ def vocab_size(self) -> int:
46
  return len(self.vocab)
47
 
48
+ def get_vocab(self) -> Dict[str, int]:
49
  return dict(self.vocab)
50
 
51
+ def _tokenize(self, text: str) -> List[str]:
 
52
  text = f'{self.bos_token} {text} {self.eos_token}'.replace(' ', '@')
53
+
54
+ tokens = []
55
+ i = 0
56
  while i < len(text):
57
  best_match = None
58
  for j in range(min(i + 50, len(text)), i, -1):
59
+ subtoken = text[i:j]
60
+ if subtoken in self.vocab:
61
+ best_match = subtoken
62
  break
63
+
64
  if best_match:
65
+ tokens.append(best_match)
66
+ i += len(best_match)
67
  else:
68
+ tokens.append(self.unk_token)
69
+ i += 1
70
  return tokens
71
 
72
+ def _convert_token_to_id(self, token: str) -> int:
73
  return self.vocab.get(token, self.vocab.get(self.unk_token))
74
 
75
+ def _convert_id_to_token(self, index: int) -> str:
76
  return self.id_to_token.get(index, self.unk_token)
77
 
78
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
79
  if not os.path.isdir(save_directory):
80
  os.makedirs(save_directory)
81
 
 
85
  with open(vocab_file, "w", encoding="utf-8") as f:
86
  json.dump(self.vocab, f, indent=2, ensure_ascii=False)
87
 
88
+ return (vocab_file,)
89
+
90
+ try:
91
+ AutoTokenizer.register("stl_encoder", STLTokenizer)
92
+ except Exception:
93
+ pass