Spaces:
Sleeping
Sleeping
File size: 7,285 Bytes
c1678bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import streamlit as st
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from html import escape
import os
# 修复权限问题
os.environ['STREAMLIT_CONFIG_DIR'] = os.getcwd()
os.environ['STREAMLIT_GATHER_USAGE_STATS'] = 'false'
# 设置页面
st.set_page_config(page_title="NER 实体识别", page_icon="🔍", layout="wide")
# 标题和描述
st.title("🔍 命名实体识别 (NER)")
st.markdown("使用 `dslim/bert-base-NER` 模型识别文本中的实体并分类")
# 初始化模型
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
return tokenizer, model
# 加载模型
with st.spinner("正在加载模型,请稍候..."):
try:
tokenizer, model = load_model()
st.success("模型加载成功!")
except Exception as e:
st.error(f"加载模型时出错: {e}")
st.stop()
# 实体类型颜色映射
entity_colors = {
'PER': '#FF6B6B', # 人物 - 红色
'ORG': '#4ECDC4', # 组织 - 青绿色
'LOC': '#FFD166', # 地点 - 黄色
'MISC': '#9E6FDC' # 其他 - 紫色
}
# 实体类型描述
entity_descriptions = {
'PER': '人物 (Person)',
'ORG': '组织 (Organization)',
'LOC': '地点 (Location)',
'MISC': '其他 (Miscellaneous)'
}
# 显示实体类型说明
with st.expander("实体类型说明"):
cols = st.columns(4)
for i, (key, value) in enumerate(entity_descriptions.items()):
color = entity_colors.get(key, '#CCCCCC')
cols[i].markdown(f"<span style='background-color:{color}; padding:5px; border-radius:5px; color:white;'>{key}</span> - {value}", unsafe_allow_html=True)
# 输入文本
default_text = "My name is Clara and I live in Berkeley, California."
text_input = st.text_area("输入要分析的文本:", value=default_text, height=100)
# 处理按钮
if st.button("识别实体", type="primary"):
if not text_input.strip():
st.warning("请输入一些文本进行分析")
else:
with st.spinner("正在分析文本..."):
try:
# 对输入文本进行编码
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
# 使用模型进行预测
outputs = model(**inputs)
# 获取预测结果
predictions = torch.argmax(outputs.logits, dim=-1)[0]
# 解码预测结果
predicted_labels = [model.config.id2label[t.item()] for t in predictions]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# 处理输出
current_entity = None
current_tokens = []
results = []
for token, label in zip(tokens, predicted_labels):
# 跳过特殊token
if token in ['[CLS]', '[SEP]']:
continue
# 处理子词token
if token.startswith('##'):
token = token[2:]
# 处理实体标签
if label != 'O':
entity_type = label.split('-')[-1]
if label.startswith('B-'):
# 如果是新实体的开始
if current_entity:
results.append((' '.join(current_tokens), current_entity))
current_entity = entity_type
current_tokens = [token]
else:
# 继续当前实体
current_tokens.append(token)
else:
# 如果不是实体
if current_entity:
results.append((' '.join(current_tokens), current_entity))
current_entity = None
current_tokens = []
# 添加最后一个实体
if current_entity:
results.append((' '.join(current_tokens), current_entity))
# 显示结果
st.subheader("分析结果")
# 创建两列布局
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("**文本中的实体:**")
# 高亮显示文本中的实体
highlighted_text = text_input
for entity, e_type in results:
color = entity_colors.get(e_type, '#CCCCCC')
highlighted_entity = f"<mark style='background-color: {color}; padding: 2px 4px; border-radius: 4px;'>{entity} [{e_type}]</mark>"
highlighted_text = highlighted_text.replace(entity, highlighted_entity)
st.markdown(highlighted_text, unsafe_allow_html=True)
with col2:
st.markdown("**检测到的实体列表:**")
if not results:
st.info("未检测到任何实体")
else:
# 按类型分组实体
entities_by_type = {}
for entity, e_type in results:
if e_type not in entities_by_type:
entities_by_type[e_type] = []
if entity not in entities_by_type[e_type]:
entities_by_type[e_type].append(entity)
# 显示每种类型的实体
for e_type, entities in entities_by_type.items():
color = entity_colors.get(e_type, '#CCCCCC')
st.markdown(f"<span style='background-color:{color}; padding:2px 6px; border-radius:4px; color:white;'>{e_type}</span>", unsafe_allow_html=True)
for entity in entities:
st.markdown(f"- {entity}")
# 显示模型信息
with st.expander("模型信息"):
st.markdown("""
**模型:** dslim/bert-base-NER
**描述:** 基于BERT的命名实体识别模型,能够识别人物(PER)、组织(ORG)、地点(LOC)和其他(MISC)实体。
""")
except Exception as e:
st.error(f"分析过程中出错: {e}")
# 添加页脚
st.markdown("---")
st.markdown("使用 🤗 Transformers 和 Streamlit 构建 | 模型: dslim/bert-base-NER") |