Spaces:
Build error
Build error
| import sys | |
| import json | |
| import transformers | |
| import torch | |
| def vectorize_with_pretrained_embeddings(sentences): | |
| """ | |
| Produces a tensor containing a BERT embedding for each sentence in the dataset or in a | |
| batch | |
| Args: | |
| sentences: List of sentences of length n | |
| Returns: | |
| embeddings: A 2D torch array containing embeddings for each of the n sentences (n x d) | |
| where d = 768 | |
| """ | |
| tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased') | |
| pretrained_model = transformers.BertModel.from_pretrained( | |
| 'bert-base-cased', output_hidden_states=False) | |
| pretrained_model.eval() | |
| embeddings = [] | |
| for sentence in sentences: | |
| with_tags = "[CLS] " + sentence + " [SEP]" | |
| tokenized_sentence = tokenizer.tokenize(with_tags) | |
| tokenized_sentence = tokenized_sentence[:512] | |
| # print(tokenized_sentence) | |
| # print(len(tokenized_sentence)) | |
| indices_from_tokens = tokenizer.convert_tokens_to_ids( | |
| tokenized_sentence) | |
| segments_ids = [1] * len(indices_from_tokens) | |
| tokens_tensor = torch.tensor([indices_from_tokens]) | |
| segments_tensors = torch.tensor([segments_ids]) | |
| # print(indices_from_tokens) | |
| # print(tokens_tensor) | |
| # print(segments_tensors) | |
| with torch.no_grad(): | |
| outputs = pretrained_model(tokens_tensor, segments_tensors)[ | |
| 0] # The output is the | |
| # last hidden state of the pretrained model of shape 1 x sentence_length x BERT embedding_length | |
| # we average across the embedding length | |
| embeddings.append(torch.mean(outputs, dim=1)) | |
| # dimension to produce constant sized tensors | |
| # print(embeddings[0].shape) | |
| embeddings = torch.cat(embeddings, dim=0) | |
| # print('Shape of embeddings tensor (n x d = 768): ', embeddings.shape) | |
| return embeddings.cpu().detach().numpy() | |
| def main(): | |
| # Step 1: Read JSON input from stdin | |
| input_json = sys.stdin.read() | |
| inputs = json.loads(input_json) | |
| # Step 2: Extract inputs | |
| passage = inputs.get("Passage", "") | |
| question = inputs.get("QuestionText", "") | |
| distractors = inputs.get("Distractors", "") | |
| # Combine inputs | |
| combined_input = [f"{question}\n{distractors}\n{passage}"] | |
| # print(combined_input) | |
| embedding = vectorize_with_pretrained_embeddings(combined_input) | |
| embedding_flat = embedding.flatten() # Flatten to a 1D array | |
| embedding_str = ",".join(map(str, embedding_flat)) | |
| print(embedding_str) | |
| if __name__ == "__main__": | |
| main() | |