| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, Union |
| | from sagemaker_inference import encoder |
| | import torch |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | from transformers import AutoTokenizer, BertPreTrainedModel |
| | from transformers.models.bert import BertModel |
| | from transformers.modeling_outputs import ModelOutput |
| |
|
| |
|
| | @dataclass |
| | class MultipleChoiceModelOutput(ModelOutput): |
| | loss: Optional[torch.FloatTensor] = None |
| | logits: torch.FloatTensor = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| | class BertForMultipleChoice(BertPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.bert = BertModel(config) |
| | classifier_dropout = ( |
| | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
| | ) |
| | self.dropout = nn.Dropout(classifier_dropout) |
| | self.classifier = nn.Linear(config.hidden_size, 1) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
| |
|
| | input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None |
| | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None |
| | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None |
| | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None |
| | inputs_embeds = ( |
| | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) |
| | if inputs_embeds is not None |
| | else None |
| | ) |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | pooled_output = outputs[1] |
| |
|
| | pooled_output = self.dropout(pooled_output) |
| | logits = self.classifier(pooled_output) |
| | reshaped_logits = logits.view(-1, num_choices) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(reshaped_logits, labels) |
| |
|
| | if not return_dict: |
| | output = (reshaped_logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return MultipleChoiceModelOutput( |
| | loss=loss, |
| | logits=reshaped_logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| |
|
| | def model_fn(model_dir): |
| | tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| | model = BertForMultipleChoice.from_pretrained(model_dir) |
| | return {"model": model, "tokenizer": tokenizer} |
| |
|
| |
|
| | def predict_fn(data, model): |
| | prompt = data["prompt"] |
| | candidates = data["candidates"] |
| |
|
| | inputs = model["tokenizer"]( |
| | [[prompt, candidate] for candidate in candidates], |
| | return_tensors="pt", |
| | padding=True |
| | ) |
| |
|
| | labels = torch.tensor(0).unsqueeze(0) |
| |
|
| | with torch.no_grad(): |
| | outputs = model( |
| | **{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels |
| | ) |
| | |
| | return outputs.logits |
| |
|
| |
|
| | def output_fn(prediction, content_type): |
| | result = {i: x for i, x in enumerate(prediction)} |
| | return encoder.encode(result, content_type) |
| |
|