{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-05-09T08:24:27.682520Z", "start_time": "2024-05-09T08:24:23.650272Z" } }, "source": [ "# 导入库\n", "import torch\n", "from transformers import BertTokenizerFast" ], "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:27.871979Z", "start_time": "2024-05-09T08:24:27.683484Z" } }, "cell_type": "code", "source": "from utils import BertNerModel", "id": "bed4b6400ac293d8", "outputs": [], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:27.886940Z", "start_time": "2024-05-09T08:24:27.872977Z" } }, "cell_type": "code", "source": [ "# 加载模型\n", "with open('labels.txt', 'r') as f:\n", " labels = [i.rstrip('\\n') for i in f]\n", "labels" ], "id": "614210e3e65da37f", "outputs": [ { "data": { "text/plain": [ "['B-BANK',\n", " 'E-BANK',\n", " 'O',\n", " 'B-COMMENTS_N',\n", " 'E-COMMENTS_N',\n", " 'B-COMMENTS_ADJ',\n", " 'E-COMMENTS_ADJ',\n", " 'B-PRODUCT',\n", " 'E-PRODUCT',\n", " 'I-PRODUCT',\n", " 'I-COMMENTS_N',\n", " 'I-BANK',\n", " 'I-COMMENTS_ADJ',\n", " 'B-product_name',\n", " 'I-product_name',\n", " 'B-time',\n", " 'I-time',\n", " 'E-time',\n", " 'B-person_name',\n", " 'I-person_name',\n", " 'E-person_name',\n", " 'E-product_name',\n", " 'B-org_name',\n", " 'I-org_name',\n", " 'E-org_name',\n", " 'B-location',\n", " 'I-location',\n", " 'E-location',\n", " 'B-company_name',\n", " 'I-company_name',\n", " 'E-company_name',\n", " 'B-GPE',\n", " 'I-GPE',\n", " 'E-GPE',\n", " 'B-PER',\n", " 'I-PER',\n", " 'E-PER',\n", " 'B-LOC',\n", " 'I-LOC',\n", " 'E-LOC',\n", " 'B-ORG',\n", " 'I-ORG',\n", " 'E-ORG',\n", " 'B-body',\n", " 'E-body',\n", " 'I-body',\n", " 'B-symp',\n", " 'E-symp',\n", " 'I-symp',\n", " 'B-chec',\n", " 'E-chec',\n", " 'I-chec',\n", " 'B-dise',\n", " 'I-dise',\n", " 'E-dise',\n", " 'B-cure',\n", " 'I-cure',\n", " 'E-cure',\n", " 'B-身体部位',\n", " 'I-身体部位',\n", " 'B-检查和检验',\n", " 'E-检查和检验',\n", " 'I-检查和检验',\n", " 'E-身体部位',\n", " 'B-症状和体征',\n", " 'E-症状和体征',\n", " 'I-症状和体征',\n", " 'B-疾病和诊断',\n", " 'I-疾病和诊断',\n", " 'E-疾病和诊断',\n", " 'B-治疗',\n", " 'I-治疗',\n", " 'E-治疗',\n", " 'B-解剖部位',\n", " 'E-解剖部位',\n", " 'B-手术',\n", " 'I-手术',\n", " 'E-手术',\n", " 'B-影像检查',\n", " 'E-影像检查',\n", " 'I-解剖部位',\n", " 'B-药物',\n", " 'E-药物',\n", " 'I-药物',\n", " 'B-实验室检验',\n", " 'I-实验室检验',\n", " 'E-实验室检验',\n", " 'I-影像检查',\n", " 'B-name',\n", " 'I-name',\n", " 'E-name',\n", " 'B-address',\n", " 'E-address',\n", " 'B-organization',\n", " 'E-organization',\n", " 'B-game',\n", " 'I-game',\n", " 'E-game',\n", " 'I-address',\n", " 'B-scene',\n", " 'I-scene',\n", " 'E-scene',\n", " 'B-book',\n", " 'I-book',\n", " 'E-book',\n", " 'I-organization',\n", " 'B-company',\n", " 'I-company',\n", " 'E-company',\n", " 'B-position',\n", " 'E-position',\n", " 'I-position',\n", " 'B-government',\n", " 'I-government',\n", " 'E-government',\n", " 'B-movie',\n", " 'I-movie',\n", " 'E-movie',\n", " 'B-bod',\n", " 'I-bod',\n", " 'E-bod',\n", " 'B-dis',\n", " 'I-dis',\n", " 'E-dis',\n", " 'B-sym',\n", " 'I-sym',\n", " 'E-sym',\n", " 'B-pro',\n", " 'I-pro',\n", " 'E-pro',\n", " 'B-ite',\n", " 'I-ite',\n", " 'E-ite',\n", " 'B-mic',\n", " 'I-mic',\n", " 'E-mic',\n", " 'B-dep',\n", " 'E-dep',\n", " 'B-dru',\n", " 'I-dru',\n", " 'E-dru',\n", " 'I-dep',\n", " 'B-equ',\n", " 'I-equ',\n", " 'E-equ',\n", " 'B-Time',\n", " 'I-Time',\n", " 'E-Time',\n", " 'B-Person',\n", " 'B-Location',\n", " 'I-Location',\n", " 'E-Location',\n", " 'E-Person',\n", " 'B-Thing',\n", " 'E-Thing',\n", " 'B-Metric',\n", " 'E-Metric',\n", " 'I-Person',\n", " 'I-Thing',\n", " 'B-Organization',\n", " 'I-Organization',\n", " 'E-Organization',\n", " 'I-Metric',\n", " 'B-Abstract',\n", " 'I-Abstract',\n", " 'E-Abstract',\n", " 'B-Physical',\n", " 'I-Physical',\n", " 'E-Physical',\n", " 'B-Term',\n", " 'I-Term',\n", " 'E-Term',\n", " 'B-ABstract',\n", " 'I-ABstract',\n", " 'E-ABstract',\n", " 'B-HCCX',\n", " 'E-HCCX',\n", " 'I-HCCX',\n", " 'B-MISC',\n", " 'E-MISC',\n", " 'B-HPPX',\n", " 'E-HPPX',\n", " 'I-HPPX',\n", " 'I-MISC',\n", " 'B-XH',\n", " 'I-XH',\n", " 'E-XH',\n", " 'B-EQU',\n", " 'I-EQU',\n", " 'E-EQU',\n", " 'B-TIME',\n", " 'E-TIME',\n", " 'I-TIME',\n", " 'B-FAC',\n", " 'I-FAC',\n", " 'E-FAC',\n", " 'B-Symptom',\n", " 'E-Symptom',\n", " 'B-Medical_Examination',\n", " 'E-Medical_Examination',\n", " 'I-Medical_Examination',\n", " 'B-Drug',\n", " 'I-Drug',\n", " 'E-Drug',\n", " 'B-Drug_Category',\n", " 'I-Drug_Category',\n", " 'E-Drug_Category',\n", " 'I-Symptom',\n", " 'B-Operation',\n", " 'E-Operation',\n", " 'I-Operation',\n", " 'B-NAME',\n", " 'I-NAME',\n", " 'E-NAME',\n", " 'B-CONT',\n", " 'I-CONT',\n", " 'E-CONT',\n", " 'B-EDU',\n", " 'I-EDU',\n", " 'E-EDU',\n", " 'B-TITLE',\n", " 'I-TITLE',\n", " 'E-TITLE',\n", " 'B-RACE',\n", " 'E-RACE',\n", " 'B-PRO',\n", " 'I-PRO',\n", " 'E-PRO',\n", " 'I-RACE',\n", " 'B-T',\n", " 'I-T',\n", " 'E-T']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 3 }, { "metadata": {}, "cell_type": "markdown", "source": "", "id": "91f0e81784b80263" }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:29.299162Z", "start_time": "2024-05-09T08:24:27.888935Z" } }, "cell_type": "code", "source": [ "model = BertNerModel(labels)\n", "model.load_state_dict(torch.load(\"bert-model.pth\"))" ], "id": "2129b37797e3c37a", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:29.314124Z", "start_time": "2024-05-09T08:24:29.302155Z" } }, "cell_type": "code", "source": "tr_label_dict = {j: i for i,j in zip(labels, range(len(labels)))}", "id": "89256fa2c48b0519", "outputs": [], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:29.376988Z", "start_time": "2024-05-09T08:24:29.316121Z" } }, "cell_type": "code", "source": [ "tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese', do_lower_case=True)\n", "label_all_tokens = True\n", "\n", "def align_word_ids(texts):\n", " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n", " word_ids = tokenized_inputs.word_ids()\n", " previous_word_idx = None\n", " label_ids = []\n", " for word_idx in word_ids:\n", " if word_idx is None:\n", " label_ids.append(-100)\n", "\n", " elif word_idx != previous_word_idx:\n", " try:\n", " label_ids.append(1)\n", " except:\n", " label_ids.append(-100)\n", " else:\n", " try:\n", " label_ids.append(1 if label_all_tokens else -100)\n", " except:\n", " label_ids.append(-100)\n", " previous_word_idx = word_idx\n", " return label_ids\n", "\n", "def evaluate_one_text(model, sentence):\n", "\n", " use_cuda = torch.cuda.is_available()\n", " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", "\n", " if use_cuda:\n", " model = model.cuda()\n", "\n", " text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\")\n", " mask = text['attention_mask'][0].unsqueeze(0).to(device)\n", " input_id = text['input_ids'][0].unsqueeze(0).to(device)\n", " label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)\n", "\n", " logits = model(input_id, mask, None)\n", " logits_clean = logits[0][label_ids != -100]\n", "\n", " predictions = logits_clean.argmax(dim=1).tolist()\n", " prediction_label = [tr_label_dict[i] for i in predictions]\n", " print(sentence)\n", " print(prediction_label)" ], "id": "12a17ea429a5710c", "outputs": [], "execution_count": 6 }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:29.408903Z", "start_time": "2024-05-09T08:24:29.377983Z" } }, "cell_type": "code", "source": "align_word_ids('悉尼遭袭2名死伤中国公民为留学生')", "id": "31f15809e7f8913c", "outputs": [ { "data": { "text/plain": [ "[-100,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100,\n", " -100]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:30.907891Z", "start_time": "2024-05-09T08:24:29.411863Z" } }, "cell_type": "code", "source": "evaluate_one_text(model, '悉尼遭袭2名死伤中国公民为留学生')", "id": "acc1f02571cf9e6c", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "悉尼遭袭2名死伤中国公民为留学生\n", "['B-LOC', 'E-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'E-LOC', 'O', 'O', 'O', 'O', 'O', 'O']\n" ] } ], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-09T08:24:30.923815Z", "start_time": "2024-05-09T08:24:30.909863Z" } }, "cell_type": "code", "source": "", "id": "b84039ea25ce0baf", "outputs": [], "execution_count": 8 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }