| from fastapi import FastAPI |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
| import torch |
| import os |
| import json |
| import random |
| import numpy as np |
| from torch import nn |
| import argparse |
| import logging |
| from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config |
| from transformers import BertTokenizerFast |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, AdamW, BertModel |
|
|
| import requests |
| import uvicorn |
| from pydantic import BaseModel |
| from transformers import pipeline |
|
|
| extra_args = {} |
| def set_args(): |
| """ |
| Sets up the arguments. |
| """ |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--device', default='0', type=str, required=False, help='生成设备') |
| |
| |
| parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置') |
| parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径') |
| parser.add_argument('--vocab_path', default='/app/bert-base-zh/vocab.txt', type=str, required=False, |
| help='对话模型路径') |
| parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, |
| help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数") |
| |
| parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断') |
| parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度") |
| parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') |
| return parser.parse_args() |
|
|
|
|
| def create_logger(args): |
| """ |
| 将日志输出到日志文件和控制台 |
| """ |
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
|
|
| formatter = logging.Formatter( |
| '%(asctime)s - %(levelname)s - %(message)s') |
|
|
| |
| file_handler = logging.FileHandler( |
| filename=args.log_path) |
| file_handler.setFormatter(formatter) |
| file_handler.setLevel(logging.INFO) |
| logger.addHandler(file_handler) |
|
|
| |
| console = logging.StreamHandler() |
| console.setLevel(logging.DEBUG) |
| console.setFormatter(formatter) |
| logger.addHandler(console) |
|
|
| return logger |
|
|
| class Word_BERT(nn.Module): |
| def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2): |
| super(Word_BERT, self).__init__() |
| self.bert = BertModel.from_pretrained('/home/user/app/./bert-base-zh') |
| |
| self.out = nn.Sequential( |
| |
| |
| nn.Dropout(0.1), |
| nn.Linear(768, seq_label) |
| ) |
| self.cancer = nn.Sequential( |
| nn.Dropout(0.1), |
| nn.Linear(768, cancer_label) |
| ) |
| self.transfer = nn.Sequential( |
| nn.Dropout(0.1), |
| nn.Linear(768, transfer_label) |
| ) |
| self.ly_transfer = nn.Sequential( |
| nn.Dropout(0.1), |
| nn.Linear(768, ly_transfer) |
| ) |
|
|
| def forward(self, word_input, masks): |
| |
| output = self.bert(word_input, attention_mask=masks) |
| sequence_output = output.last_hidden_state |
| pool = output.pooler_output |
| |
| |
| out = self.out(sequence_output) |
| cancer = self.cancer(pool) |
| transfer = self.transfer(pool) |
| ly_transfer = self.ly_transfer(pool) |
| return out,cancer,transfer,ly_transfer |
|
|
| def getChat(text,model,tokenizer): |
| |
| |
| |
| |
| |
| |
| text = ['[CLS]']+[i for i in text]+['[SEP]'] |
| |
| text_ids = tokenizer.convert_tokens_to_ids(text) |
| |
|
|
| input_ids = torch.tensor(text_ids).long() |
| input_ids = input_ids.unsqueeze(0) |
| mask_input = torch.ones_like(input_ids).long() |
| |
| response = [] |
| |
| with torch.no_grad(): |
| out, cancer, transfer, ly_transfer = model(input_ids, mask_input) |
| out = F.sigmoid(out).squeeze(2).cpu() |
| out = out.numpy().tolist() |
| cancer = cancer.argmax(dim=-1).cpu().numpy().tolist() |
| transfer = transfer.argmax(dim=-1).cpu().numpy().tolist() |
| ly_transfer = ly_transfer.argmax(dim=-1).cpu().numpy().tolist() |
| |
| |
|
|
| pred_thresold = [[1 if jj > 0.4 else 0 for jj in ii] for ii in out] |
| size_list = [] |
| start,end = 0,0 |
| for i,j in enumerate(pred_thresold[0]): |
| if j==1 and start==end: |
| start = i |
| elif j!=1 and start!=end: |
| end = i |
| size_list.append((start,end)) |
| start = end |
| |
| size_text = [] |
| for k in size_list: |
| size_text.append(text[k[0]:k[1]]) |
| if len(size_text)==0: |
| size_str = "无" |
| else: |
| size_str = ''.join(size_text[0]) |
| if '×' in size_str: |
| split_w = '×' |
| else: |
| split_w = '*' |
| tt = size_str.split(split_w) |
| f = 0 |
| if tt[0][0].isdigit(): |
| size_float = [float(i) for i in tt] |
| for kk in size_float: |
| if kk>=4: |
| f = 1 |
| else: |
| size_str = "无" |
| |
| if size_str == "无": |
| size_4 = "无" |
| elif f==0: |
| size_4 = "<4cm" |
| else: |
| size_4 = ">=4cm" |
| |
| |
| |
| |
|
|
| cancer_dict = {'腺癌': 0, '肺良性疾病': 1, '鳞癌': 2, '无法判断组织分型': 3, '复合型': 4, '转移癌': 5, '小细胞癌': 6, '大细胞癌': 7} |
| id_cancer = {j:i for i,j in cancer_dict.items()} |
| transfer_id = {'无': 0, '转移': 1} |
| id_transfer = {j:i for i,j in transfer_id.items()} |
| lymph_transfer_id = {'无': 0, '淋巴转移': 1} |
| id_lymph_transfer = {j: i for i, j in lymph_transfer_id.items()} |
| |
| cancer = id_cancer[cancer[0]] |
| transfer = id_transfer[transfer[0]] |
| ly_transfer = id_lymph_transfer[ly_transfer[0]] |
| |
| output = "肿瘤大小:"+size_str+"\n肿瘤大小<>=4cm:"+size_4+"\n"+"病理组织分型:"+cancer+"\n"+"转移:"+transfer+"\n"+"淋巴转移:"+ly_transfer+"\n" |
|
|
| return output,size_str,size_4,cancer,transfer,ly_transfer |
|
|
| app = FastAPI() |
|
|
| def model_init(): |
| |
| acuda = torch.cuda.is_available() and not args.no_cuda |
| device = 'cuda' if acuda else 'cpu' |
| os.environ["CUDA_VISIBLE_DEVICES"] = device |
| tokenizer = BertTokenizerFast(vocab_file='/home/user/app/./bert-base-zh/vocab.txt', sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") |
| |
| model = Word_BERT() |
| model.load_state_dict(torch.load('/home/user/app/./model.pth',map_location=torch.device('cpu'))) |
| |
| model.eval() |
| return tokenizer,model |
| |
|
|
| tokenizer,model_extra = model_init() |
|
|
| @app.get("/infer_t5") |
| def t5(input): |
| output,size_str,size_4,cancer,transfer,ly_transfer = getChat(input,model_extra,tokenizer) |
| |
| |
| return {"output": output,"size_str":size_str,"size_4":size_4,"cancer":cancer,"transfer":transfer,"ly_transfer":ly_transfer} |
|
|
| app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
| @app.get("/") |
| def index() -> FileResponse: |
| return FileResponse(path="/app/static/index.html", media_type="text/html") |
|
|
| @app.get("/postText") |
| def postText(input): |
| output,size_str,size_4,cancer,transfer,ly_transfer = getChat(input,model_extra,tokenizer) |
|
|
| return {"output": output,"size_str":size_str,"size_4":size_4,"cancer":cancer,"transfer":transfer,"ly_transfer":ly_transfer} |
|
|