itriedcoding commited on
Commit
a982fff
·
verified ·
1 Parent(s): cf8d8b2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py CHANGED
@@ -8,6 +8,46 @@ from huggingface_hub import hf_hub_download
8
 
9
  REPO_ID = "itriedcoding/Sage"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Custom model class matching Sage architecture
12
  class TransformerLM(nn.Module):
13
  def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):
 
8
 
9
  REPO_ID = "itriedcoding/Sage"
10
 
11
+ class CharacterTokenizer:
12
+ def __init__(self):
13
+ self.char_to_idx = {}
14
+ self.idx_to_char = {}
15
+ self.vocab_size = 0
16
+ self.pad_token_id = 0
17
+ self.unk_token_id = 1
18
+
19
+ def fit(self, texts):
20
+ chars = set()
21
+ for text in texts:
22
+ chars.update(list(str(text)))
23
+ self.char_to_idx['<PAD>'] = 0
24
+ self.char_to_idx['<UNK>'] = 1
25
+ for i, char in enumerate(sorted(chars)):
26
+ self.char_to_idx[char] = i + 2
27
+ self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
28
+ self.vocab_size = len(self.char_to_idx)
29
+
30
+ def encode(self, text, max_length=None, padding=False, truncation=False, return_tensors=None):
31
+ if isinstance(text, str):
32
+ text = [text]
33
+ encoded = []
34
+ for t in text:
35
+ tokens = [self.char_to_idx.get(c, self.unk_token_id) for c in str(t)]
36
+ if truncation and max_length:
37
+ tokens = tokens[:max_length]
38
+ if padding and max_length:
39
+ tokens = tokens + [self.pad_token_id] * (max_length - len(tokens))
40
+ encoded.append(tokens)
41
+ if return_tensors == 'pt':
42
+ return torch.tensor(encoded, dtype=torch.long)
43
+ return encoded
44
+
45
+ def decode(self, token_ids):
46
+ if isinstance(token_ids, torch.Tensor):
47
+ token_ids = token_ids.tolist()
48
+ chars = [self.idx_to_char.get(idx, '<UNK>') for idx in token_ids]
49
+ return ''.join(chars)
50
+
51
  # Custom model class matching Sage architecture
52
  class TransformerLM(nn.Module):
53
  def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):