Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Pre-embed Clinical Trials Script (Multi-GPU Support) | |
| This script pre-processes and embeds a clinical trial database, | |
| saving the results to a single parquet file for easy sharing on HuggingFace. | |
| Usage: | |
| # Single GPU | |
| python preembed_trials.py --trials trials.csv --embedder path/to/embedder --output trial_embeddings.parquet --devices cuda:0 | |
| # Multi-GPU (parallel embedding) | |
| python preembed_trials.py --trials trial_space_lineitems.csv --embedder ksg-dfci/TrialSpace-1225 --output trial_embeddings.parquet --devices cuda:2,cuda:3 | |
| This will create: | |
| - trial_embeddings.parquet: Trial dataframe with 'embedding' column containing vectors | |
| - trial_embeddings_metadata.json: Metadata about the embedding process (optional) | |
| """ | |
| import argparse | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import json | |
| import os | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Tuple, List | |
| from transformers import AutoTokenizer | |
| import multiprocessing as mp | |
| def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str: | |
| """Truncate text to a maximum number of tokens.""" | |
| return tokenizer.decode( | |
| tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens), | |
| skip_special_tokens=True | |
| ) | |
| def load_trials(file_path: str) -> pd.DataFrame: | |
| """Load trials from CSV or Excel file.""" | |
| print(f"\n{'='*70}") | |
| print(f"Loading trial database from: {file_path}") | |
| print(f"{'='*70}") | |
| if file_path.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| elif file_path.endswith(('.xlsx', '.xls')): | |
| df = pd.read_excel(file_path) | |
| else: | |
| raise ValueError("Unsupported file format. Use CSV or Excel.") | |
| # Check required columns | |
| required_cols = ['nct_id', 'this_space', 'trial_text', 'trial_boilerplate_text'] | |
| missing = [col for col in required_cols if col not in df.columns] | |
| if missing: | |
| raise ValueError(f"Missing required columns: {', '.join(missing)}") | |
| print(f"✓ Loaded {len(df)} trials") | |
| print(f" Columns: {', '.join(df.columns.tolist())}") | |
| # Clean data | |
| original_count = len(df) | |
| df = df[~df['this_space'].isnull()].copy() | |
| df['trial_boilerplate_text'] = df['trial_boilerplate_text'].fillna('') | |
| if len(df) < original_count: | |
| print(f" ⚠ Removed {original_count - len(df)} trials with missing 'this_space'") | |
| return df | |
| def embed_chunk_on_device(args: Tuple[int, List[str], str, str]) -> Tuple[int, np.ndarray]: | |
| """ | |
| Worker function to embed a chunk of texts on a specific GPU. | |
| Args: | |
| args: Tuple of (chunk_index, texts_to_embed, embedder_path, device) | |
| Returns: | |
| Tuple of (chunk_index, embeddings_array) | |
| """ | |
| chunk_idx, texts, embedder_path, device = args | |
| # Import here to ensure fresh CUDA context in spawned process | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| print(f" [GPU {device}] Loading model for chunk {chunk_idx} ({len(texts)} texts)...") | |
| # Load model on specific device | |
| embedder_model = SentenceTransformer(embedder_path, device=device, trust_remote_code=True) | |
| # Set the instruction prompt | |
| try: | |
| embedder_model.prompts['query'] = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial." | |
| ) | |
| except: | |
| pass | |
| try: | |
| embedder_model.max_seq_length = 2500 | |
| except: | |
| pass | |
| print(f" [GPU {device}] Embedding {len(texts)} texts...") | |
| # Embed | |
| with torch.no_grad(): | |
| embeddings = embedder_model.encode( | |
| texts, | |
| batch_size=64, | |
| convert_to_tensor=True, | |
| normalize_embeddings=True, | |
| show_progress_bar=True, | |
| prompt='query' | |
| ) | |
| embeddings_np = embeddings.cpu().numpy() | |
| print(f" [GPU {device}] ✓ Chunk {chunk_idx} complete: {embeddings_np.shape}") | |
| # Explicitly clean up to free GPU memory | |
| del embedder_model | |
| del embeddings | |
| torch.cuda.empty_cache() | |
| return chunk_idx, embeddings_np | |
| def embed_trials_multi_gpu(df: pd.DataFrame, embedder_path: str, devices: List[str]) -> Tuple[np.ndarray, str]: | |
| """Embed trials using multiple GPUs in parallel.""" | |
| print(f"\n{'='*70}") | |
| print(f"MULTI-GPU EMBEDDING") | |
| print(f"{'='*70}") | |
| print(f"Embedder model: {embedder_path}") | |
| print(f"Devices: {', '.join(devices)}") | |
| print(f"Total trials: {len(df)}") | |
| # Load tokenizer for text preparation (on CPU) | |
| print(f"\nPreparing texts...") | |
| embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_path, trust_remote_code=True) | |
| # Prepare texts for embedding | |
| df['this_space_trunc'] = df['this_space'].apply( | |
| lambda x: truncate_text(str(x), embedder_tokenizer, max_tokens=1500) | |
| ) | |
| # Add instruction prefix | |
| prefix = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial. " | |
| ) | |
| all_texts = [prefix + txt for txt in df['this_space_trunc'].tolist()] | |
| print(f" Text length stats:") | |
| print(f" Mean: {np.mean([len(t) for t in all_texts]):.0f} chars") | |
| print(f" Max: {max([len(t) for t in all_texts])} chars") | |
| # Split texts into chunks for each GPU | |
| num_gpus = len(devices) | |
| chunk_size = len(all_texts) // num_gpus | |
| chunks = [] | |
| for i, device in enumerate(devices): | |
| start_idx = i * chunk_size | |
| # Last GPU gets any remainder | |
| end_idx = len(all_texts) if i == num_gpus - 1 else (i + 1) * chunk_size | |
| chunk_texts = all_texts[start_idx:end_idx] | |
| chunks.append((i, chunk_texts, embedder_path, device)) | |
| print(f" Chunk {i} -> {device}: indices {start_idx}-{end_idx} ({len(chunk_texts)} texts)") | |
| print(f"\n{'='*70}") | |
| print(f"Starting parallel embedding on {num_gpus} GPUs...") | |
| print(f"{'='*70}") | |
| # Run embedding in parallel using multiprocessing with spawn context | |
| ctx = mp.get_context('spawn') | |
| with ctx.Pool(processes=num_gpus) as pool: | |
| results = pool.map(embed_chunk_on_device, chunks) | |
| # Sort results by chunk index and concatenate | |
| results.sort(key=lambda x: x[0]) | |
| embeddings_list = [r[1] for r in results] | |
| embeddings_np = np.vstack(embeddings_list) | |
| print(f"\n{'='*70}") | |
| print(f"✓ Embedding complete") | |
| print(f" Final shape: {embeddings_np.shape}") | |
| print(f" Dtype: {embeddings_np.dtype}") | |
| print(f"{'='*70}") | |
| return embeddings_np, embedder_path | |
| def embed_trials_single_gpu(df: pd.DataFrame, embedder_path: str, device: str) -> Tuple[np.ndarray, str]: | |
| """Embed trials using a single GPU (original behavior).""" | |
| from sentence_transformers import SentenceTransformer | |
| print(f"\n{'='*70}") | |
| print(f"Loading embedder model: {embedder_path}") | |
| print(f"{'='*70}") | |
| print(f"Device: {device}") | |
| # Load embedder | |
| embedder_model = SentenceTransformer(embedder_path, device=device, trust_remote_code=True) | |
| embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_path, trust_remote_code=True) | |
| print(f"✓ Embedder loaded") | |
| # Set the instruction prompt | |
| try: | |
| embedder_model.prompts['query'] = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial." | |
| ) | |
| except: | |
| pass | |
| try: | |
| embedder_model.max_seq_length = 2500 | |
| except: | |
| pass | |
| print(f"\n{'='*70}") | |
| print(f"Embedding {len(df)} trials") | |
| print(f"{'='*70}") | |
| # Prepare texts for embedding | |
| df['this_space_trunc'] = df['this_space'].apply( | |
| lambda x: truncate_text(str(x), embedder_tokenizer, max_tokens=1500) | |
| ) | |
| # Add instruction prefix | |
| prefix = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial. " | |
| ) | |
| texts_to_embed = [prefix + txt for txt in df['this_space_trunc'].tolist()] | |
| print(f" Text length stats:") | |
| print(f" Mean: {np.mean([len(t) for t in texts_to_embed]):.0f} chars") | |
| print(f" Max: {max([len(t) for t in texts_to_embed])} chars") | |
| # Embed with progress bar | |
| with torch.no_grad(): | |
| embeddings = embedder_model.encode( | |
| texts_to_embed, | |
| batch_size=64, | |
| convert_to_tensor=True, | |
| normalize_embeddings=True, | |
| show_progress_bar=True, | |
| prompt='query' | |
| ) | |
| embeddings_np = embeddings.cpu().numpy() | |
| print(f"✓ Embedding complete") | |
| print(f" Shape: {embeddings_np.shape}") | |
| print(f" Dtype: {embeddings_np.dtype}") | |
| return embeddings_np, embedder_path | |
| def save_embeddings(df: pd.DataFrame, embeddings: np.ndarray, output_path: str, embedder_path: str, devices: List[str]): | |
| """Save trial data with embeddings to a single parquet file.""" | |
| print(f"\n{'='*70}") | |
| print(f"Saving to: {output_path}") | |
| print(f"{'='*70}") | |
| # Ensure output directory exists | |
| output_file = Path(output_path) | |
| output_file.parent.mkdir(parents=True, exist_ok=True) | |
| # Add embeddings as a column (convert each row to a list for parquet compatibility) | |
| df_out = df.copy() | |
| df_out['embedding'] = [emb.tolist() for emb in embeddings] | |
| # Save to parquet | |
| df_out.to_parquet(output_path, index=False) | |
| print(f"✓ Saved parquet file: {output_path}") | |
| print(f" Size: {output_file.stat().st_size / 1024 / 1024:.2f} MB") | |
| print(f" Rows: {len(df_out)}") | |
| print(f" Embedding dimension: {embeddings.shape[1]}") | |
| # Save metadata alongside (optional, for reference) | |
| metadata = { | |
| "created_at": datetime.now().isoformat(), | |
| "embedder_model": embedder_path, | |
| "num_trials": len(df), | |
| "embedding_dim": embeddings.shape[1], | |
| "nct_ids_sample": df['nct_id'].tolist()[:10] + (["..."] if len(df) > 10 else []), | |
| "embedding_dtype": str(embeddings.dtype), | |
| "normalized": True, | |
| "format": "parquet", | |
| "embedding_column": "embedding", | |
| "devices_used": devices | |
| } | |
| metadata_file = str(output_file.with_suffix('.metadata.json')) | |
| with open(metadata_file, 'w') as f: | |
| json.dump(metadata, f, indent=2) | |
| print(f"✓ Saved metadata: {metadata_file}") | |
| print(f"\n{'='*70}") | |
| print(f"PRE-EMBEDDING COMPLETE") | |
| print(f"{'='*70}") | |
| print(f"\nTo use these pre-embedded trials in your app:") | |
| print(f"1. Update config.py with:") | |
| print(f" PREEMBEDDED_TRIALS = '{output_path}'") | |
| print(f"2. Restart the application") | |
| print(f"\nThe app will automatically load these embeddings on startup!") | |
| print(f"\nTo share on HuggingFace:") | |
| print(f" huggingface-cli upload your-username/dataset-name {output_path}") | |
| def parse_devices(devices_str: str) -> List[str]: | |
| """Parse comma-separated device string into list of devices.""" | |
| if not devices_str: | |
| return ["cuda" if torch.cuda.is_available() else "cpu"] | |
| devices = [d.strip() for d in devices_str.split(',')] | |
| # Validate devices | |
| for device in devices: | |
| if device.startswith('cuda'): | |
| if ':' in device: | |
| gpu_id = int(device.split(':')[1]) | |
| if gpu_id >= torch.cuda.device_count(): | |
| raise ValueError(f"GPU {gpu_id} not available. Only {torch.cuda.device_count()} GPUs found.") | |
| elif not torch.cuda.is_available(): | |
| raise ValueError("CUDA not available") | |
| return devices | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Pre-embed clinical trials for faster loading (supports multi-GPU)", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Single GPU | |
| python preembed_trials.py --trials data/trials.csv --embedder models/embedder --output trial_embeddings.parquet --devices cuda:0 | |
| # Multi-GPU (4 GPUs in parallel) | |
| python preembed_trials.py --trials trials.csv --embedder Qwen/Qwen3-Embedding-0.6B --output trial_embeddings.parquet --devices cuda:0,cuda:1,cuda:2,cuda:3 | |
| # CPU only | |
| python preembed_trials.py --trials trials.csv --embedder model --output trial_embeddings.parquet --devices cpu | |
| """ | |
| ) | |
| parser.add_argument( | |
| '--trials', | |
| type=str, | |
| required=True, | |
| help='Path to trial database (CSV or Excel)' | |
| ) | |
| parser.add_argument( | |
| '--embedder', | |
| type=str, | |
| required=True, | |
| help='Path to embedder model or HuggingFace model name' | |
| ) | |
| parser.add_argument( | |
| '--output', | |
| type=str, | |
| required=True, | |
| help='Output path for parquet file (e.g., "trial_embeddings.parquet")' | |
| ) | |
| parser.add_argument( | |
| '--devices', | |
| type=str, | |
| default=None, | |
| help='Comma-separated list of devices (e.g., "cuda:0,cuda:1,cuda:2" or "cuda:0" or "cpu"). Default: auto-detect single GPU' | |
| ) | |
| # Keep --device for backwards compatibility | |
| parser.add_argument( | |
| '--device', | |
| type=str, | |
| default=None, | |
| help='(Deprecated) Use --devices instead. Single device to use for embedding.' | |
| ) | |
| args = parser.parse_args() | |
| # Handle backwards compatibility with --device | |
| if args.device and not args.devices: | |
| args.devices = args.device | |
| # Parse devices | |
| devices = parse_devices(args.devices) | |
| # Ensure output has .parquet extension | |
| output_path = args.output | |
| if not output_path.endswith('.parquet'): | |
| output_path = output_path + '.parquet' | |
| print(f"\n{'='*70}") | |
| print(f"CLINICAL TRIAL PRE-EMBEDDING SCRIPT") | |
| print(f"{'='*70}") | |
| print(f"Trial Database: {args.trials}") | |
| print(f"Embedder Model: {args.embedder}") | |
| print(f"Output File: {output_path}") | |
| print(f"Devices: {', '.join(devices)}") | |
| print(f"{'='*70}\n") | |
| try: | |
| # Load trials | |
| df = load_trials(args.trials) | |
| # Embed trials (choose single vs multi-GPU based on device count) | |
| if len(devices) > 1: | |
| embeddings, embedder_path = embed_trials_multi_gpu(df, args.embedder, devices) | |
| else: | |
| embeddings, embedder_path = embed_trials_single_gpu(df, args.embedder, devices[0]) | |
| # Save everything to parquet | |
| save_embeddings(df, embeddings, output_path, embedder_path, devices) | |
| print(f"\n✓ SUCCESS!") | |
| except Exception as e: | |
| print(f"\n✗ ERROR: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |