| | from typing import Dict |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import torch |
| |
|
| |
|
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | "gpt2", torch_dtype=torch.float16, output_hidden_states=True |
| | ) |
| | self.model = self.model.cuda() |
| | self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| |
|
| |
|
| | def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]: |
| | """ |
| | Args: |
| | data (:obj:): |
| | includes the deserialized audio file as bytes |
| | Return: |
| | A :obj:`dict`:. base64 encoded image |
| | """ |
| | |
| | inputs = data.pop("inputs", data) |
| | all_logits = [] |
| |
|
| | for doc in inputs: |
| | tokenized = self.tokenizer( |
| | inputs, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=512, |
| | ) |
| | token_ids, token_mask = tokenized.input_ids.cuda(), tokenized.attention_mask.cuda() |
| | with torch.no_grad(): |
| | out = model(token_ids, attention_mask=token_mask) |
| | meaned_logits = (out.logits * token_mask.unsqueeze(-1)).sum(1) / token_mask.sum( |
| | 1 |
| | ).unsqueeze(-1) |
| | sorted_logits = torch.sort(out.logits).values |
| | mean_sorted_logits = (sorted_logits * token_mask.unsqueeze(-1)).sum( |
| | 1 |
| | ) / token_mask.sum(1).unsqueeze(-1) |
| | all_logits.append(meaned_logits.cpu().numpy().tolist()) |
| | |
| | |
| | return {"logits": all_logits} |