Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Tue Sep 5 10:29:03 2023 | |
| @author: peter | |
| """ | |
| import transformers | |
| import torch | |
| class QaracDecoderHead(torch.nn.Module): | |
| def __init__(self,config): | |
| """ | |
| Creates the Decoder head | |
| Parameters | |
| ---------- | |
| config : transformers.RobertaConfig | |
| Config for the RobertaModel that this head will be attached to. | |
| Returns | |
| ------- | |
| None. | |
| """ | |
| super(QaracDecoderHead,self).__init__() | |
| self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config) | |
| self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config) | |
| self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config) | |
| def forward(self, | |
| vector, | |
| hidden_states, | |
| attention_mask=None): | |
| """ | |
| Predicts text fron vector and hidden states of base model | |
| Parameters | |
| ---------- | |
| inputs : tuple of tensorflow.Tensors | |
| Vector to be decoded and last hidden states of base model | |
| Returns | |
| ------- | |
| transformers.modeling_outputs.CausalLMOutputWithCrossAttentions | |
| Predicted text | |
| """ | |
| vectors = torch.cat([vector, hidden_states], | |
| dim=1) | |
| attentions = attention_mask if attention_mask is None else torch.cat([torch.ones((hidden_states.shape(0), | |
| 1)), | |
| attention_mask]) | |
| l0 = self.layer_0(vectors, | |
| attentions) | |
| return self.head(self.layer_1(l0[0][:,1:], | |
| attention_mask)[0]) | |
| class QaracDecoderModel(transformers.RobertaModel, | |
| transformers.generation_utils.GenerationMixin): | |
| def __init__(self,model_path,config,tokenizer): | |
| """ | |
| Creates decoder model from base model | |
| Parameters | |
| ---------- | |
| base_model : transformers.TFRobertaModel | |
| The base model | |
| Returns | |
| ------- | |
| None. | |
| """ | |
| super(QaracDecoderModel,self).__init__(config) | |
| self.decoder_base = transformers.RobertaModel.from_pretrained(model_path, | |
| config=config) | |
| self.decoder_head = QaracDecoderHead(self.config) | |
| self.tokenizer = tokenizer | |
| def forward(self,inputs,**kwargs): | |
| """ | |
| Predicts text from inputsBleakley | |
| Parameters | |
| ---------- | |
| inputs : tuple of Tensorflow.Tensors OR tensorflow.Tensor | |
| Vector to be converted to text and seed text OR tokenized seed text | |
| kwargs : optional keyword arguments | |
| vector : tensorflow.Tensor vector to be decoded. May be supplied | |
| via a keyword argument when this is invoked by .generate | |
| Returns | |
| ------- | |
| transformers.modeling_outputs.CausalLMOutputWithCrossAttentions | |
| Predicted text | |
| """ | |
| (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs | |
| (seed,attention_mask) = (s['input_ids'],s['attention_mask']) if 'attention_mask' in s else (s,None) | |
| return self.decoder_head(torch.unsqueeze(v,1), | |
| self.decoder_base(seed, | |
| attention_mask=attention_mask, | |
| use_cache='vector' in kwargs).last_hidden_state) | |
| def prepare_inputs_for_generation(self, | |
| input_ids, | |
| attention_mask=None, | |
| **kwargs): | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| return {'input_ids':input_ids, | |
| 'attention_mask':attention_mask} | |