renyi211's picture
Upload app.py
c1678bf verified
import streamlit as st
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from html import escape
import os
# 修复权限问题
os.environ['STREAMLIT_CONFIG_DIR'] = os.getcwd()
os.environ['STREAMLIT_GATHER_USAGE_STATS'] = 'false'
# 设置页面
st.set_page_config(page_title="NER 实体识别", page_icon="🔍", layout="wide")
# 标题和描述
st.title("🔍 命名实体识别 (NER)")
st.markdown("使用 `dslim/bert-base-NER` 模型识别文本中的实体并分类")
# 初始化模型
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
return tokenizer, model
# 加载模型
with st.spinner("正在加载模型,请稍候..."):
try:
tokenizer, model = load_model()
st.success("模型加载成功!")
except Exception as e:
st.error(f"加载模型时出错: {e}")
st.stop()
# 实体类型颜色映射
entity_colors = {
'PER': '#FF6B6B', # 人物 - 红色
'ORG': '#4ECDC4', # 组织 - 青绿色
'LOC': '#FFD166', # 地点 - 黄色
'MISC': '#9E6FDC' # 其他 - 紫色
}
# 实体类型描述
entity_descriptions = {
'PER': '人物 (Person)',
'ORG': '组织 (Organization)',
'LOC': '地点 (Location)',
'MISC': '其他 (Miscellaneous)'
}
# 显示实体类型说明
with st.expander("实体类型说明"):
cols = st.columns(4)
for i, (key, value) in enumerate(entity_descriptions.items()):
color = entity_colors.get(key, '#CCCCCC')
cols[i].markdown(f"<span style='background-color:{color}; padding:5px; border-radius:5px; color:white;'>{key}</span> - {value}", unsafe_allow_html=True)
# 输入文本
default_text = "My name is Clara and I live in Berkeley, California."
text_input = st.text_area("输入要分析的文本:", value=default_text, height=100)
# 处理按钮
if st.button("识别实体", type="primary"):
if not text_input.strip():
st.warning("请输入一些文本进行分析")
else:
with st.spinner("正在分析文本..."):
try:
# 对输入文本进行编码
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
# 使用模型进行预测
outputs = model(**inputs)
# 获取预测结果
predictions = torch.argmax(outputs.logits, dim=-1)[0]
# 解码预测结果
predicted_labels = [model.config.id2label[t.item()] for t in predictions]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# 处理输出
current_entity = None
current_tokens = []
results = []
for token, label in zip(tokens, predicted_labels):
# 跳过特殊token
if token in ['[CLS]', '[SEP]']:
continue
# 处理子词token
if token.startswith('##'):
token = token[2:]
# 处理实体标签
if label != 'O':
entity_type = label.split('-')[-1]
if label.startswith('B-'):
# 如果是新实体的开始
if current_entity:
results.append((' '.join(current_tokens), current_entity))
current_entity = entity_type
current_tokens = [token]
else:
# 继续当前实体
current_tokens.append(token)
else:
# 如果不是实体
if current_entity:
results.append((' '.join(current_tokens), current_entity))
current_entity = None
current_tokens = []
# 添加最后一个实体
if current_entity:
results.append((' '.join(current_tokens), current_entity))
# 显示结果
st.subheader("分析结果")
# 创建两列布局
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("**文本中的实体:**")
# 高亮显示文本中的实体
highlighted_text = text_input
for entity, e_type in results:
color = entity_colors.get(e_type, '#CCCCCC')
highlighted_entity = f"<mark style='background-color: {color}; padding: 2px 4px; border-radius: 4px;'>{entity} [{e_type}]</mark>"
highlighted_text = highlighted_text.replace(entity, highlighted_entity)
st.markdown(highlighted_text, unsafe_allow_html=True)
with col2:
st.markdown("**检测到的实体列表:**")
if not results:
st.info("未检测到任何实体")
else:
# 按类型分组实体
entities_by_type = {}
for entity, e_type in results:
if e_type not in entities_by_type:
entities_by_type[e_type] = []
if entity not in entities_by_type[e_type]:
entities_by_type[e_type].append(entity)
# 显示每种类型的实体
for e_type, entities in entities_by_type.items():
color = entity_colors.get(e_type, '#CCCCCC')
st.markdown(f"<span style='background-color:{color}; padding:2px 6px; border-radius:4px; color:white;'>{e_type}</span>", unsafe_allow_html=True)
for entity in entities:
st.markdown(f"- {entity}")
# 显示模型信息
with st.expander("模型信息"):
st.markdown("""
**模型:** dslim/bert-base-NER
**描述:** 基于BERT的命名实体识别模型,能够识别人物(PER)、组织(ORG)、地点(LOC)和其他(MISC)实体。
""")
except Exception as e:
st.error(f"分析过程中出错: {e}")
# 添加页脚
st.markdown("---")
st.markdown("使用 🤗 Transformers 和 Streamlit 构建 | 模型: dslim/bert-base-NER")