| import os |
| import torch |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModel |
| import chromadb |
| |
| import gradio as gr |
|
|
| |
| |
|
|
| |
| |
| dataset = load_dataset("thankrandomness/mimic-iii-sample") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") |
| model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") |
|
|
| |
| |
| |
| |
| |
| |
| |
| def embed_text(text, max_length=512): |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) |
| with torch.no_grad(): |
| embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze() |
| return embeddings.numpy() |
|
|
| |
| client = chromadb.Client() |
| collection = client.create_collection(name="pubmedbert_embeddings") |
|
|
| |
| for i, row in enumerate(dataset['train']): |
| for note in row['notes']: |
| text = note.get('text', '') |
| annotations_list = [] |
| |
| for annotation in note.get('annotations', []): |
| try: |
| code = annotation['code'] |
| code_system = annotation['code_system'] |
| description = annotation['description'] |
| |
| annotations_list.append({"code": code, "code_system": code_system, "description": description}) |
| except KeyError as e: |
| print(f"Skipping annotation due to missing key: {e}") |
|
|
| print(f"Processed annotations for note {note['note_id']}: {annotations_list}") |
|
|
| if text and annotations_list: |
| embeddings = embed_text([text])[0] |
|
|
| |
| for j, annotation in enumerate(annotations_list): |
| collection.upsert( |
| ids=[f"note_{note['note_id']}_{j}"], |
| embeddings=[embeddings], |
| metadatas=[annotation] |
| ) |
| else: |
| print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'") |
|
|
| |
| def retrieve_relevant_text(input_text): |
| input_embedding = embed_text([input_text])[0] |
| results = collection.query(query_embeddings=[input_embedding], n_results=5) |
| print(results) |
| |
| output = [] |
| for result in results['results']: |
| print(result) |
| for annotation in result["metadata"]["annotations"]: |
| output.append({ |
| "similarity_score": result["distances"], |
| "annotation": annotation |
| }) |
| return output |
|
|
| |
| def gradio_interface(input_text): |
| results = retrieve_relevant_text(input_text) |
| formatted_results = [ |
| f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}" |
| for result in results |
| ] |
| return formatted_results |
|
|
| interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text") |
| interface.launch() |