Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -8,6 +8,46 @@ from huggingface_hub import hf_hub_download
|
|
| 8 |
|
| 9 |
REPO_ID = "itriedcoding/Sage"
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# Custom model class matching Sage architecture
|
| 12 |
class TransformerLM(nn.Module):
|
| 13 |
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):
|
|
|
|
| 8 |
|
| 9 |
REPO_ID = "itriedcoding/Sage"
|
| 10 |
|
| 11 |
+
class CharacterTokenizer:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
self.char_to_idx = {}
|
| 14 |
+
self.idx_to_char = {}
|
| 15 |
+
self.vocab_size = 0
|
| 16 |
+
self.pad_token_id = 0
|
| 17 |
+
self.unk_token_id = 1
|
| 18 |
+
|
| 19 |
+
def fit(self, texts):
|
| 20 |
+
chars = set()
|
| 21 |
+
for text in texts:
|
| 22 |
+
chars.update(list(str(text)))
|
| 23 |
+
self.char_to_idx['<PAD>'] = 0
|
| 24 |
+
self.char_to_idx['<UNK>'] = 1
|
| 25 |
+
for i, char in enumerate(sorted(chars)):
|
| 26 |
+
self.char_to_idx[char] = i + 2
|
| 27 |
+
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
|
| 28 |
+
self.vocab_size = len(self.char_to_idx)
|
| 29 |
+
|
| 30 |
+
def encode(self, text, max_length=None, padding=False, truncation=False, return_tensors=None):
|
| 31 |
+
if isinstance(text, str):
|
| 32 |
+
text = [text]
|
| 33 |
+
encoded = []
|
| 34 |
+
for t in text:
|
| 35 |
+
tokens = [self.char_to_idx.get(c, self.unk_token_id) for c in str(t)]
|
| 36 |
+
if truncation and max_length:
|
| 37 |
+
tokens = tokens[:max_length]
|
| 38 |
+
if padding and max_length:
|
| 39 |
+
tokens = tokens + [self.pad_token_id] * (max_length - len(tokens))
|
| 40 |
+
encoded.append(tokens)
|
| 41 |
+
if return_tensors == 'pt':
|
| 42 |
+
return torch.tensor(encoded, dtype=torch.long)
|
| 43 |
+
return encoded
|
| 44 |
+
|
| 45 |
+
def decode(self, token_ids):
|
| 46 |
+
if isinstance(token_ids, torch.Tensor):
|
| 47 |
+
token_ids = token_ids.tolist()
|
| 48 |
+
chars = [self.idx_to_char.get(idx, '<UNK>') for idx in token_ids]
|
| 49 |
+
return ''.join(chars)
|
| 50 |
+
|
| 51 |
# Custom model class matching Sage architecture
|
| 52 |
class TransformerLM(nn.Module):
|
| 53 |
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):
|