renyi211 commited on
Commit
c1678bf
·
verified ·
1 Parent(s): 9262062

Upload app.py

Browse files
Files changed (1) hide show
  1. src/app.py +170 -0
src/app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import transformers
3
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
4
+ import torch
5
+ from html import escape
6
+ import os
7
+
8
+ # 修复权限问题
9
+ os.environ['STREAMLIT_CONFIG_DIR'] = os.getcwd()
10
+ os.environ['STREAMLIT_GATHER_USAGE_STATS'] = 'false'
11
+ # 设置页面
12
+ st.set_page_config(page_title="NER 实体识别", page_icon="🔍", layout="wide")
13
+
14
+ # 标题和描述
15
+ st.title("🔍 命名实体识别 (NER)")
16
+ st.markdown("使用 `dslim/bert-base-NER` 模型识别文本中的实体并分类")
17
+
18
+ # 初始化模型
19
+ @st.cache_resource
20
+ def load_model():
21
+ tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
22
+ model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
23
+ return tokenizer, model
24
+
25
+ # 加载模型
26
+ with st.spinner("正在加载模型,请稍候..."):
27
+ try:
28
+ tokenizer, model = load_model()
29
+ st.success("模型加载成功!")
30
+ except Exception as e:
31
+ st.error(f"加载模型时出错: {e}")
32
+ st.stop()
33
+
34
+ # 实体类型颜色映射
35
+ entity_colors = {
36
+ 'PER': '#FF6B6B', # 人物 - 红色
37
+ 'ORG': '#4ECDC4', # 组织 - 青绿色
38
+ 'LOC': '#FFD166', # 地点 - 黄色
39
+ 'MISC': '#9E6FDC' # 其他 - 紫色
40
+ }
41
+
42
+ # 实体类型描述
43
+ entity_descriptions = {
44
+ 'PER': '人物 (Person)',
45
+ 'ORG': '组织 (Organization)',
46
+ 'LOC': '地点 (Location)',
47
+ 'MISC': '其他 (Miscellaneous)'
48
+ }
49
+
50
+ # 显示实体类型说明
51
+ with st.expander("实体类型说明"):
52
+ cols = st.columns(4)
53
+ for i, (key, value) in enumerate(entity_descriptions.items()):
54
+ color = entity_colors.get(key, '#CCCCCC')
55
+ cols[i].markdown(f"<span style='background-color:{color}; padding:5px; border-radius:5px; color:white;'>{key}</span> - {value}", unsafe_allow_html=True)
56
+
57
+ # 输入文本
58
+ default_text = "My name is Clara and I live in Berkeley, California."
59
+ text_input = st.text_area("输入要分析的文本:", value=default_text, height=100)
60
+
61
+ # 处理按钮
62
+ if st.button("识别实体", type="primary"):
63
+ if not text_input.strip():
64
+ st.warning("请输入一些文本进行分析")
65
+ else:
66
+ with st.spinner("正在分析文本..."):
67
+ try:
68
+ # 对输入文本进行编码
69
+ inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
70
+
71
+ # 使用模型进行预测
72
+ outputs = model(**inputs)
73
+
74
+ # 获取预测结果
75
+ predictions = torch.argmax(outputs.logits, dim=-1)[0]
76
+
77
+ # 解码预测结果
78
+ predicted_labels = [model.config.id2label[t.item()] for t in predictions]
79
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
80
+
81
+ # 处理输出
82
+ current_entity = None
83
+ current_tokens = []
84
+ results = []
85
+
86
+ for token, label in zip(tokens, predicted_labels):
87
+ # 跳过特殊token
88
+ if token in ['[CLS]', '[SEP]']:
89
+ continue
90
+
91
+ # 处理子词token
92
+ if token.startswith('##'):
93
+ token = token[2:]
94
+
95
+ # 处理实体标签
96
+ if label != 'O':
97
+ entity_type = label.split('-')[-1]
98
+
99
+ if label.startswith('B-'):
100
+ # 如果是新实体的开始
101
+ if current_entity:
102
+ results.append((' '.join(current_tokens), current_entity))
103
+ current_entity = entity_type
104
+ current_tokens = [token]
105
+ else:
106
+ # 继续当前实体
107
+ current_tokens.append(token)
108
+ else:
109
+ # 如果不是实体
110
+ if current_entity:
111
+ results.append((' '.join(current_tokens), current_entity))
112
+ current_entity = None
113
+ current_tokens = []
114
+
115
+ # 添加最后一个实体
116
+ if current_entity:
117
+ results.append((' '.join(current_tokens), current_entity))
118
+
119
+ # 显示结果
120
+ st.subheader("分析结果")
121
+
122
+ # 创建两列布局
123
+ col1, col2 = st.columns([2, 1])
124
+
125
+ with col1:
126
+ st.markdown("**文本中的实体:**")
127
+
128
+ # 高亮显示文本中的实体
129
+ highlighted_text = text_input
130
+ for entity, e_type in results:
131
+ color = entity_colors.get(e_type, '#CCCCCC')
132
+ highlighted_entity = f"<mark style='background-color: {color}; padding: 2px 4px; border-radius: 4px;'>{entity} [{e_type}]</mark>"
133
+ highlighted_text = highlighted_text.replace(entity, highlighted_entity)
134
+
135
+ st.markdown(highlighted_text, unsafe_allow_html=True)
136
+
137
+ with col2:
138
+ st.markdown("**检测到的实体列表:**")
139
+
140
+ if not results:
141
+ st.info("未检测到任何实体")
142
+ else:
143
+ # 按类型分组实体
144
+ entities_by_type = {}
145
+ for entity, e_type in results:
146
+ if e_type not in entities_by_type:
147
+ entities_by_type[e_type] = []
148
+ if entity not in entities_by_type[e_type]:
149
+ entities_by_type[e_type].append(entity)
150
+
151
+ # 显示每种类型的实体
152
+ for e_type, entities in entities_by_type.items():
153
+ color = entity_colors.get(e_type, '#CCCCCC')
154
+ st.markdown(f"<span style='background-color:{color}; padding:2px 6px; border-radius:4px; color:white;'>{e_type}</span>", unsafe_allow_html=True)
155
+ for entity in entities:
156
+ st.markdown(f"- {entity}")
157
+
158
+ # 显示模型信息
159
+ with st.expander("模型信息"):
160
+ st.markdown("""
161
+ **模型:** dslim/bert-base-NER
162
+ **描述:** 基于BERT的命名实体识别模型,能够识别人物(PER)、组织(ORG)、地点(LOC)和其他(MISC)实体。
163
+ """)
164
+
165
+ except Exception as e:
166
+ st.error(f"分析过程中出错: {e}")
167
+
168
+ # 添加页脚
169
+ st.markdown("---")
170
+ st.markdown("使用 🤗 Transformers 和 Streamlit 构建 | 模型: dslim/bert-base-NER")