| import torch | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from .configuration_greedy import GreedyConfig | |
| from freegroup import tools | |
| class GreedyModel(PreTrainedModel): | |
| config_class = GreedyConfig | |
| def __init__(self, config: GreedyConfig): | |
| super().__init__(config) | |
| self.stub = torch.nn.parameter.Parameter(torch.tensor(0.)) | |
| def _reduce_step(self, token, stack, reducables): | |
| stack.append(token.item()) | |
| for reducable in self.config.reciprocals + reducables: | |
| n = len(reducable) | |
| if len(stack) >= len(reducable): | |
| if tools.occurs(stack[-n:], reducable * 2): | |
| del stack[-n:] | |
| return stack | |
| def prepare_inputs_for_generation(self, input_ids, **kwargs): | |
| past = kwargs.pop('past', None) | |
| return {'input_ids': input_ids, 'past': past} | |
| def forward(self, input_ids = None, past = None, **kwargs): | |
| assert (input_ids is not None), "Can't be None" | |
| batch_size, sequence_length = input_ids.shape | |
| if past is None: | |
| stacks = [[[] for _ in range(len(self.config.reducables))] for _ in range(batch_size)] | |
| hidden_states = None | |
| else: | |
| stacks, hidden_states = past | |
| begin_idx = 0 if hidden_states is None else hidden_states.size(0) | |
| for t in range(begin_idx, sequence_length): | |
| last_hidden_states = torch.zeros((batch_size, self.config.vocab_size)) | |
| for batch_idx, word in enumerate(input_ids): | |
| for stack, reducables in zip(stacks[batch_idx], self.config.reducables): | |
| self._reduce_step(word[t], stack, reducables) | |
| if not stack: continue | |
| last = stack[-1] | |
| for r in reducables: | |
| if not last in r: | |
| key = r[0] | |
| last_hidden_states[batch_idx][r[0]] += 1 | |
| if last in r: | |
| pos = r.index(last) | |
| key = r[(pos + 1) % len(r)] | |
| last_hidden_states[batch_idx][key] += 1 | |
| for r in self.config.reciprocals: | |
| if last in r: | |
| pos = r.index(last) | |
| key = r[(pos + 1) % len(r)] | |
| last_hidden_states[batch_idx][key] += 1 | |
| for r in self.config.reciprocals: | |
| if word[t] in r: | |
| pos = r.index(word[t]) | |
| key = r[(pos + 1) % len(r)] | |
| last_hidden_states[batch_idx][key] = -torch.inf | |
| if all(map(lambda x: len(x) == 0, stacks[batch_idx])): | |
| last_hidden_states[batch_idx][self.config.eos_token_id] = torch.inf | |
| if hidden_states is None: hidden_states = last_hidden_states.clone().unsqueeze(0) | |
| else: hidden_states = torch.cat((hidden_states, last_hidden_states.unsqueeze(0))) | |
| return CausalLMOutputWithPast( | |
| logits = hidden_states.permute(1, 0, 2), | |
| past_key_values = (stacks, hidden_states) | |
| ) |