import streamlit as st import transformers from transformers import AutoTokenizer, AutoModelForTokenClassification import torch from html import escape import os # 修复权限问题 os.environ['STREAMLIT_CONFIG_DIR'] = os.getcwd() os.environ['STREAMLIT_GATHER_USAGE_STATS'] = 'false' # 设置页面 st.set_page_config(page_title="NER 实体识别", page_icon="🔍", layout="wide") # 标题和描述 st.title("🔍 命名实体识别 (NER)") st.markdown("使用 `dslim/bert-base-NER` 模型识别文本中的实体并分类") # 初始化模型 @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") return tokenizer, model # 加载模型 with st.spinner("正在加载模型,请稍候..."): try: tokenizer, model = load_model() st.success("模型加载成功!") except Exception as e: st.error(f"加载模型时出错: {e}") st.stop() # 实体类型颜色映射 entity_colors = { 'PER': '#FF6B6B', # 人物 - 红色 'ORG': '#4ECDC4', # 组织 - 青绿色 'LOC': '#FFD166', # 地点 - 黄色 'MISC': '#9E6FDC' # 其他 - 紫色 } # 实体类型描述 entity_descriptions = { 'PER': '人物 (Person)', 'ORG': '组织 (Organization)', 'LOC': '地点 (Location)', 'MISC': '其他 (Miscellaneous)' } # 显示实体类型说明 with st.expander("实体类型说明"): cols = st.columns(4) for i, (key, value) in enumerate(entity_descriptions.items()): color = entity_colors.get(key, '#CCCCCC') cols[i].markdown(f"{key} - {value}", unsafe_allow_html=True) # 输入文本 default_text = "My name is Clara and I live in Berkeley, California." text_input = st.text_area("输入要分析的文本:", value=default_text, height=100) # 处理按钮 if st.button("识别实体", type="primary"): if not text_input.strip(): st.warning("请输入一些文本进行分析") else: with st.spinner("正在分析文本..."): try: # 对输入文本进行编码 inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512) # 使用模型进行预测 outputs = model(**inputs) # 获取预测结果 predictions = torch.argmax(outputs.logits, dim=-1)[0] # 解码预测结果 predicted_labels = [model.config.id2label[t.item()] for t in predictions] tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # 处理输出 current_entity = None current_tokens = [] results = [] for token, label in zip(tokens, predicted_labels): # 跳过特殊token if token in ['[CLS]', '[SEP]']: continue # 处理子词token if token.startswith('##'): token = token[2:] # 处理实体标签 if label != 'O': entity_type = label.split('-')[-1] if label.startswith('B-'): # 如果是新实体的开始 if current_entity: results.append((' '.join(current_tokens), current_entity)) current_entity = entity_type current_tokens = [token] else: # 继续当前实体 current_tokens.append(token) else: # 如果不是实体 if current_entity: results.append((' '.join(current_tokens), current_entity)) current_entity = None current_tokens = [] # 添加最后一个实体 if current_entity: results.append((' '.join(current_tokens), current_entity)) # 显示结果 st.subheader("分析结果") # 创建两列布局 col1, col2 = st.columns([2, 1]) with col1: st.markdown("**文本中的实体:**") # 高亮显示文本中的实体 highlighted_text = text_input for entity, e_type in results: color = entity_colors.get(e_type, '#CCCCCC') highlighted_entity = f"{entity} [{e_type}]" highlighted_text = highlighted_text.replace(entity, highlighted_entity) st.markdown(highlighted_text, unsafe_allow_html=True) with col2: st.markdown("**检测到的实体列表:**") if not results: st.info("未检测到任何实体") else: # 按类型分组实体 entities_by_type = {} for entity, e_type in results: if e_type not in entities_by_type: entities_by_type[e_type] = [] if entity not in entities_by_type[e_type]: entities_by_type[e_type].append(entity) # 显示每种类型的实体 for e_type, entities in entities_by_type.items(): color = entity_colors.get(e_type, '#CCCCCC') st.markdown(f"{e_type}", unsafe_allow_html=True) for entity in entities: st.markdown(f"- {entity}") # 显示模型信息 with st.expander("模型信息"): st.markdown(""" **模型:** dslim/bert-base-NER **描述:** 基于BERT的命名实体识别模型,能够识别人物(PER)、组织(ORG)、地点(LOC)和其他(MISC)实体。 """) except Exception as e: st.error(f"分析过程中出错: {e}") # 添加页脚 st.markdown("---") st.markdown("使用 🤗 Transformers 和 Streamlit 构建 | 模型: dslim/bert-base-NER")