Spaces:
Configuration error
Configuration error
Commit ·
d541e5a
0
Parent(s):
added inference scripts, model and vocab
Browse files- .gitattributes +1 -0
- .gitignore +161 -0
- .idea/.gitignore +8 -0
- main.py +20 -0
- model_lr0.0001_bs256_epoch50.pt +3 -0
- src/__init__.py +0 -0
- src/evaluator.py +15 -0
- src/model.py +81 -0
- src/tokenizer.py +72 -0
- src/util.py +88 -0
- test.py +25 -0
- vocab.pt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
main.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from src.evaluator import evaluate
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
# parser
|
| 8 |
+
parser = argparse.ArgumentParser(description='inference with model.')
|
| 9 |
+
parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint file')
|
| 10 |
+
parser.add_argument("--decompress", action="store_true", help="decompress the input text")
|
| 11 |
+
parser.add_argument('--vocab', type=str, help='Path to the vocab file')
|
| 12 |
+
parser.add_argument('--text', type=str, help='Text to be tokenized')
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
|
| 15 |
+
# load model and vocab
|
| 16 |
+
evaluate(args)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
main()
|
model_lr0.0001_bs256_epoch50.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31923ca96e3c2471ad6252dfb615b15cde784be5a7792c7379d1c9a9b27a7f4e
|
| 3 |
+
size 551468733
|
src/__init__.py
ADDED
|
File without changes
|
src/evaluator.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.model import Model2
|
| 2 |
+
from src.tokenizer import Tokenizer
|
| 3 |
+
from src.util import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def evaluate(args):
|
| 7 |
+
vocab = torch.load(args.vocab, map_location=torch.device('cpu'))
|
| 8 |
+
model = Model2(len(vocab), 300, 256, vocab['<PAD>'])
|
| 9 |
+
load_from_checkpoint(model, args.checkpoint)
|
| 10 |
+
|
| 11 |
+
print()
|
| 12 |
+
if args.decompress:
|
| 13 |
+
print(decompress(args.text, Tokenizer(vocab), model))
|
| 14 |
+
else:
|
| 15 |
+
print(compress(args.text, Tokenizer(vocab), model))
|
src/model.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from src.util import device
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Transpose(nn.Module):
|
| 8 |
+
def __init__(self, dim0=None, dim1=None):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.dim0 = dim0
|
| 11 |
+
self.dim1 = dim1
|
| 12 |
+
|
| 13 |
+
def forward(self, tensor):
|
| 14 |
+
if self.dim0 is None:
|
| 15 |
+
self.dim0 = tensor.dim() - 2
|
| 16 |
+
self.dim1 = tensor.dim() - 1
|
| 17 |
+
|
| 18 |
+
return torch.transpose(tensor, self.dim0, self.dim1)
|
| 19 |
+
|
| 20 |
+
class Model2(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
vocab_size,
|
| 24 |
+
embedding_dim,
|
| 25 |
+
state_size,
|
| 26 |
+
pad_index,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.state_size = state_size
|
| 30 |
+
self.pad_index = pad_index
|
| 31 |
+
self.embedding_layer = nn.Embedding(
|
| 32 |
+
num_embeddings=vocab_size,
|
| 33 |
+
embedding_dim=embedding_dim,
|
| 34 |
+
padding_idx=pad_index,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.rnn_layer = nn.LSTMCell(input_size=embedding_dim, hidden_size=state_size)
|
| 38 |
+
self.lin1 = nn.Sequential(
|
| 39 |
+
nn.Linear(state_size, state_size * 4),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Dropout(p=0.5),
|
| 42 |
+
)
|
| 43 |
+
self.lin2 = nn.Sequential(
|
| 44 |
+
nn.Linear(state_size * 4, state_size * 8),
|
| 45 |
+
Transpose(),
|
| 46 |
+
nn.BatchNorm1d(state_size * 8),
|
| 47 |
+
Transpose(),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Dropout(p=0.5),
|
| 50 |
+
)
|
| 51 |
+
self.lin3 = nn.Sequential(
|
| 52 |
+
nn.Linear(state_size * 8, state_size * 16),
|
| 53 |
+
nn.ReLU(),
|
| 54 |
+
nn.Dropout(p=0.5),
|
| 55 |
+
)
|
| 56 |
+
self.lin4 = nn.Sequential(nn.Linear(state_size * 16, vocab_size))
|
| 57 |
+
|
| 58 |
+
def forward(self, X):
|
| 59 |
+
N, T = X.shape
|
| 60 |
+
non_pad_mask = X != self.pad_index
|
| 61 |
+
X = self.embedding_layer(X)
|
| 62 |
+
|
| 63 |
+
state = torch.zeros((N, self.state_size), device=device)
|
| 64 |
+
c = torch.zeros((N, self.state_size), device=device)
|
| 65 |
+
states = []
|
| 66 |
+
for t in range(T):
|
| 67 |
+
next_state, next_c = self.rnn_layer(X[:, t, :], (state, c))
|
| 68 |
+
# print(non_pad_mask[:, t].reshape(-1, 1).shape, next_state.shape, state.shape)
|
| 69 |
+
state = torch.where(non_pad_mask[:, t].reshape(-1, 1), next_state, state)
|
| 70 |
+
c = torch.where(non_pad_mask[:, t].reshape(-1, 1), next_c, c)
|
| 71 |
+
|
| 72 |
+
states.append(state)
|
| 73 |
+
|
| 74 |
+
# (N, T, states)
|
| 75 |
+
states = torch.stack(states, dim=1)
|
| 76 |
+
output = self.lin1(states)
|
| 77 |
+
output = self.lin2(output)
|
| 78 |
+
output = self.lin3(output)
|
| 79 |
+
output = self.lin4(output)
|
| 80 |
+
|
| 81 |
+
return output
|
src/tokenizer.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torchtext.vocab import Vocab
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
|
| 9 |
+
from src.util import device
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Tokenizer(nn.Module):
|
| 13 |
+
def __init__(self, vocab: str | Vocab):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# check vocab file exists
|
| 17 |
+
if isinstance(vocab, str):
|
| 18 |
+
assert os.path.exists(vocab)
|
| 19 |
+
self.vocab = torch.load(vocab, map_location=device)
|
| 20 |
+
else:
|
| 21 |
+
self.vocab = vocab
|
| 22 |
+
|
| 23 |
+
self.edge_index = vocab['<EDGE>']
|
| 24 |
+
self.pad_index = vocab['<PAD>']
|
| 25 |
+
self.unk_index = vocab['<UNK>']
|
| 26 |
+
|
| 27 |
+
def get_tensors(self, data):
|
| 28 |
+
"""
|
| 29 |
+
Builds torch.Tensor from a variable length 2D python list. The return value is a tuple of two tensors, one for input and the other for output.
|
| 30 |
+
|
| 31 |
+
Parameters
|
| 32 |
+
----------
|
| 33 |
+
data: Nested list of token indices
|
| 34 |
+
[[1,2,3],
|
| 35 |
+
[4,2,3,4,2],
|
| 36 |
+
[223,4,2]]
|
| 37 |
+
This example has three sentences.
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
max_len = max([len(datum) for datum in data]) + 1
|
| 41 |
+
N = len(data)
|
| 42 |
+
X = np.full((N, max_len), self.pad_index, np.int64)
|
| 43 |
+
Y = np.full((N, max_len), self.pad_index, np.int64)
|
| 44 |
+
|
| 45 |
+
for i in range(N):
|
| 46 |
+
# prepend the inputs with edge token
|
| 47 |
+
X[i, 0] = self.edge_index
|
| 48 |
+
for j in range(len(data[i])):
|
| 49 |
+
X[i, j + 1] = data[i][j]
|
| 50 |
+
Y[i, j] = data[i][j]
|
| 51 |
+
|
| 52 |
+
# finish the outputs with edge token
|
| 53 |
+
Y[i, j] = self.edge_index
|
| 54 |
+
|
| 55 |
+
return torch.tensor(X, device=device), torch.tensor(Y, device=device)
|
| 56 |
+
|
| 57 |
+
def forward(self, text: List[str]) -> Tuple[Tensor, Tensor]:
|
| 58 |
+
"""
|
| 59 |
+
Tokenizes a list of natural text. The return value is a tensor of token ids.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
text: List[str]. A list of natural language strings.
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
torch.Tensor. A tensor of token ids.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
text = [sentence.split() for sentence in text]
|
| 71 |
+
tokenized = [self.vocab(sentence) for sentence in text]
|
| 72 |
+
return self.get_tensors(tokenized)
|
src/util.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def stringify(array):
|
| 8 |
+
return '\n'.join([' '.join(inner_list) for inner_list in array])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compress(text, tokenizer, model):
|
| 12 |
+
"""
|
| 13 |
+
tokenizer: Tokenizer.
|
| 14 |
+
text: str.
|
| 15 |
+
Each line represents a single document.
|
| 16 |
+
"""
|
| 17 |
+
tokens = [sentence.split() for sentence in text.split("\n")]
|
| 18 |
+
indices, _ = tokenizer(text.split("\n"))
|
| 19 |
+
|
| 20 |
+
logits = model(indices)
|
| 21 |
+
next_token_predicted = logits.argmax(dim=2)
|
| 22 |
+
|
| 23 |
+
# slices are for skipping edge tokens
|
| 24 |
+
prediction_mask = indices[:, 1:] == next_token_predicted[:, :-1]
|
| 25 |
+
|
| 26 |
+
# replace correctly predicted tokens with "X"
|
| 27 |
+
for i, sentence_mask in enumerate(prediction_mask):
|
| 28 |
+
sentence_len = len(tokens[i])
|
| 29 |
+
for j, predicted_successfully in enumerate(sentence_mask):
|
| 30 |
+
# length check is to ignore pad tokens
|
| 31 |
+
if predicted_successfully and j < sentence_len and tokenizer.vocab[tokens[i][j]] != tokenizer.unk_index:
|
| 32 |
+
tokens[i][j] = "X"
|
| 33 |
+
|
| 34 |
+
sentences = [" ".join(sentence) for sentence in tokens]
|
| 35 |
+
document = "\n".join(sentences)
|
| 36 |
+
return document
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def decompress(text, tokenizer, model):
|
| 40 |
+
"""
|
| 41 |
+
text: str.
|
| 42 |
+
Each line represents a single document.
|
| 43 |
+
"""
|
| 44 |
+
sentence_tokens = [document.split() for document in text.split("\n")]
|
| 45 |
+
indices, _ = tokenizer(text.split("\n"))
|
| 46 |
+
|
| 47 |
+
uncompressed = []
|
| 48 |
+
for i, sentence in enumerate(sentence_tokens):
|
| 49 |
+
prefix = ['<EDGE>']
|
| 50 |
+
for j, token in enumerate(sentence):
|
| 51 |
+
if token != "X":
|
| 52 |
+
prefix.append(token)
|
| 53 |
+
else:
|
| 54 |
+
# only infer when X is found
|
| 55 |
+
indices = torch.tensor([tokenizer.vocab(prefix)],
|
| 56 |
+
dtype=torch.int,
|
| 57 |
+
device=device)
|
| 58 |
+
logits = model(indices)
|
| 59 |
+
# prediction logit for X
|
| 60 |
+
logit = logits[:, -1, :]
|
| 61 |
+
index = logit.argmax(dim=1)
|
| 62 |
+
prefix.append(tokenizer.vocab.lookup_token(index))
|
| 63 |
+
|
| 64 |
+
# reset prefix for new sentence
|
| 65 |
+
uncompressed.append(prefix[1:])
|
| 66 |
+
|
| 67 |
+
return stringify(uncompressed)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_from_checkpoint(model, checkpoint_path):
|
| 71 |
+
"""
|
| 72 |
+
Loads a model from a checkpoint.
|
| 73 |
+
|
| 74 |
+
Parameters:
|
| 75 |
+
----------
|
| 76 |
+
checkpoint_path: The path to the checkpoint.
|
| 77 |
+
|
| 78 |
+
Raises:
|
| 79 |
+
------
|
| 80 |
+
Exception: If no checkpoint is found in the provided path.
|
| 81 |
+
"""
|
| 82 |
+
if os.path.exists(checkpoint_path):
|
| 83 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
| 84 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 85 |
+
model.eval()
|
| 86 |
+
print(f"loaded existing model.")
|
| 87 |
+
else:
|
| 88 |
+
raise Exception("No checkpoint found in the provided path")
|
test.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.evaluator import evaluate
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
parser = argparse.ArgumentParser(description='inference test with model.')
|
| 6 |
+
parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint file', default='model_lr0.0001_bs256_epoch50.pt')
|
| 7 |
+
parser.add_argument("--decompress", action="store_true", help="decompress the input text", default=False)
|
| 8 |
+
parser.add_argument('--vocab', type=str, help='Path to the vocab file', default='vocab.pt')
|
| 9 |
+
parser.add_argument('--text', type=str, help='Text to be tokenized', default="""dr. tonie mcdonald is a life long levittown resident who taught and rose through the ranks of the district she now leads .
|
| 10 |
+
he received his ba in chemistry , magna cum laude , from amherst college in 1 9 8 1 .""")
|
| 11 |
+
args = parser.parse_args()
|
| 12 |
+
|
| 13 |
+
print("--- input ---")
|
| 14 |
+
print(args.text)
|
| 15 |
+
|
| 16 |
+
# compress
|
| 17 |
+
print("--- compress ---")
|
| 18 |
+
evaluate(args)
|
| 19 |
+
|
| 20 |
+
# decompress
|
| 21 |
+
print("--- decompress ---")
|
| 22 |
+
args.decompress = True
|
| 23 |
+
args.text = """dr. tonie mcdonald is X life long levittown resident who taught and rose through X ranks of the district she now leads .
|
| 24 |
+
he received his ba X chemistry X magna cum laude X from amherst college in X X 8 1 ."""
|
| 25 |
+
evaluate(args)
|
vocab.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:38847aa134accb833b3afc3204db2ce8650400907885a7efd3a1c541f58d3f0d
|
| 3 |
+
size 133355
|