epo_lstm / epo_lstm.py
timcryt's picture
Upload 2 files
eff5017 verified
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
import torch.nn.functional as F
from transformers import Pipeline
from transformers import PreTrainedTokenizerFast
import re
import unicodedata
class EOLSTM(nn.Module):
def __init__(self, vocab_size, pad_token_id, embed_dim=256, hidden_dim=768,
num_layers=3, dropout_lstm=0.2, dropout_head=0.25, head_scale=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
dropout=dropout_lstm if num_layers>1 else 0, batch_first=True)
self.ln = nn.LayerNorm(hidden_dim)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, int(hidden_dim*head_scale)),
nn.GELU(),
nn.Dropout(dropout_head),
nn.Linear(int(hidden_dim*head_scale), vocab_size)
)
def forward(self, x, hidden=None):
x = self.embedding(x)
lstm_out, hidden = self.lstm(x, hidden)
lstm_out = self.ln(lstm_out)
logits = self.fc(lstm_out)
return logits, hidden
class EOLSTMConfig(PretrainedConfig):
model_type = "epo_lstm"
def __init__(
self,
vocab_size=512,
pad_token_id=1,
sep_token_id=2,
eos_token_id=3,
embed_dim=256,
hidden_dim=768,
num_layers=3,
dropout_lstm=0.2,
dropout_head=0.25,
head_scale=2,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.sep_token_id = sep_token_id
self.eos_token_id = eos_token_id
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout_lstm = dropout_lstm
self.dropout_head = dropout_head
self.head_scale = head_scale
class EOLSTMPretrained(PreTrainedModel):
config_class = EOLSTMConfig
def __init__(self, config):
super().__init__(config)
self.model = EOLSTM(
vocab_size=config.vocab_size,
pad_token_id=config.pad_token_id,
embed_dim=config.embed_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
dropout_lstm=config.dropout_lstm,
dropout_head=config.dropout_head,
head_scale=config.head_scale,
)
def forward(self, input_ids, attention_mask=None, labels=None):
return self.model(input_ids)
class TextGenerator:
def __init__(self, model, eos_token_id, sep_token_id, device='cpu'):
self.eos_token_id = eos_token_id
self.sep_token_id = sep_token_id
self.model = model.to(device)
self.model.eval()
self.device = device
def generate(self,
prompt: str,
max_length: int = 1000,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0) -> str:
generated_tokens = prompt
past = None
current_tokens = generated_tokens.copy()
with torch.no_grad():
for _ in range(max_length):
inputs = torch.tensor(current_tokens).to(self.device)
logits, past = self.model(inputs, past)
logits = logits[-1, :] / temperature
if repetition_penalty != 1.0:
for token in set(generated_tokens):
logits[token] /= repetition_penalty
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
if top_k > 0:
sorted_probs = sorted_probs[..., :top_k]
sorted_indices = sorted_indices[..., :top_k]
sorted_probs /= sorted_probs.sum()
if top_p < 1.0:
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs <= top_p
sorted_probs = sorted_probs[mask]
sorted_indices = sorted_indices[mask]
sorted_probs /= sorted_probs.sum()
next_token = sorted_indices[
torch.multinomial(sorted_probs, num_samples=1)
].item()
if next_token == self.eos_token_id or next_token == self.sep_token_id:
break
generated_tokens.append(next_token)
current_tokens = [next_token]
return generated_tokens
class EOLSTMGenerator(EOLSTMPretrained, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.generator = TextGenerator(
model=self.model,
eos_token_id=config.eos_token_id,
sep_token_id=config.sep_token_id,
)
def generate(
self,
input_ids=None,
attention_mask=None,
max_length=1000,
temperature=1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
**kwargs
):
tokens = self.generator.generate(
prompt=input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty
)
return tokens
class LSTMGeneratorPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
generate_kwargs = {
"max_length": kwargs.get('max_length', 1000),
"temperature": kwargs.get('temperature', 0.7),
"top_k": kwargs.get('top_k', 0),
"top_p": kwargs.get('top_p', 1.0),
"repetition_penalty": kwargs.get('repetition_penalty', 1.5)
}
return {}, generate_kwargs, {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def preprocess(self, text, **kwargs):
def remove_accents(input_str):
nfkd_form = unicodedata.normalize('NFKD', input_str)
return u"".join([c for c in nfkd_form if not unicodedata.combining(c)])
acc_leters = [
('ĉ', 'cx'), ('ĝ', 'gx'), ('ĵ', 'jx'), ('ĥ', 'hx'), ('ŝ', 'sx'), ('ŭ', 'ux'), ('й', 'иx'), ('ў', 'уx'), ('\\"', "\\'"), ('е', 'є'), ('-', '—'),
]
non_eo = re.compile('[^0123456789abcĉdefgĝhĥijĵklmnopqrsŝtuŭvwxyzабвгдежзиклмнопрстуўфхцчшщъыьэюяі.,!:>"? -]')
text = text.lower()
text = re.sub('\\s+', ' ', text)
for let, letx in acc_leters:
text = re.sub(let, letx, text)
text = remove_accents(text)
for let, letx in acc_leters:
text = re.sub(letx, let, text)
text = non_eo.sub(' ', text)
text = re.sub('\\s+', ' ', text)
text = text.strip()
if not text.endswith('>'):
text = text + ' '
text = '[BOS]' + text
return self.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
)
def _forward(self, model_inputs, **kwargs):
input_ids = model_inputs["input_ids"]
return self.model.generate(
input_ids=input_ids,
max_length=kwargs.get('max_length', 1000),
temperature=kwargs.get('temperature', 1.0),
top_k=kwargs.get('top_k', 0),
top_p=kwargs.get('top_p', 1.0),
repetition_penalty=kwargs.get('repetition_penalty', 1.0)
)
def postprocess(self, model_outputs):
model_outputs = model_outputs[1:]
def split_list(lst, delimiter):
result = []
temp = []
for item in lst:
if item == delimiter:
result.append(temp)
temp = []
else:
temp.append(item)
result.append(temp)
return result
rev_voc = {self.tokenizer.vocab[key]: key for key in self.tokenizer.vocab}
return [{"generated_text": ''.join([rev_voc[i] for i in output])}
for output in split_list(model_outputs, self.model.config.eos_token_id)]