| import torch |
| from torch import nn |
| from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin |
| import torch.nn.functional as F |
| from transformers import Pipeline |
| from transformers import PreTrainedTokenizerFast |
| import re |
| import unicodedata |
|
|
|
|
| class EOLSTM(nn.Module): |
| def __init__(self, vocab_size, pad_token_id, embed_dim=256, hidden_dim=768, |
| num_layers=3, dropout_lstm=0.2, dropout_head=0.25, head_scale=2): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id) |
| self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, |
| dropout=dropout_lstm if num_layers>1 else 0, batch_first=True) |
| self.ln = nn.LayerNorm(hidden_dim) |
| self.fc = nn.Sequential( |
| nn.Linear(hidden_dim, int(hidden_dim*head_scale)), |
| nn.GELU(), |
| nn.Dropout(dropout_head), |
| nn.Linear(int(hidden_dim*head_scale), vocab_size) |
| ) |
|
|
| def forward(self, x, hidden=None): |
| x = self.embedding(x) |
| lstm_out, hidden = self.lstm(x, hidden) |
| lstm_out = self.ln(lstm_out) |
| logits = self.fc(lstm_out) |
| return logits, hidden |
|
|
|
|
| class EOLSTMConfig(PretrainedConfig): |
| model_type = "epo_lstm" |
| |
| def __init__( |
| self, |
| vocab_size=512, |
| pad_token_id=1, |
| sep_token_id=2, |
| eos_token_id=3, |
| embed_dim=256, |
| hidden_dim=768, |
| num_layers=3, |
| dropout_lstm=0.2, |
| dropout_head=0.25, |
| head_scale=2, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.pad_token_id = pad_token_id |
| self.sep_token_id = sep_token_id |
| self.eos_token_id = eos_token_id |
| self.embed_dim = embed_dim |
| self.hidden_dim = hidden_dim |
| self.num_layers = num_layers |
| self.dropout_lstm = dropout_lstm |
| self.dropout_head = dropout_head |
| self.head_scale = head_scale |
|
|
|
|
| class EOLSTMPretrained(PreTrainedModel): |
| config_class = EOLSTMConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = EOLSTM( |
| vocab_size=config.vocab_size, |
| pad_token_id=config.pad_token_id, |
| embed_dim=config.embed_dim, |
| hidden_dim=config.hidden_dim, |
| num_layers=config.num_layers, |
| dropout_lstm=config.dropout_lstm, |
| dropout_head=config.dropout_head, |
| head_scale=config.head_scale, |
| ) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| return self.model(input_ids) |
|
|
| class TextGenerator: |
| def __init__(self, model, eos_token_id, sep_token_id, device='cpu'): |
| self.eos_token_id = eos_token_id |
| self.sep_token_id = sep_token_id |
| |
| self.model = model.to(device) |
| self.model.eval() |
| self.device = device |
| |
| def generate(self, |
| prompt: str, |
| max_length: int = 1000, |
| temperature: float = 1.0, |
| top_k: int = 0, |
| top_p: float = 1.0, |
| repetition_penalty: float = 1.0) -> str: |
|
|
| generated_tokens = prompt |
|
|
| past = None |
|
|
| current_tokens = generated_tokens.copy() |
| |
| with torch.no_grad(): |
| for _ in range(max_length): |
| inputs = torch.tensor(current_tokens).to(self.device) |
| |
|
|
| logits, past = self.model(inputs, past) |
| logits = logits[-1, :] / temperature |
| |
|
|
| if repetition_penalty != 1.0: |
| for token in set(generated_tokens): |
| logits[token] /= repetition_penalty |
| |
|
|
| probs = F.softmax(logits, dim=-1) |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
|
|
| if top_k > 0: |
| sorted_probs = sorted_probs[..., :top_k] |
| sorted_indices = sorted_indices[..., :top_k] |
| sorted_probs /= sorted_probs.sum() |
|
|
|
|
| if top_p < 1.0: |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| mask = cumulative_probs <= top_p |
| |
| sorted_probs = sorted_probs[mask] |
| sorted_indices = sorted_indices[mask] |
| sorted_probs /= sorted_probs.sum() |
|
|
| next_token = sorted_indices[ |
| torch.multinomial(sorted_probs, num_samples=1) |
| ].item() |
|
|
| if next_token == self.eos_token_id or next_token == self.sep_token_id: |
| break |
| |
| generated_tokens.append(next_token) |
| current_tokens = [next_token] |
|
|
| |
| return generated_tokens |
| |
|
|
| class EOLSTMGenerator(EOLSTMPretrained, GenerationMixin): |
| def __init__(self, config): |
| super().__init__(config) |
| self.generator = TextGenerator( |
| model=self.model, |
| eos_token_id=config.eos_token_id, |
| sep_token_id=config.sep_token_id, |
| ) |
|
|
| def generate( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| max_length=1000, |
| temperature=1.0, |
| top_k: int = 0, |
| top_p: float = 1.0, |
| repetition_penalty: float = 1.0, |
| **kwargs |
| ): |
|
|
| tokens = self.generator.generate( |
| prompt=input_ids, |
| max_length=max_length, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty |
| ) |
| return tokens |
|
|
| class LSTMGeneratorPipeline(Pipeline): |
| def _sanitize_parameters(self, **kwargs): |
| generate_kwargs = { |
| "max_length": kwargs.get('max_length', 1000), |
| "temperature": kwargs.get('temperature', 0.7), |
| "top_k": kwargs.get('top_k', 0), |
| "top_p": kwargs.get('top_p', 1.0), |
| "repetition_penalty": kwargs.get('repetition_penalty', 1.5) |
| |
| } |
| return {}, generate_kwargs, {} |
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
|
|
| def preprocess(self, text, **kwargs): |
|
|
| def remove_accents(input_str): |
| nfkd_form = unicodedata.normalize('NFKD', input_str) |
| return u"".join([c for c in nfkd_form if not unicodedata.combining(c)]) |
| acc_leters = [ |
| ('ĉ', 'cx'), ('ĝ', 'gx'), ('ĵ', 'jx'), ('ĥ', 'hx'), ('ŝ', 'sx'), ('ŭ', 'ux'), ('й', 'иx'), ('ў', 'уx'), ('\\"', "\\'"), ('е', 'є'), ('-', '—'), |
| ] |
| |
| non_eo = re.compile('[^0123456789abcĉdefgĝhĥijĵklmnopqrsŝtuŭvwxyzабвгдежзиклмнопрстуўфхцчшщъыьэюяі.,!:>"? -]') |
| |
| text = text.lower() |
| text = re.sub('\\s+', ' ', text) |
| for let, letx in acc_leters: |
| text = re.sub(let, letx, text) |
| text = remove_accents(text) |
| for let, letx in acc_leters: |
| text = re.sub(letx, let, text) |
| text = non_eo.sub(' ', text) |
| text = re.sub('\\s+', ' ', text) |
| text = text.strip() |
|
|
| if not text.endswith('>'): |
| text = text + ' ' |
| |
| text = '[BOS]' + text |
| |
| return self.tokenizer( |
| text, |
| padding=True, |
| truncation=True, |
| max_length=512, |
| ) |
|
|
| def _forward(self, model_inputs, **kwargs): |
| input_ids = model_inputs["input_ids"] |
| return self.model.generate( |
| input_ids=input_ids, |
| max_length=kwargs.get('max_length', 1000), |
| temperature=kwargs.get('temperature', 1.0), |
| top_k=kwargs.get('top_k', 0), |
| top_p=kwargs.get('top_p', 1.0), |
| repetition_penalty=kwargs.get('repetition_penalty', 1.0) |
| ) |
|
|
| def postprocess(self, model_outputs): |
| model_outputs = model_outputs[1:] |
|
|
| def split_list(lst, delimiter): |
| result = [] |
| temp = [] |
| for item in lst: |
| if item == delimiter: |
| result.append(temp) |
| temp = [] |
| else: |
| temp.append(item) |
| result.append(temp) |
| return result |
|
|
| rev_voc = {self.tokenizer.vocab[key]: key for key in self.tokenizer.vocab} |
| |
| return [{"generated_text": ''.join([rev_voc[i] for i in output])} |
| for output in split_list(model_outputs, self.model.config.eos_token_id)] |
|
|
|
|