| import chromadb |
| import pandas as pd |
| from datasets import load_dataset |
| import json |
|
|
| print("Setting up ChromaDB...") |
|
|
| |
| client = chromadb.PersistentClient(path="./patient_db") |
|
|
| |
| collection = client.get_or_create_collection( |
| name="clinical_notes", |
| metadata={"hnsw:space": "cosine"} |
| ) |
|
|
| print("Loading dataset...") |
| dataset = load_dataset("AGBonnet/augmented-clinical-notes", split="train[:200]") |
| df = dataset.to_pandas() |
|
|
| print("Storing notes into ChromaDB...") |
|
|
| success = 0 |
| skipped = 0 |
|
|
| for i, row in df.iterrows(): |
| try: |
| |
| summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else row["summary"] |
|
|
| |
| patient_info = summary.get("patient information", {}) |
| age = patient_info.get("age", "Unknown") |
| sex = patient_info.get("sex", "Unknown") |
| visit_motivation = summary.get("visit motivation", "Unknown") |
|
|
| |
| |
| collection.add( |
| documents=[row["full_note"]], |
| metadatas=[{ |
| "idx": str(row["idx"]), |
| "age": str(age), |
| "sex": str(sex), |
| "visit_motivation": str(visit_motivation)[:200] |
| }], |
| ids=[str(row["idx"])] |
| ) |
| success += 1 |
|
|
| except Exception as e: |
| skipped += 1 |
| continue |
|
|
| print(f"\nβ
Successfully stored: {success} notes") |
| print(f"β οΈ Skipped: {skipped} rows") |
| print(f"\nTotal notes in ChromaDB: {collection.count()}") |
|
|
| |
|
|
| def retrieve_similar_notes(query: str, n_results: int = 3): |
| """ |
| Given a transcript or query, find the most similar |
| clinical notes stored in ChromaDB |
| """ |
| results = collection.query( |
| query_texts=[query], |
| n_results=n_results |
| ) |
|
|
| notes = results["documents"][0] |
| metadatas = results["metadatas"][0] |
|
|
| print(f"\nπ Top {n_results} similar notes retrieved:") |
| for i, (note, meta) in enumerate(zip(notes, metadatas)): |
| print(f"\n[{i+1}] Age: {meta['age']} | Sex: {meta['sex']}") |
| print(f" Visit: {meta['visit_motivation'][:100]}") |
| print(f" Note preview: {note[:150]}...") |
|
|
| return notes, metadatas |
|
|
|
|
| |
| if __name__ == "__main__": |
| test_query = "Patient has neck pain and difficulty walking" |
| retrieve_similar_notes(test_query) |