Spaces:
Running
on
L4
Running
on
L4
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Pre-embed Patient Summaries Script | |
| This script pre-processes and embeds a patient database, | |
| saving the results to a single Parquet file for faster loading | |
| in the main application and compatibility with Hugging Face datasets. | |
| Usage: | |
| python preembed_patients.py --patients ../v20_public_data/patient_summaries_and_their_spaces.parquet --embedder ksg-dfci/TrialSpace-1225 --output synthetic_patient_embeddings.parquet --gpus 0,1 --patient-boilerplate-col patient_boilerplate_text --patient-id-col pseudo_mrn | |
| This will create: | |
| - synthetic_patient_embeddings.parquet: Patient dataframe with embedding vectors as a column | |
| The parquet file contains: | |
| - All original patient columns (patient_id, patient_summary, patient_boilerplate, etc.) | |
| - patient_embedding: The embedding vector for each patient (stored as list of floats) | |
| - Metadata stored in parquet file metadata (embedder model, creation date, etc.) | |
| To upload to Hugging Face: | |
| from datasets import Dataset | |
| ds = Dataset.from_parquet("synthetic_patient_embeddings.parquet") | |
| ds.push_to_hub("your-username/patient-embeddings") | |
| """ | |
| import argparse | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import json | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Tuple, List | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer | |
| def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str: | |
| """Truncate text to a maximum number of tokens.""" | |
| return tokenizer.decode( | |
| tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens), | |
| skip_special_tokens=True | |
| ) | |
| def load_patients(file_path: str, patient_id_col: str = 'patient_id', patient_boilerplate_col: str = 'patient_boilerplate') -> pd.DataFrame: | |
| """Load patients from parquet file.""" | |
| print(f"\n{'='*70}") | |
| print(f"Loading patient database from: {file_path}") | |
| print(f"{'='*70}") | |
| if file_path.endswith('.parquet'): | |
| df = pd.read_parquet(file_path) | |
| elif file_path.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| elif file_path.endswith(('.xlsx', '.xls')): | |
| df = pd.read_excel(file_path) | |
| else: | |
| raise ValueError("Unsupported file format. Use Parquet, CSV, or Excel.") | |
| # Check required columns | |
| required_cols = [patient_id_col, 'patient_summary'] | |
| missing = [col for col in required_cols if col not in df.columns] | |
| if missing: | |
| raise ValueError(f"Missing required columns: {', '.join(missing)}") | |
| # Rename patient_id column to standard name if different | |
| if patient_id_col != 'patient_id': | |
| df = df.rename(columns={patient_id_col: 'patient_id'}) | |
| print(f" Renamed column '{patient_id_col}' to 'patient_id'") | |
| print(f"β Loaded {len(df)} patients") | |
| print(f" Columns: {', '.join(df.columns.tolist())}") | |
| # Clean data | |
| original_count = len(df) | |
| df = df[~df['patient_summary'].isnull()].copy() | |
| df = df[df['patient_summary'].str.strip().str.len() > 0].copy() | |
| # Handle boilerplate column | |
| if patient_boilerplate_col and patient_boilerplate_col in df.columns: | |
| if patient_boilerplate_col != 'patient_boilerplate': | |
| df = df.rename(columns={patient_boilerplate_col: 'patient_boilerplate'}) | |
| print(f" Renamed column '{patient_boilerplate_col}' to 'patient_boilerplate'") | |
| df['patient_boilerplate'] = df['patient_boilerplate'].fillna('') | |
| non_empty_bp = (df['patient_boilerplate'].str.strip().str.len() > 0).sum() | |
| print(f" β Found patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text") | |
| else: | |
| df['patient_boilerplate'] = '' | |
| if patient_boilerplate_col: | |
| print(f" β Column '{patient_boilerplate_col}' not found - patient_boilerplate will be empty") | |
| else: | |
| print(f" β No boilerplate column specified - patient_boilerplate will be empty") | |
| if len(df) < original_count: | |
| print(f" β Removed {original_count - len(df)} patients with missing/empty 'patient_summary'") | |
| return df | |
| def embed_patients(df: pd.DataFrame, embedder_path: str, device: str = None, gpus: list = None) -> Tuple[np.ndarray, str]: | |
| """Embed patient summaries using the specified embedder model. | |
| Args: | |
| df: DataFrame with patient data | |
| embedder_path: Path to embedder model | |
| device: Single device string (e.g., 'cuda:0', 'cpu') - used if gpus not specified | |
| gpus: List of GPU indices for multi-GPU parallel processing (e.g., [0, 1, 2, 3]) | |
| """ | |
| print(f"\n{'='*70}") | |
| print(f"Loading embedder model: {embedder_path}") | |
| print(f"{'='*70}") | |
| # Determine device configuration | |
| use_multi_gpu = gpus is not None and len(gpus) > 1 | |
| if use_multi_gpu: | |
| target_devices = [f"cuda:{gpu}" for gpu in gpus] | |
| print(f"Multi-GPU mode: {target_devices}") | |
| # Load model on CPU first for multi-process pool | |
| embedder_model = SentenceTransformer(embedder_path, device='cpu', trust_remote_code=True) | |
| else: | |
| if gpus is not None and len(gpus) == 1: | |
| device = f"cuda:{gpus[0]}" | |
| elif device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| embedder_model = SentenceTransformer(embedder_path, device=device, trust_remote_code=True) | |
| embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_path, trust_remote_code=True) | |
| print(f"β Embedder loaded") | |
| # Set the instruction prompt | |
| try: | |
| embedder_model.prompts['query'] = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial." | |
| ) | |
| except: | |
| pass | |
| try: | |
| embedder_model.max_seq_length = 2500 | |
| except: | |
| pass | |
| print(f"\n{'='*70}") | |
| print(f"Embedding {len(df)} patient summaries") | |
| print(f"{'='*70}") | |
| # Prepare texts for embedding | |
| df['patient_summary_trunc'] = df['patient_summary'].apply( | |
| lambda x: truncate_text(str(x), embedder_tokenizer, max_tokens=1500) | |
| ) | |
| # Add instruction prefix | |
| prefix = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial. " | |
| ) | |
| texts_to_embed = [prefix + txt for txt in df['patient_summary_trunc'].tolist()] | |
| print(f" Text length stats:") | |
| print(f" Mean: {np.mean([len(t) for t in texts_to_embed]):.0f} chars") | |
| print(f" Max: {max([len(t) for t in texts_to_embed])} chars") | |
| # Embed with progress bar | |
| if use_multi_gpu: | |
| print(f" Starting multi-GPU pool on {target_devices}...") | |
| pool = embedder_model.start_multi_process_pool(target_devices=target_devices) | |
| try: | |
| embeddings_np = embedder_model.encode_multi_process( | |
| texts_to_embed, | |
| pool, | |
| batch_size=64, | |
| normalize_embeddings=True, | |
| ) | |
| finally: | |
| embedder_model.stop_multi_process_pool(pool) | |
| else: | |
| with torch.no_grad(): | |
| embeddings = embedder_model.encode( | |
| texts_to_embed, | |
| batch_size=64, | |
| convert_to_tensor=True, | |
| normalize_embeddings=True, | |
| show_progress_bar=True, | |
| prompt='query' | |
| ) | |
| embeddings_np = embeddings.cpu().numpy() | |
| print(f"β Embedding complete") | |
| print(f" Shape: {embeddings_np.shape}") | |
| print(f" Dtype: {embeddings_np.dtype}") | |
| return embeddings_np, embedder_path | |
| def save_embeddings(df: pd.DataFrame, embeddings: np.ndarray, output_path: str, embedder_path: str, gpus: list = None): | |
| """Save patient data with embeddings to a single Parquet file. | |
| The embeddings are stored as a column of lists, which is compatible with | |
| Hugging Face datasets and PyArrow. | |
| """ | |
| print(f"\n{'='*70}") | |
| print(f"Saving to: {output_path}") | |
| print(f"{'='*70}") | |
| # Ensure output path ends with .parquet | |
| if not output_path.endswith('.parquet'): | |
| output_path = f"{output_path}.parquet" | |
| output_dir = Path(output_path).parent | |
| if str(output_dir) and str(output_dir) != '.': | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Add embeddings as a column (convert numpy arrays to lists for parquet compatibility) | |
| df_out = df.copy() | |
| df_out['patient_embedding'] = [emb.tolist() for emb in embeddings] | |
| # Create metadata dictionary | |
| metadata = { | |
| "created_at": datetime.now().isoformat(), | |
| "embedder_model": embedder_path, | |
| "num_patients": str(len(df)), | |
| "embedding_dim": str(embeddings.shape[1]), | |
| "embedding_dtype": str(embeddings.dtype), | |
| "normalized": "true", | |
| "gpus_used": str(gpus) if gpus else "single device", | |
| "format_version": "2.0", # Version indicator for the new format | |
| } | |
| # Convert DataFrame to PyArrow Table | |
| table = pa.Table.from_pandas(df_out) | |
| # Add metadata to the table schema | |
| existing_metadata = table.schema.metadata or {} | |
| existing_metadata[b'patient_embedding_metadata'] = json.dumps(metadata).encode('utf-8') | |
| table = table.replace_schema_metadata(existing_metadata) | |
| # Write to parquet | |
| pq.write_table(table, output_path) | |
| file_size_mb = Path(output_path).stat().st_size / 1024 / 1024 | |
| print(f"β Saved parquet file: {output_path}") | |
| print(f" Size: {file_size_mb:.2f} MB") | |
| print(f" Columns: {', '.join(df_out.columns.tolist())}") | |
| print(f" Embedding column: patient_embedding (dim={embeddings.shape[1]})") | |
| print(f"\n{'='*70}") | |
| print(f"PRE-EMBEDDING COMPLETE") | |
| print(f"{'='*70}") | |
| print(f"\nTo use these pre-embedded patients in your app:") | |
| print(f"1. Update config.py with:") | |
| print(f" PREEMBEDDED_PATIENTS = '{output_path}'") | |
| print(f"2. Restart the application") | |
| print(f"\nThe app will automatically load these embeddings on startup!") | |
| print(f"\nTo upload to Hugging Face Hub:") | |
| print(f" from datasets import Dataset") | |
| print(f" ds = Dataset.from_parquet('{output_path}')") | |
| print(f" ds.push_to_hub('your-username/patient-embeddings')") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Pre-embed patient summaries for faster loading", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python preembed_patients.py --patients data/patients.parquet --embedder models/embedder --output embeddings/patient_embeddings.parquet | |
| python preembed_patients.py --patients patients.csv --embedder Qwen/Qwen3-Embedding-0.6B --output patient_embeddings.parquet --device cuda | |
| python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --patient-id-col mrn | |
| python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --gpus 0,1,2,3 | |
| python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --patient-boilerplate-col boilerplate_summary | |
| Hugging Face Upload: | |
| After creating the parquet file, you can upload to Hugging Face Hub: | |
| from datasets import Dataset | |
| ds = Dataset.from_parquet("patient_embeddings.parquet") | |
| ds.push_to_hub("your-username/patient-embeddings") | |
| """ | |
| ) | |
| parser.add_argument( | |
| '--patients', | |
| type=str, | |
| required=True, | |
| help='Path to patient database (Parquet, CSV, or Excel). Required columns: patient_summary and the patient ID column (default: patient_id, or specify with --patient-id-col)' | |
| ) | |
| parser.add_argument( | |
| '--embedder', | |
| type=str, | |
| required=True, | |
| help='Path to embedder model or HuggingFace model name' | |
| ) | |
| parser.add_argument( | |
| '--output', | |
| type=str, | |
| required=True, | |
| help='Output path for the parquet file (e.g., "patient_embeddings.parquet")' | |
| ) | |
| parser.add_argument( | |
| '--device', | |
| type=str, | |
| default=None, | |
| help='Device to use for embedding (default: auto-detect). Examples: cuda, cuda:0, cuda:3, cpu. Ignored if --gpus is specified.' | |
| ) | |
| parser.add_argument( | |
| '--patient-id-col', | |
| type=str, | |
| default='patient_id', | |
| help='Name of the patient ID column in the input file (default: patient_id)' | |
| ) | |
| parser.add_argument( | |
| '--patient-boilerplate-col', | |
| type=str, | |
| default='patient_boilerplate', | |
| help='Name of the patient boilerplate column in the input file (default: patient_boilerplate). Set to empty string to skip.' | |
| ) | |
| parser.add_argument( | |
| '--gpus', | |
| type=str, | |
| default=None, | |
| help='Comma-separated list of GPU indices for multi-GPU parallel processing (e.g., "0,1,2,3"). Overrides --device if specified.' | |
| ) | |
| args = parser.parse_args() | |
| # Parse GPU list if provided | |
| gpu_list = None | |
| if args.gpus: | |
| try: | |
| gpu_list = [int(g.strip()) for g in args.gpus.split(',')] | |
| except ValueError: | |
| print(f"β ERROR: Invalid GPU list format: {args.gpus}") | |
| print(" Use comma-separated integers, e.g., '0,1,2,3'") | |
| return 1 | |
| print(f"\n{'='*70}") | |
| print(f"PATIENT SUMMARY PRE-EMBEDDING SCRIPT") | |
| print(f"{'='*70}") | |
| print(f"Patient Database: {args.patients}") | |
| print(f"Embedder Model: {args.embedder}") | |
| print(f"Output File: {args.output}") | |
| print(f"Patient ID Col: {args.patient_id_col}") | |
| print(f"Boilerplate Col: {args.patient_boilerplate_col or '(none)'}") | |
| if gpu_list: | |
| print(f"GPUs: {gpu_list} (multi-GPU mode)") | |
| elif args.device: | |
| print(f"Device: {args.device}") | |
| else: | |
| print(f"Device: auto-detect") | |
| print(f"{'='*70}\n") | |
| try: | |
| # Load patients | |
| df = load_patients(args.patients, args.patient_id_col, args.patient_boilerplate_col) | |
| # Embed patients | |
| embeddings, embedder_path = embed_patients(df, args.embedder, args.device, gpu_list) | |
| # Save everything to single parquet file | |
| save_embeddings(df, embeddings, args.output, embedder_path, gpu_list) | |
| print(f"\nβ SUCCESS!") | |
| except Exception as e: | |
| print(f"\nβ ERROR: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |