Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| import torch | |
| from html import escape | |
| import os | |
| # 修复权限问题 | |
| os.environ['STREAMLIT_CONFIG_DIR'] = os.getcwd() | |
| os.environ['STREAMLIT_GATHER_USAGE_STATS'] = 'false' | |
| # 设置页面 | |
| st.set_page_config(page_title="NER 实体识别", page_icon="🔍", layout="wide") | |
| # 标题和描述 | |
| st.title("🔍 命名实体识别 (NER)") | |
| st.markdown("使用 `dslim/bert-base-NER` 模型识别文本中的实体并分类") | |
| # 初始化模型 | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") | |
| model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") | |
| return tokenizer, model | |
| # 加载模型 | |
| with st.spinner("正在加载模型,请稍候..."): | |
| try: | |
| tokenizer, model = load_model() | |
| st.success("模型加载成功!") | |
| except Exception as e: | |
| st.error(f"加载模型时出错: {e}") | |
| st.stop() | |
| # 实体类型颜色映射 | |
| entity_colors = { | |
| 'PER': '#FF6B6B', # 人物 - 红色 | |
| 'ORG': '#4ECDC4', # 组织 - 青绿色 | |
| 'LOC': '#FFD166', # 地点 - 黄色 | |
| 'MISC': '#9E6FDC' # 其他 - 紫色 | |
| } | |
| # 实体类型描述 | |
| entity_descriptions = { | |
| 'PER': '人物 (Person)', | |
| 'ORG': '组织 (Organization)', | |
| 'LOC': '地点 (Location)', | |
| 'MISC': '其他 (Miscellaneous)' | |
| } | |
| # 显示实体类型说明 | |
| with st.expander("实体类型说明"): | |
| cols = st.columns(4) | |
| for i, (key, value) in enumerate(entity_descriptions.items()): | |
| color = entity_colors.get(key, '#CCCCCC') | |
| cols[i].markdown(f"<span style='background-color:{color}; padding:5px; border-radius:5px; color:white;'>{key}</span> - {value}", unsafe_allow_html=True) | |
| # 输入文本 | |
| default_text = "My name is Clara and I live in Berkeley, California." | |
| text_input = st.text_area("输入要分析的文本:", value=default_text, height=100) | |
| # 处理按钮 | |
| if st.button("识别实体", type="primary"): | |
| if not text_input.strip(): | |
| st.warning("请输入一些文本进行分析") | |
| else: | |
| with st.spinner("正在分析文本..."): | |
| try: | |
| # 对输入文本进行编码 | |
| inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512) | |
| # 使用模型进行预测 | |
| outputs = model(**inputs) | |
| # 获取预测结果 | |
| predictions = torch.argmax(outputs.logits, dim=-1)[0] | |
| # 解码预测结果 | |
| predicted_labels = [model.config.id2label[t.item()] for t in predictions] | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| # 处理输出 | |
| current_entity = None | |
| current_tokens = [] | |
| results = [] | |
| for token, label in zip(tokens, predicted_labels): | |
| # 跳过特殊token | |
| if token in ['[CLS]', '[SEP]']: | |
| continue | |
| # 处理子词token | |
| if token.startswith('##'): | |
| token = token[2:] | |
| # 处理实体标签 | |
| if label != 'O': | |
| entity_type = label.split('-')[-1] | |
| if label.startswith('B-'): | |
| # 如果是新实体的开始 | |
| if current_entity: | |
| results.append((' '.join(current_tokens), current_entity)) | |
| current_entity = entity_type | |
| current_tokens = [token] | |
| else: | |
| # 继续当前实体 | |
| current_tokens.append(token) | |
| else: | |
| # 如果不是实体 | |
| if current_entity: | |
| results.append((' '.join(current_tokens), current_entity)) | |
| current_entity = None | |
| current_tokens = [] | |
| # 添加最后一个实体 | |
| if current_entity: | |
| results.append((' '.join(current_tokens), current_entity)) | |
| # 显示结果 | |
| st.subheader("分析结果") | |
| # 创建两列布局 | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.markdown("**文本中的实体:**") | |
| # 高亮显示文本中的实体 | |
| highlighted_text = text_input | |
| for entity, e_type in results: | |
| color = entity_colors.get(e_type, '#CCCCCC') | |
| highlighted_entity = f"<mark style='background-color: {color}; padding: 2px 4px; border-radius: 4px;'>{entity} [{e_type}]</mark>" | |
| highlighted_text = highlighted_text.replace(entity, highlighted_entity) | |
| st.markdown(highlighted_text, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown("**检测到的实体列表:**") | |
| if not results: | |
| st.info("未检测到任何实体") | |
| else: | |
| # 按类型分组实体 | |
| entities_by_type = {} | |
| for entity, e_type in results: | |
| if e_type not in entities_by_type: | |
| entities_by_type[e_type] = [] | |
| if entity not in entities_by_type[e_type]: | |
| entities_by_type[e_type].append(entity) | |
| # 显示每种类型的实体 | |
| for e_type, entities in entities_by_type.items(): | |
| color = entity_colors.get(e_type, '#CCCCCC') | |
| st.markdown(f"<span style='background-color:{color}; padding:2px 6px; border-radius:4px; color:white;'>{e_type}</span>", unsafe_allow_html=True) | |
| for entity in entities: | |
| st.markdown(f"- {entity}") | |
| # 显示模型信息 | |
| with st.expander("模型信息"): | |
| st.markdown(""" | |
| **模型:** dslim/bert-base-NER | |
| **描述:** 基于BERT的命名实体识别模型,能够识别人物(PER)、组织(ORG)、地点(LOC)和其他(MISC)实体。 | |
| """) | |
| except Exception as e: | |
| st.error(f"分析过程中出错: {e}") | |
| # 添加页脚 | |
| st.markdown("---") | |
| st.markdown("使用 🤗 Transformers 和 Streamlit 构建 | 模型: dslim/bert-base-NER") |