| | import os |
| | import numpy as np |
| | import transformers |
| | import torch |
| | import torch.nn as nn |
| | from torch import cuda |
| | from transformers import BertTokenizer |
| | from BERT_inference import BertClassificationModel |
| |
|
| |
|
| | def encoder(max_len,text): |
| |
|
| | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") |
| | tokenizer = tokenizer( |
| | text, |
| | padding = True, |
| | truncation = True, |
| | max_length = max_len, |
| | return_tensors='pt' |
| | ) |
| | input_ids = tokenizer['input_ids'] |
| | token_type_ids = tokenizer['token_type_ids'] |
| | attention_mask = tokenizer['attention_mask'] |
| | return input_ids,token_type_ids,attention_mask |
| |
|
| |
|
| | def predict(model,device,text): |
| | model.to(device) |
| | model.eval() |
| | with torch.no_grad(): |
| | input_ids,token_type_ids,attention_mask = encoder(512,text) |
| | input_ids,token_type_ids,attention_mask=input_ids.to(device),token_type_ids.to(device),attention_mask.to(device) |
| | out_put = model(input_ids,token_type_ids,attention_mask) |
| | |
| | probs = torch.nn.functional.softmax(out_put).detach().cpu().numpy().tolist() |
| | |
| | return probs[0][1] |
| | |
| | |
| | def inference_matrix(topics): |
| | device = torch.device('cuda' if cuda.is_available() else 'cpu') |
| | load_path = "bert_model.pkl" |
| | model = torch.load(load_path,map_location=torch.device(device)) |
| | matrix = np.zeros([len(topics),len(topics)],dtype=float) |
| | for i,i_text in enumerate(topics): |
| | for j,j_text in enumerate(topics): |
| | if(i == j): |
| | matrix[i][j] = 0 |
| | else: |
| | test = i_text+" 是否包含 "+j_text |
| | outputs = predict(model,device,test) |
| | |
| | |
| | matrix[i][j] = outputs |
| |
|
| | return matrix |
| | if __name__ == "__main__": |
| |
|
| | print("yes") |
| | topics = ['在本次报告中我将介绍分布式并行加速算法模型架构内存和计算优化以及集群架构等关键技术', '在现代机器学习任务中大模型训练已成为解决复杂问题的重要手段', '首先分布式并行加速策略包括数据并行模型并行流水线并行和张量并行等四种方式', '选择合适的集群架构是实现大模型的分布式训练的关键', '这些策略帮助我们将训练数据和模型分布到多个设备上以加速大模型训练过程'] |
| | print(inference_matrix(topics)) |