Update function.py
Browse files- function.py +97 -97
function.py
CHANGED
|
@@ -1,98 +1,98 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from transformers import RobertaForTokenClassification, AutoTokenizer
|
| 3 |
-
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
|
| 4 |
-
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
|
| 5 |
-
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 6 |
-
model.eval()
|
| 7 |
-
model.to(device)
|
| 8 |
-
import json
|
| 9 |
-
label_list={
|
| 10 |
-
0:'其他',
|
| 11 |
-
1:'电话',
|
| 12 |
-
2:'毕业时间', #毕业时间
|
| 13 |
-
3:'出生日期', #出生日期
|
| 14 |
-
4:'项目名称', #项目名称
|
| 15 |
-
5:'毕业院校', #毕业院校
|
| 16 |
-
6:'职务', #职务
|
| 17 |
-
7:'籍贯', #籍贯
|
| 18 |
-
8:'学位', #学位
|
| 19 |
-
9:'性别', #性别
|
| 20 |
-
10:'姓名', #姓名
|
| 21 |
-
11:'工作时间', #工作时间
|
| 22 |
-
12:'落户市县', #落户市县
|
| 23 |
-
13:'项目时间', #项目时间
|
| 24 |
-
14:'最高学历', #最高学历
|
| 25 |
-
15:'工作单位', #工作单位
|
| 26 |
-
16:'政治面貌', #政治面貌
|
| 27 |
-
17:'工作内容', #工作内容
|
| 28 |
-
18:'项目责任', #项目责任
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
def get_info(text):
|
| 32 |
-
#文本处理
|
| 33 |
-
text=text.strip()
|
| 34 |
-
text=text.replace('\n',',') # 将换行符替换为逗号
|
| 35 |
-
text=text.replace('\r',',') # 将回车符替换为逗号
|
| 36 |
-
text=text.replace('\t',',') # 将制表符替换为逗号
|
| 37 |
-
text=text.replace(' ',',') # 将空格替换为逗号
|
| 38 |
-
#将连续的逗号合并成一个逗号
|
| 39 |
-
while ',,' in text:
|
| 40 |
-
text=text.replace(',,',',')
|
| 41 |
-
block_list=[]
|
| 42 |
-
if len(text)>300:
|
| 43 |
-
#切块策略
|
| 44 |
-
#先切分成句
|
| 45 |
-
sentence_list=text.split(',')
|
| 46 |
-
#然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
|
| 47 |
-
boundary=300
|
| 48 |
-
block_list=[]
|
| 49 |
-
block=sentence_list[0]
|
| 50 |
-
for i in range(1,len(sentence_list)):
|
| 51 |
-
if len(block)+len(sentence_list[i])<=boundary:
|
| 52 |
-
block+=sentence_list[i]
|
| 53 |
-
else:
|
| 54 |
-
block_list.append(block)
|
| 55 |
-
block=sentence_list[i]
|
| 56 |
-
block_list.append(block)
|
| 57 |
-
else:
|
| 58 |
-
block_list.append(text)
|
| 59 |
-
_input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
|
| 60 |
-
#如果有GPU,将输入数据移到GPU
|
| 61 |
-
input_ids = _input['input_ids'].to(device)
|
| 62 |
-
attention_mask = _input['attention_mask'].to(device)
|
| 63 |
-
# 模型推理
|
| 64 |
-
with torch.no_grad():
|
| 65 |
-
logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
| 66 |
-
|
| 67 |
-
# 获取预测的标签ID
|
| 68 |
-
print(logits.shape)
|
| 69 |
-
ids = torch.argmax(logits, dim=-1)
|
| 70 |
-
input_ids=input_ids.reshape(-1)
|
| 71 |
-
#将张量在最后一个维度拼接,并以0为分界,拼接成句
|
| 72 |
-
ids =ids.reshape(-1)
|
| 73 |
-
# 按标签组合成提取内容
|
| 74 |
-
extracted_info = {}
|
| 75 |
-
word_list=[]
|
| 76 |
-
flag=None
|
| 77 |
-
for idx, label_id in enumerate(ids):
|
| 78 |
-
label_id = label_id.item()
|
| 79 |
-
if label_id!= 0 and (flag==None or flag==label_id): #不等于零时
|
| 80 |
-
if flag==None:
|
| 81 |
-
flag=label_id
|
| 82 |
-
label = label_list[label_id] # 获取对应的标签
|
| 83 |
-
word_list.append(input_ids[idx].item())
|
| 84 |
-
if label not in extracted_info:
|
| 85 |
-
extracted_info[label] = []
|
| 86 |
-
else:
|
| 87 |
-
if word_list:
|
| 88 |
-
sentence=''.join(tokenizer.decode(word_list))
|
| 89 |
-
extracted_info[label].append(sentence)
|
| 90 |
-
flag=None
|
| 91 |
-
word_list=[]
|
| 92 |
-
if label_id!= 0:
|
| 93 |
-
label = label_list[label_id] # 获取对应的标签
|
| 94 |
-
word_list.append(input_ids[idx].item())
|
| 95 |
-
if label not in extracted_info:
|
| 96 |
-
extracted_info[label] = []
|
| 97 |
-
# 返回JSON格式的提取内容
|
| 98 |
return extracted_info
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import RobertaForTokenClassification, AutoTokenizer
|
| 3 |
+
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
|
| 4 |
+
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
|
| 5 |
+
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 6 |
+
model.eval()
|
| 7 |
+
model.to(device)
|
| 8 |
+
import json
|
| 9 |
+
label_list={
|
| 10 |
+
0:'其他',
|
| 11 |
+
1:'电话',
|
| 12 |
+
2:'毕业时间', #毕业时间
|
| 13 |
+
3:'出生日期', #出生日期
|
| 14 |
+
4:'项目名称', #项目名称
|
| 15 |
+
5:'毕业院校', #毕业院校
|
| 16 |
+
6:'职务', #职务
|
| 17 |
+
7:'籍贯', #籍贯
|
| 18 |
+
8:'学位', #学位
|
| 19 |
+
9:'性别', #性别
|
| 20 |
+
10:'姓名', #姓名
|
| 21 |
+
11:'工作时间', #工作时间
|
| 22 |
+
12:'落户市县', #落户市县
|
| 23 |
+
13:'项目时间', #项目时间
|
| 24 |
+
14:'最高学历', #最高学历
|
| 25 |
+
15:'工作单位', #工作单位
|
| 26 |
+
16:'政治面貌', #政治面貌
|
| 27 |
+
17:'工作内容', #工作内容
|
| 28 |
+
18:'项目责任', #项目责任
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
def get_info(text):
|
| 32 |
+
#文本处理
|
| 33 |
+
text=text.strip()
|
| 34 |
+
text=text.replace('\n',',') # 将换行符替换为逗号
|
| 35 |
+
text=text.replace('\r',',') # 将回车符替换为逗号
|
| 36 |
+
text=text.replace('\t',',') # 将制表符替换为逗号
|
| 37 |
+
text=text.replace(' ',',') # 将空格替换为逗号
|
| 38 |
+
#将连续的逗号合并成一个逗号
|
| 39 |
+
while ',,' in text:
|
| 40 |
+
text=text.replace(',,',',')
|
| 41 |
+
block_list=[]
|
| 42 |
+
if len(text)>300:
|
| 43 |
+
#切块策略
|
| 44 |
+
#先切分成句
|
| 45 |
+
sentence_list=text.split(',')
|
| 46 |
+
#然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
|
| 47 |
+
boundary=300
|
| 48 |
+
block_list=[]
|
| 49 |
+
block=sentence_list[0]
|
| 50 |
+
for i in range(1,len(sentence_list)):
|
| 51 |
+
if len(block)+len(sentence_list[i])<=boundary:
|
| 52 |
+
block+=sentence_list[i]
|
| 53 |
+
else:
|
| 54 |
+
block_list.append(block)
|
| 55 |
+
block=sentence_list[i]
|
| 56 |
+
block_list.append(block)
|
| 57 |
+
else:
|
| 58 |
+
block_list.append(text)
|
| 59 |
+
_input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
|
| 60 |
+
#如果有GPU,将输入数据移到GPU
|
| 61 |
+
input_ids = _input['input_ids'].to(device)
|
| 62 |
+
attention_mask = _input['attention_mask'].to(device)
|
| 63 |
+
# 模型推理
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
| 66 |
+
|
| 67 |
+
# 获取预测的标签ID
|
| 68 |
+
#print(logits.shape)
|
| 69 |
+
ids = torch.argmax(logits, dim=-1)
|
| 70 |
+
input_ids=input_ids.reshape(-1)
|
| 71 |
+
#将张量在最后一个维度拼接,并以0为分界,拼接成句
|
| 72 |
+
ids =ids.reshape(-1)
|
| 73 |
+
# 按标签组合成提取内容
|
| 74 |
+
extracted_info = {}
|
| 75 |
+
word_list=[]
|
| 76 |
+
flag=None
|
| 77 |
+
for idx, label_id in enumerate(ids):
|
| 78 |
+
label_id = label_id.item()
|
| 79 |
+
if label_id!= 0 and (flag==None or flag==label_id): #不等于零时
|
| 80 |
+
if flag==None:
|
| 81 |
+
flag=label_id
|
| 82 |
+
label = label_list[label_id] # 获取对应的标签
|
| 83 |
+
word_list.append(input_ids[idx].item())
|
| 84 |
+
if label not in extracted_info:
|
| 85 |
+
extracted_info[label] = []
|
| 86 |
+
else:
|
| 87 |
+
if word_list:
|
| 88 |
+
sentence=''.join(tokenizer.decode(word_list))
|
| 89 |
+
extracted_info[label].append(sentence)
|
| 90 |
+
flag=None
|
| 91 |
+
word_list=[]
|
| 92 |
+
if label_id!= 0:
|
| 93 |
+
label = label_list[label_id] # 获取对应的标签
|
| 94 |
+
word_list.append(input_ids[idx].item())
|
| 95 |
+
if label not in extracted_info:
|
| 96 |
+
extracted_info[label] = []
|
| 97 |
+
# 返回JSON格式的提取内容
|
| 98 |
return extracted_info
|