| import os |
| import json |
| import streamlit as st |
| from transformers import AutoTokenizer, AutoModel |
| import torch |
| import numpy as np |
| import faiss |
|
|
| |
| st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide') |
| st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)') |
|
|
| |
| METADATA_PATH = 'umls_metadata.json' |
| EMBED_PATH = 'umls_embeddings.npy' |
| INDEX_PATH = 'umls_index.faiss' |
| MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL' |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModel.from_pretrained(MODEL_NAME) |
| model.eval() |
| return tokenizer, model |
|
|
| tokenizer, model = load_model() |
|
|
| |
| @st.cache_resource |
| def load_umls_index(): |
| meta = json.load(open(METADATA_PATH, 'r')) |
| embeddings = np.load(EMBED_PATH) |
| index = faiss.read_index(INDEX_PATH) |
| return index, meta |
|
|
| faiss_index, umls_meta = load_umls_index() |
|
|
| |
| @st.cache_resource |
| def embed_text(text, _tokenizer, _model): |
| inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True) |
| with torch.no_grad(): |
| outputs = _model(**inputs) |
| emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy() |
| return emb / np.linalg.norm(emb) |
|
|
| |
| st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:') |
| examples = [ |
| 'The patient was administered metformin for type 2 diabetes.', |
| 'ER crowding has become a widespread issue in hospitals.', |
| 'Tamoxifen is used in the treatment of ER-positive breast cancer.' |
| ] |
| selected = st.selectbox('🔍 Example queries', ['Choose...'] + examples) |
| sentence = st.text_area('📝 Sentence:', value=(selected if selected != 'Choose...' else '')) |
|
|
| if st.button('🔗 Link Entities'): |
| if not sentence.strip(): |
| st.warning('Please enter a sentence first.') |
| else: |
| with st.spinner('Embedding sentence and searching FAISS…'): |
| sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1) |
| distances, indices = faiss_index.search(sent_emb, 5) |
| results = [] |
| for idx in indices[0]: |
| entry = umls_meta.get(str(idx), {}) |
| results.append({ |
| 'cui': entry.get('cui', ''), |
| 'name': entry.get('name', ''), |
| 'definition': entry.get('definition', ''), |
| 'source': entry.get('source', '') |
| }) |
| |
| if results: |
| st.success('Top UMLS candidates:') |
| for item in results: |
| st.markdown(f"**{item['name']}** (CUI: `{item['cui']}`)") |
| if item['definition']: |
| st.markdown(f"> {item['definition']}\n") |
| st.markdown(f"_Source: {item['source']}_\n---") |
| else: |
| st.info('No matches found in UMLS index.') |