riyadhrazzaq commited on
Commit
d541e5a
·
0 Parent(s):

added inference scripts, model and vocab

Browse files
Files changed (12) hide show
  1. .gitattributes +1 -0
  2. .gitignore +161 -0
  3. .idea/.gitignore +8 -0
  4. main.py +20 -0
  5. model_lr0.0001_bs256_epoch50.pt +3 -0
  6. src/__init__.py +0 -0
  7. src/evaluator.py +15 -0
  8. src/model.py +81 -0
  9. src/tokenizer.py +72 -0
  10. src/util.py +88 -0
  11. test.py +25 -0
  12. vocab.pt +3 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
main.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from src.evaluator import evaluate
4
+
5
+
6
+ def main():
7
+ # parser
8
+ parser = argparse.ArgumentParser(description='inference with model.')
9
+ parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint file')
10
+ parser.add_argument("--decompress", action="store_true", help="decompress the input text")
11
+ parser.add_argument('--vocab', type=str, help='Path to the vocab file')
12
+ parser.add_argument('--text', type=str, help='Text to be tokenized')
13
+ args = parser.parse_args()
14
+
15
+ # load model and vocab
16
+ evaluate(args)
17
+
18
+
19
+ if __name__ == "__main__":
20
+ main()
model_lr0.0001_bs256_epoch50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31923ca96e3c2471ad6252dfb615b15cde784be5a7792c7379d1c9a9b27a7f4e
3
+ size 551468733
src/__init__.py ADDED
File without changes
src/evaluator.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model import Model2
2
+ from src.tokenizer import Tokenizer
3
+ from src.util import *
4
+
5
+
6
+ def evaluate(args):
7
+ vocab = torch.load(args.vocab, map_location=torch.device('cpu'))
8
+ model = Model2(len(vocab), 300, 256, vocab['<PAD>'])
9
+ load_from_checkpoint(model, args.checkpoint)
10
+
11
+ print()
12
+ if args.decompress:
13
+ print(decompress(args.text, Tokenizer(vocab), model))
14
+ else:
15
+ print(compress(args.text, Tokenizer(vocab), model))
src/model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from src.util import device
5
+
6
+
7
+ class Transpose(nn.Module):
8
+ def __init__(self, dim0=None, dim1=None):
9
+ super().__init__()
10
+ self.dim0 = dim0
11
+ self.dim1 = dim1
12
+
13
+ def forward(self, tensor):
14
+ if self.dim0 is None:
15
+ self.dim0 = tensor.dim() - 2
16
+ self.dim1 = tensor.dim() - 1
17
+
18
+ return torch.transpose(tensor, self.dim0, self.dim1)
19
+
20
+ class Model2(nn.Module):
21
+ def __init__(
22
+ self,
23
+ vocab_size,
24
+ embedding_dim,
25
+ state_size,
26
+ pad_index,
27
+ ):
28
+ super().__init__()
29
+ self.state_size = state_size
30
+ self.pad_index = pad_index
31
+ self.embedding_layer = nn.Embedding(
32
+ num_embeddings=vocab_size,
33
+ embedding_dim=embedding_dim,
34
+ padding_idx=pad_index,
35
+ )
36
+
37
+ self.rnn_layer = nn.LSTMCell(input_size=embedding_dim, hidden_size=state_size)
38
+ self.lin1 = nn.Sequential(
39
+ nn.Linear(state_size, state_size * 4),
40
+ nn.ReLU(),
41
+ nn.Dropout(p=0.5),
42
+ )
43
+ self.lin2 = nn.Sequential(
44
+ nn.Linear(state_size * 4, state_size * 8),
45
+ Transpose(),
46
+ nn.BatchNorm1d(state_size * 8),
47
+ Transpose(),
48
+ nn.ReLU(),
49
+ nn.Dropout(p=0.5),
50
+ )
51
+ self.lin3 = nn.Sequential(
52
+ nn.Linear(state_size * 8, state_size * 16),
53
+ nn.ReLU(),
54
+ nn.Dropout(p=0.5),
55
+ )
56
+ self.lin4 = nn.Sequential(nn.Linear(state_size * 16, vocab_size))
57
+
58
+ def forward(self, X):
59
+ N, T = X.shape
60
+ non_pad_mask = X != self.pad_index
61
+ X = self.embedding_layer(X)
62
+
63
+ state = torch.zeros((N, self.state_size), device=device)
64
+ c = torch.zeros((N, self.state_size), device=device)
65
+ states = []
66
+ for t in range(T):
67
+ next_state, next_c = self.rnn_layer(X[:, t, :], (state, c))
68
+ # print(non_pad_mask[:, t].reshape(-1, 1).shape, next_state.shape, state.shape)
69
+ state = torch.where(non_pad_mask[:, t].reshape(-1, 1), next_state, state)
70
+ c = torch.where(non_pad_mask[:, t].reshape(-1, 1), next_c, c)
71
+
72
+ states.append(state)
73
+
74
+ # (N, T, states)
75
+ states = torch.stack(states, dim=1)
76
+ output = self.lin1(states)
77
+ output = self.lin2(output)
78
+ output = self.lin3(output)
79
+ output = self.lin4(output)
80
+
81
+ return output
src/tokenizer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torchtext.vocab import Vocab
7
+ from torch import nn, Tensor
8
+
9
+ from src.util import device
10
+
11
+
12
+ class Tokenizer(nn.Module):
13
+ def __init__(self, vocab: str | Vocab):
14
+ super().__init__()
15
+
16
+ # check vocab file exists
17
+ if isinstance(vocab, str):
18
+ assert os.path.exists(vocab)
19
+ self.vocab = torch.load(vocab, map_location=device)
20
+ else:
21
+ self.vocab = vocab
22
+
23
+ self.edge_index = vocab['<EDGE>']
24
+ self.pad_index = vocab['<PAD>']
25
+ self.unk_index = vocab['<UNK>']
26
+
27
+ def get_tensors(self, data):
28
+ """
29
+ Builds torch.Tensor from a variable length 2D python list. The return value is a tuple of two tensors, one for input and the other for output.
30
+
31
+ Parameters
32
+ ----------
33
+ data: Nested list of token indices
34
+ [[1,2,3],
35
+ [4,2,3,4,2],
36
+ [223,4,2]]
37
+ This example has three sentences.
38
+
39
+ """
40
+ max_len = max([len(datum) for datum in data]) + 1
41
+ N = len(data)
42
+ X = np.full((N, max_len), self.pad_index, np.int64)
43
+ Y = np.full((N, max_len), self.pad_index, np.int64)
44
+
45
+ for i in range(N):
46
+ # prepend the inputs with edge token
47
+ X[i, 0] = self.edge_index
48
+ for j in range(len(data[i])):
49
+ X[i, j + 1] = data[i][j]
50
+ Y[i, j] = data[i][j]
51
+
52
+ # finish the outputs with edge token
53
+ Y[i, j] = self.edge_index
54
+
55
+ return torch.tensor(X, device=device), torch.tensor(Y, device=device)
56
+
57
+ def forward(self, text: List[str]) -> Tuple[Tensor, Tensor]:
58
+ """
59
+ Tokenizes a list of natural text. The return value is a tensor of token ids.
60
+
61
+ Parameters
62
+ ----------
63
+ text: List[str]. A list of natural language strings.
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor. A tensor of token ids.
68
+ """
69
+
70
+ text = [sentence.split() for sentence in text]
71
+ tokenized = [self.vocab(sentence) for sentence in text]
72
+ return self.get_tensors(tokenized)
src/util.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5
+
6
+
7
+ def stringify(array):
8
+ return '\n'.join([' '.join(inner_list) for inner_list in array])
9
+
10
+
11
+ def compress(text, tokenizer, model):
12
+ """
13
+ tokenizer: Tokenizer.
14
+ text: str.
15
+ Each line represents a single document.
16
+ """
17
+ tokens = [sentence.split() for sentence in text.split("\n")]
18
+ indices, _ = tokenizer(text.split("\n"))
19
+
20
+ logits = model(indices)
21
+ next_token_predicted = logits.argmax(dim=2)
22
+
23
+ # slices are for skipping edge tokens
24
+ prediction_mask = indices[:, 1:] == next_token_predicted[:, :-1]
25
+
26
+ # replace correctly predicted tokens with "X"
27
+ for i, sentence_mask in enumerate(prediction_mask):
28
+ sentence_len = len(tokens[i])
29
+ for j, predicted_successfully in enumerate(sentence_mask):
30
+ # length check is to ignore pad tokens
31
+ if predicted_successfully and j < sentence_len and tokenizer.vocab[tokens[i][j]] != tokenizer.unk_index:
32
+ tokens[i][j] = "X"
33
+
34
+ sentences = [" ".join(sentence) for sentence in tokens]
35
+ document = "\n".join(sentences)
36
+ return document
37
+
38
+
39
+ def decompress(text, tokenizer, model):
40
+ """
41
+ text: str.
42
+ Each line represents a single document.
43
+ """
44
+ sentence_tokens = [document.split() for document in text.split("\n")]
45
+ indices, _ = tokenizer(text.split("\n"))
46
+
47
+ uncompressed = []
48
+ for i, sentence in enumerate(sentence_tokens):
49
+ prefix = ['<EDGE>']
50
+ for j, token in enumerate(sentence):
51
+ if token != "X":
52
+ prefix.append(token)
53
+ else:
54
+ # only infer when X is found
55
+ indices = torch.tensor([tokenizer.vocab(prefix)],
56
+ dtype=torch.int,
57
+ device=device)
58
+ logits = model(indices)
59
+ # prediction logit for X
60
+ logit = logits[:, -1, :]
61
+ index = logit.argmax(dim=1)
62
+ prefix.append(tokenizer.vocab.lookup_token(index))
63
+
64
+ # reset prefix for new sentence
65
+ uncompressed.append(prefix[1:])
66
+
67
+ return stringify(uncompressed)
68
+
69
+
70
+ def load_from_checkpoint(model, checkpoint_path):
71
+ """
72
+ Loads a model from a checkpoint.
73
+
74
+ Parameters:
75
+ ----------
76
+ checkpoint_path: The path to the checkpoint.
77
+
78
+ Raises:
79
+ ------
80
+ Exception: If no checkpoint is found in the provided path.
81
+ """
82
+ if os.path.exists(checkpoint_path):
83
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
84
+ model.load_state_dict(checkpoint['model_state_dict'])
85
+ model.eval()
86
+ print(f"loaded existing model.")
87
+ else:
88
+ raise Exception("No checkpoint found in the provided path")
test.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.evaluator import evaluate
2
+ import argparse
3
+
4
+
5
+ parser = argparse.ArgumentParser(description='inference test with model.')
6
+ parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint file', default='model_lr0.0001_bs256_epoch50.pt')
7
+ parser.add_argument("--decompress", action="store_true", help="decompress the input text", default=False)
8
+ parser.add_argument('--vocab', type=str, help='Path to the vocab file', default='vocab.pt')
9
+ parser.add_argument('--text', type=str, help='Text to be tokenized', default="""dr. tonie mcdonald is a life long levittown resident who taught and rose through the ranks of the district she now leads .
10
+ he received his ba in chemistry , magna cum laude , from amherst college in 1 9 8 1 .""")
11
+ args = parser.parse_args()
12
+
13
+ print("--- input ---")
14
+ print(args.text)
15
+
16
+ # compress
17
+ print("--- compress ---")
18
+ evaluate(args)
19
+
20
+ # decompress
21
+ print("--- decompress ---")
22
+ args.decompress = True
23
+ args.text = """dr. tonie mcdonald is X life long levittown resident who taught and rose through X ranks of the district she now leads .
24
+ he received his ba X chemistry X magna cum laude X from amherst college in X X 8 1 ."""
25
+ evaluate(args)
vocab.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38847aa134accb833b3afc3204db2ce8650400907885a7efd3a1c541f58d3f0d
3
+ size 133355