| import torch |
| from transformers import RobertaForTokenClassification, AutoTokenizer |
| model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract') |
| tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True) |
| device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model.eval() |
| model.to(device) |
| import json |
| label_list={ |
| 0:'其他', |
| 1:'电话', |
| 2:'毕业时间', |
| 3:'出生日期', |
| 4:'项目名称', |
| 5:'毕业院校', |
| 6:'职务', |
| 7:'籍贯', |
| 8:'学位', |
| 9:'性别', |
| 10:'姓名', |
| 11:'工作时间', |
| 12:'落户市县', |
| 13:'项目时间', |
| 14:'最高学历', |
| 15:'工作单位', |
| 16:'政治面貌', |
| 17:'工作内容', |
| 18:'项目责任', |
| } |
|
|
| def get_info(text): |
| |
| text=text.strip() |
| text=text.replace('\n',',') |
| text=text.replace('\r',',') |
| text=text.replace('\t',',') |
| text=text.replace(' ',',') |
| |
| while ',,' in text: |
| text=text.replace(',,',',') |
| block_list=[] |
| if len(text)>300: |
| |
| |
| sentence_list=text.split(',') |
| |
| boundary=300 |
| block_list=[] |
| block=sentence_list[0] |
| for i in range(1,len(sentence_list)): |
| if len(block)+len(sentence_list[i])<=boundary: |
| block+=sentence_list[i] |
| else: |
| block_list.append(block) |
| block=sentence_list[i] |
| block_list.append(block) |
| else: |
| block_list.append(text) |
| _input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True) |
| |
| input_ids = _input['input_ids'].to(device) |
| attention_mask = _input['attention_mask'].to(device) |
| |
| with torch.no_grad(): |
| logits = model(input_ids=input_ids, attention_mask=attention_mask)[0] |
|
|
| |
| |
| ids = torch.argmax(logits, dim=-1) |
| input_ids=input_ids.reshape(-1) |
| |
| ids =ids.reshape(-1) |
| |
| extracted_info = {} |
| word_list=[] |
| flag=None |
| for idx, label_id in enumerate(ids): |
| label_id = label_id.item() |
| if label_id!= 0 and (flag==None or flag==label_id): |
| if flag==None: |
| flag=label_id |
| label = label_list[label_id] |
| word_list.append(input_ids[idx].item()) |
| if label not in extracted_info: |
| extracted_info[label] = [] |
| else: |
| if word_list: |
| sentence=''.join(tokenizer.decode(word_list)) |
| extracted_info[label].append(sentence) |
| flag=None |
| word_list=[] |
| if label_id!= 0: |
| label = label_list[label_id] |
| word_list.append(input_ids[idx].item()) |
| if label not in extracted_info: |
| extracted_info[label] = [] |
| |
| return extracted_info |