File size: 7,285 Bytes
c1678bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import streamlit as st
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from html import escape
import os

# 修复权限问题
os.environ['STREAMLIT_CONFIG_DIR'] = os.getcwd()
os.environ['STREAMLIT_GATHER_USAGE_STATS'] = 'false'
# 设置页面
st.set_page_config(page_title="NER 实体识别", page_icon="🔍", layout="wide")

# 标题和描述
st.title("🔍 命名实体识别 (NER)")
st.markdown("使用 `dslim/bert-base-NER` 模型识别文本中的实体并分类")

# 初始化模型
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
    model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
    return tokenizer, model

# 加载模型
with st.spinner("正在加载模型,请稍候..."):
    try:
        tokenizer, model = load_model()
        st.success("模型加载成功!")
    except Exception as e:
        st.error(f"加载模型时出错: {e}")
        st.stop()

# 实体类型颜色映射
entity_colors = {
    'PER': '#FF6B6B',  # 人物 - 红色
    'ORG': '#4ECDC4',  # 组织 - 青绿色
    'LOC': '#FFD166',  # 地点 - 黄色
    'MISC': '#9E6FDC'  # 其他 - 紫色
}

# 实体类型描述
entity_descriptions = {
    'PER': '人物 (Person)',
    'ORG': '组织 (Organization)',
    'LOC': '地点 (Location)',
    'MISC': '其他 (Miscellaneous)'
}

# 显示实体类型说明
with st.expander("实体类型说明"):
    cols = st.columns(4)
    for i, (key, value) in enumerate(entity_descriptions.items()):
        color = entity_colors.get(key, '#CCCCCC')
        cols[i].markdown(f"<span style='background-color:{color}; padding:5px; border-radius:5px; color:white;'>{key}</span> - {value}", unsafe_allow_html=True)

# 输入文本
default_text = "My name is Clara and I live in Berkeley, California."
text_input = st.text_area("输入要分析的文本:", value=default_text, height=100)

# 处理按钮
if st.button("识别实体", type="primary"):
    if not text_input.strip():
        st.warning("请输入一些文本进行分析")
    else:
        with st.spinner("正在分析文本..."):
            try:
                # 对输入文本进行编码
                inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
                
                # 使用模型进行预测
                outputs = model(**inputs)
                
                # 获取预测结果
                predictions = torch.argmax(outputs.logits, dim=-1)[0]
                
                # 解码预测结果
                predicted_labels = [model.config.id2label[t.item()] for t in predictions]
                tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
                
                # 处理输出
                current_entity = None
                current_tokens = []
                results = []
                
                for token, label in zip(tokens, predicted_labels):
                    # 跳过特殊token
                    if token in ['[CLS]', '[SEP]']:
                        continue
                    
                    # 处理子词token
                    if token.startswith('##'):
                        token = token[2:]
                    
                    # 处理实体标签
                    if label != 'O':
                        entity_type = label.split('-')[-1]
                        
                        if label.startswith('B-'):
                            # 如果是新实体的开始
                            if current_entity:
                                results.append((' '.join(current_tokens), current_entity))
                            current_entity = entity_type
                            current_tokens = [token]
                        else:
                            # 继续当前实体
                            current_tokens.append(token)
                    else:
                        # 如果不是实体
                        if current_entity:
                            results.append((' '.join(current_tokens), current_entity))
                            current_entity = None
                            current_tokens = []
                
                # 添加最后一个实体
                if current_entity:
                    results.append((' '.join(current_tokens), current_entity))
                
                # 显示结果
                st.subheader("分析结果")
                
                # 创建两列布局
                col1, col2 = st.columns([2, 1])
                
                with col1:
                    st.markdown("**文本中的实体:**")
                    
                    # 高亮显示文本中的实体
                    highlighted_text = text_input
                    for entity, e_type in results:
                        color = entity_colors.get(e_type, '#CCCCCC')
                        highlighted_entity = f"<mark style='background-color: {color}; padding: 2px 4px; border-radius: 4px;'>{entity} [{e_type}]</mark>"
                        highlighted_text = highlighted_text.replace(entity, highlighted_entity)
                    
                    st.markdown(highlighted_text, unsafe_allow_html=True)
                
                with col2:
                    st.markdown("**检测到的实体列表:**")
                    
                    if not results:
                        st.info("未检测到任何实体")
                    else:
                        # 按类型分组实体
                        entities_by_type = {}
                        for entity, e_type in results:
                            if e_type not in entities_by_type:
                                entities_by_type[e_type] = []
                            if entity not in entities_by_type[e_type]:
                                entities_by_type[e_type].append(entity)
                        
                        # 显示每种类型的实体
                        for e_type, entities in entities_by_type.items():
                            color = entity_colors.get(e_type, '#CCCCCC')
                            st.markdown(f"<span style='background-color:{color}; padding:2px 6px; border-radius:4px; color:white;'>{e_type}</span>", unsafe_allow_html=True)
                            for entity in entities:
                                st.markdown(f"- {entity}")
                
                # 显示模型信息
                with st.expander("模型信息"):
                    st.markdown("""

                    **模型:** dslim/bert-base-NER  

                    **描述:** 基于BERT的命名实体识别模型,能够识别人物(PER)、组织(ORG)、地点(LOC)和其他(MISC)实体。

                    """)
                    
            except Exception as e:
                st.error(f"分析过程中出错: {e}")

# 添加页脚
st.markdown("---")
st.markdown("使用 🤗 Transformers 和 Streamlit 构建 | 模型: dslim/bert-base-NER")