|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Pre-embed Patient Summaries Script |
|
|
|
|
|
This script pre-processes and embeds a patient database, |
|
|
saving the results to a single Parquet file for faster loading |
|
|
in the main application and compatibility with Hugging Face datasets. |
|
|
|
|
|
Usage: |
|
|
python preembed_patients.py --patients ../v20_public_data/patient_summaries_and_their_spaces.parquet --embedder ksg-dfci/TrialSpace-1225 --output synthetic_patient_embeddings.parquet --gpus 0,1 --patient-boilerplate-col patient_boilerplate_text --patient-id-col pseudo_mrn |
|
|
|
|
|
|
|
|
This will create: |
|
|
- synthetic_patient_embeddings.parquet: Patient dataframe with embedding vectors as a column |
|
|
|
|
|
The parquet file contains: |
|
|
- All original patient columns (patient_id, patient_summary, patient_boilerplate, etc.) |
|
|
- patient_embedding: The embedding vector for each patient (stored as list of floats) |
|
|
- Metadata stored in parquet file metadata (embedder model, creation date, etc.) |
|
|
|
|
|
To upload to Hugging Face: |
|
|
from datasets import Dataset |
|
|
ds = Dataset.from_parquet("synthetic_patient_embeddings.parquet") |
|
|
ds.push_to_hub("your-username/patient-embeddings") |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import json |
|
|
import pyarrow as pa |
|
|
import pyarrow.parquet as pq |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
from typing import Tuple, List |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str: |
|
|
"""Truncate text to a maximum number of tokens.""" |
|
|
return tokenizer.decode( |
|
|
tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens), |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
def load_patients(file_path: str, patient_id_col: str = 'patient_id', patient_boilerplate_col: str = 'patient_boilerplate') -> pd.DataFrame: |
|
|
"""Load patients from parquet file.""" |
|
|
print(f"\n{'='*70}") |
|
|
print(f"Loading patient database from: {file_path}") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
if file_path.endswith('.parquet'): |
|
|
df = pd.read_parquet(file_path) |
|
|
elif file_path.endswith('.csv'): |
|
|
df = pd.read_csv(file_path) |
|
|
elif file_path.endswith(('.xlsx', '.xls')): |
|
|
df = pd.read_excel(file_path) |
|
|
else: |
|
|
raise ValueError("Unsupported file format. Use Parquet, CSV, or Excel.") |
|
|
|
|
|
|
|
|
required_cols = [patient_id_col, 'patient_summary'] |
|
|
missing = [col for col in required_cols if col not in df.columns] |
|
|
if missing: |
|
|
raise ValueError(f"Missing required columns: {', '.join(missing)}") |
|
|
|
|
|
|
|
|
if patient_id_col != 'patient_id': |
|
|
df = df.rename(columns={patient_id_col: 'patient_id'}) |
|
|
print(f" Renamed column '{patient_id_col}' to 'patient_id'") |
|
|
|
|
|
print(f"β Loaded {len(df)} patients") |
|
|
print(f" Columns: {', '.join(df.columns.tolist())}") |
|
|
|
|
|
|
|
|
original_count = len(df) |
|
|
df = df[~df['patient_summary'].isnull()].copy() |
|
|
df = df[df['patient_summary'].str.strip().str.len() > 0].copy() |
|
|
|
|
|
|
|
|
if patient_boilerplate_col and patient_boilerplate_col in df.columns: |
|
|
if patient_boilerplate_col != 'patient_boilerplate': |
|
|
df = df.rename(columns={patient_boilerplate_col: 'patient_boilerplate'}) |
|
|
print(f" Renamed column '{patient_boilerplate_col}' to 'patient_boilerplate'") |
|
|
df['patient_boilerplate'] = df['patient_boilerplate'].fillna('') |
|
|
non_empty_bp = (df['patient_boilerplate'].str.strip().str.len() > 0).sum() |
|
|
print(f" β Found patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text") |
|
|
else: |
|
|
df['patient_boilerplate'] = '' |
|
|
if patient_boilerplate_col: |
|
|
print(f" β Column '{patient_boilerplate_col}' not found - patient_boilerplate will be empty") |
|
|
else: |
|
|
print(f" β No boilerplate column specified - patient_boilerplate will be empty") |
|
|
|
|
|
if len(df) < original_count: |
|
|
print(f" β Removed {original_count - len(df)} patients with missing/empty 'patient_summary'") |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def embed_patients(df: pd.DataFrame, embedder_path: str, device: str = None, gpus: list = None) -> Tuple[np.ndarray, str]: |
|
|
"""Embed patient summaries using the specified embedder model. |
|
|
|
|
|
Args: |
|
|
df: DataFrame with patient data |
|
|
embedder_path: Path to embedder model |
|
|
device: Single device string (e.g., 'cuda:0', 'cpu') - used if gpus not specified |
|
|
gpus: List of GPU indices for multi-GPU parallel processing (e.g., [0, 1, 2, 3]) |
|
|
""" |
|
|
print(f"\n{'='*70}") |
|
|
print(f"Loading embedder model: {embedder_path}") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
|
|
|
use_multi_gpu = gpus is not None and len(gpus) > 1 |
|
|
|
|
|
if use_multi_gpu: |
|
|
target_devices = [f"cuda:{gpu}" for gpu in gpus] |
|
|
print(f"Multi-GPU mode: {target_devices}") |
|
|
|
|
|
embedder_model = SentenceTransformer(embedder_path, device='cpu', trust_remote_code=True) |
|
|
else: |
|
|
if gpus is not None and len(gpus) == 1: |
|
|
device = f"cuda:{gpus[0]}" |
|
|
elif device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Device: {device}") |
|
|
embedder_model = SentenceTransformer(embedder_path, device=device, trust_remote_code=True) |
|
|
|
|
|
embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_path, trust_remote_code=True) |
|
|
|
|
|
print(f"β Embedder loaded") |
|
|
|
|
|
|
|
|
try: |
|
|
embedder_model.prompts['query'] = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial." |
|
|
) |
|
|
except: |
|
|
pass |
|
|
|
|
|
try: |
|
|
embedder_model.max_seq_length = 2500 |
|
|
except: |
|
|
pass |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"Embedding {len(df)} patient summaries") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
|
|
|
df['patient_summary_trunc'] = df['patient_summary'].apply( |
|
|
lambda x: truncate_text(str(x), embedder_tokenizer, max_tokens=1500) |
|
|
) |
|
|
|
|
|
|
|
|
prefix = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial. " |
|
|
) |
|
|
texts_to_embed = [prefix + txt for txt in df['patient_summary_trunc'].tolist()] |
|
|
|
|
|
print(f" Text length stats:") |
|
|
print(f" Mean: {np.mean([len(t) for t in texts_to_embed]):.0f} chars") |
|
|
print(f" Max: {max([len(t) for t in texts_to_embed])} chars") |
|
|
|
|
|
|
|
|
if use_multi_gpu: |
|
|
print(f" Starting multi-GPU pool on {target_devices}...") |
|
|
pool = embedder_model.start_multi_process_pool(target_devices=target_devices) |
|
|
|
|
|
try: |
|
|
embeddings_np = embedder_model.encode_multi_process( |
|
|
texts_to_embed, |
|
|
pool, |
|
|
batch_size=64, |
|
|
normalize_embeddings=True, |
|
|
) |
|
|
finally: |
|
|
embedder_model.stop_multi_process_pool(pool) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
embeddings = embedder_model.encode( |
|
|
texts_to_embed, |
|
|
batch_size=64, |
|
|
convert_to_tensor=True, |
|
|
normalize_embeddings=True, |
|
|
show_progress_bar=True, |
|
|
prompt='query' |
|
|
) |
|
|
embeddings_np = embeddings.cpu().numpy() |
|
|
|
|
|
print(f"β Embedding complete") |
|
|
print(f" Shape: {embeddings_np.shape}") |
|
|
print(f" Dtype: {embeddings_np.dtype}") |
|
|
|
|
|
return embeddings_np, embedder_path |
|
|
|
|
|
|
|
|
def save_embeddings(df: pd.DataFrame, embeddings: np.ndarray, output_path: str, embedder_path: str, gpus: list = None): |
|
|
"""Save patient data with embeddings to a single Parquet file. |
|
|
|
|
|
The embeddings are stored as a column of lists, which is compatible with |
|
|
Hugging Face datasets and PyArrow. |
|
|
""" |
|
|
print(f"\n{'='*70}") |
|
|
print(f"Saving to: {output_path}") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
|
|
|
if not output_path.endswith('.parquet'): |
|
|
output_path = f"{output_path}.parquet" |
|
|
|
|
|
output_dir = Path(output_path).parent |
|
|
if str(output_dir) and str(output_dir) != '.': |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
df_out = df.copy() |
|
|
df_out['patient_embedding'] = [emb.tolist() for emb in embeddings] |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"created_at": datetime.now().isoformat(), |
|
|
"embedder_model": embedder_path, |
|
|
"num_patients": str(len(df)), |
|
|
"embedding_dim": str(embeddings.shape[1]), |
|
|
"embedding_dtype": str(embeddings.dtype), |
|
|
"normalized": "true", |
|
|
"gpus_used": str(gpus) if gpus else "single device", |
|
|
"format_version": "2.0", |
|
|
} |
|
|
|
|
|
|
|
|
table = pa.Table.from_pandas(df_out) |
|
|
|
|
|
|
|
|
existing_metadata = table.schema.metadata or {} |
|
|
existing_metadata[b'patient_embedding_metadata'] = json.dumps(metadata).encode('utf-8') |
|
|
table = table.replace_schema_metadata(existing_metadata) |
|
|
|
|
|
|
|
|
pq.write_table(table, output_path) |
|
|
|
|
|
file_size_mb = Path(output_path).stat().st_size / 1024 / 1024 |
|
|
print(f"β Saved parquet file: {output_path}") |
|
|
print(f" Size: {file_size_mb:.2f} MB") |
|
|
print(f" Columns: {', '.join(df_out.columns.tolist())}") |
|
|
print(f" Embedding column: patient_embedding (dim={embeddings.shape[1]})") |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"PRE-EMBEDDING COMPLETE") |
|
|
print(f"{'='*70}") |
|
|
print(f"\nTo use these pre-embedded patients in your app:") |
|
|
print(f"1. Update config.py with:") |
|
|
print(f" PREEMBEDDED_PATIENTS = '{output_path}'") |
|
|
print(f"2. Restart the application") |
|
|
print(f"\nThe app will automatically load these embeddings on startup!") |
|
|
print(f"\nTo upload to Hugging Face Hub:") |
|
|
print(f" from datasets import Dataset") |
|
|
print(f" ds = Dataset.from_parquet('{output_path}')") |
|
|
print(f" ds.push_to_hub('your-username/patient-embeddings')") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Pre-embed patient summaries for faster loading", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
python preembed_patients.py --patients data/patients.parquet --embedder models/embedder --output embeddings/patient_embeddings.parquet |
|
|
python preembed_patients.py --patients patients.csv --embedder Qwen/Qwen3-Embedding-0.6B --output patient_embeddings.parquet --device cuda |
|
|
python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --patient-id-col mrn |
|
|
python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --gpus 0,1,2,3 |
|
|
python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --patient-boilerplate-col boilerplate_summary |
|
|
|
|
|
Hugging Face Upload: |
|
|
After creating the parquet file, you can upload to Hugging Face Hub: |
|
|
from datasets import Dataset |
|
|
ds = Dataset.from_parquet("patient_embeddings.parquet") |
|
|
ds.push_to_hub("your-username/patient-embeddings") |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--patients', |
|
|
type=str, |
|
|
required=True, |
|
|
help='Path to patient database (Parquet, CSV, or Excel). Required columns: patient_summary and the patient ID column (default: patient_id, or specify with --patient-id-col)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--embedder', |
|
|
type=str, |
|
|
required=True, |
|
|
help='Path to embedder model or HuggingFace model name' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--output', |
|
|
type=str, |
|
|
required=True, |
|
|
help='Output path for the parquet file (e.g., "patient_embeddings.parquet")' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
default=None, |
|
|
help='Device to use for embedding (default: auto-detect). Examples: cuda, cuda:0, cuda:3, cpu. Ignored if --gpus is specified.' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--patient-id-col', |
|
|
type=str, |
|
|
default='patient_id', |
|
|
help='Name of the patient ID column in the input file (default: patient_id)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--patient-boilerplate-col', |
|
|
type=str, |
|
|
default='patient_boilerplate', |
|
|
help='Name of the patient boilerplate column in the input file (default: patient_boilerplate). Set to empty string to skip.' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--gpus', |
|
|
type=str, |
|
|
default=None, |
|
|
help='Comma-separated list of GPU indices for multi-GPU parallel processing (e.g., "0,1,2,3"). Overrides --device if specified.' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
gpu_list = None |
|
|
if args.gpus: |
|
|
try: |
|
|
gpu_list = [int(g.strip()) for g in args.gpus.split(',')] |
|
|
except ValueError: |
|
|
print(f"β ERROR: Invalid GPU list format: {args.gpus}") |
|
|
print(" Use comma-separated integers, e.g., '0,1,2,3'") |
|
|
return 1 |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"PATIENT SUMMARY PRE-EMBEDDING SCRIPT") |
|
|
print(f"{'='*70}") |
|
|
print(f"Patient Database: {args.patients}") |
|
|
print(f"Embedder Model: {args.embedder}") |
|
|
print(f"Output File: {args.output}") |
|
|
print(f"Patient ID Col: {args.patient_id_col}") |
|
|
print(f"Boilerplate Col: {args.patient_boilerplate_col or '(none)'}") |
|
|
if gpu_list: |
|
|
print(f"GPUs: {gpu_list} (multi-GPU mode)") |
|
|
elif args.device: |
|
|
print(f"Device: {args.device}") |
|
|
else: |
|
|
print(f"Device: auto-detect") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
try: |
|
|
|
|
|
df = load_patients(args.patients, args.patient_id_col, args.patient_boilerplate_col) |
|
|
|
|
|
|
|
|
embeddings, embedder_path = embed_patients(df, args.embedder, args.device, gpu_list) |
|
|
|
|
|
|
|
|
save_embeddings(df, embeddings, args.output, embedder_path, gpu_list) |
|
|
|
|
|
print(f"\nβ SUCCESS!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ ERROR: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|