Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Pre-embed Clinical Trials Script | |
| This script pre-processes and embeds a clinical trial database, | |
| saving the results to disk for faster loading in the main application. | |
| Usage: | |
| python preembed_trials.py --trials trials.csv --embedder path/to/embedder --output trial_embeddings | |
| python preembed_trials.py --trials /data1/ken/meta/2024/v17b/trial_space_lineitems.csv --embedder /ksg/kehl_mm_data/meta/2024/v17/v17_models/reranker_round2.model --output trial_embeddings --device cuda:2 | |
| This will create: | |
| - trial_embeddings_data.pkl: Trial dataframe | |
| - trial_embeddings_vectors.npy: Embedding vectors | |
| - trial_embeddings_metadata.json: Metadata about the embedding process | |
| """ | |
| import argparse | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import json | |
| import re | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Tuple | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer | |
| def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str: | |
| """Truncate text to a maximum number of tokens.""" | |
| return tokenizer.decode( | |
| tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens), | |
| skip_special_tokens=True | |
| ) | |
| def load_trials(file_path: str) -> pd.DataFrame: | |
| """Load trials from CSV or Excel file.""" | |
| print(f"\n{'='*70}") | |
| print(f"Loading trial database from: {file_path}") | |
| print(f"{'='*70}") | |
| if file_path.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| elif file_path.endswith(('.xlsx', '.xls')): | |
| df = pd.read_excel(file_path) | |
| else: | |
| raise ValueError("Unsupported file format. Use CSV or Excel.") | |
| # Check required columns | |
| required_cols = ['nct_id', 'this_space', 'trial_text', 'trial_boilerplate_text'] | |
| missing = [col for col in required_cols if col not in df.columns] | |
| if missing: | |
| raise ValueError(f"Missing required columns: {', '.join(missing)}") | |
| print(f"β Loaded {len(df)} trials") | |
| print(f" Columns: {', '.join(df.columns.tolist())}") | |
| # Clean data | |
| original_count = len(df) | |
| df = df[~df['this_space'].isnull()].copy() | |
| df['trial_boilerplate_text'] = df['trial_boilerplate_text'].fillna('') | |
| if len(df) < original_count: | |
| print(f" β Removed {original_count - len(df)} trials with missing 'this_space'") | |
| return df | |
| def embed_trials(df: pd.DataFrame, embedder_path: str, device: str = None) -> Tuple[np.ndarray, str]: | |
| """Embed trials using the specified embedder model.""" | |
| print(f"\n{'='*70}") | |
| print(f"Loading embedder model: {embedder_path}") | |
| print(f"{'='*70}") | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| # Load embedder | |
| embedder_model = SentenceTransformer(embedder_path, device=device, trust_remote_code=True) | |
| embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_path, trust_remote_code=True) | |
| print(f"β Embedder loaded") | |
| # Set the instruction prompt | |
| try: | |
| embedder_model.prompts['query'] = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial." | |
| ) | |
| except: | |
| pass | |
| try: | |
| embedder_model.max_seq_length = 1500 | |
| except: | |
| pass | |
| print(f"\n{'='*70}") | |
| print(f"Embedding {len(df)} trials") | |
| print(f"{'='*70}") | |
| # Prepare texts for embedding | |
| df['this_space_trunc'] = df['this_space'].apply( | |
| lambda x: truncate_text(str(x), embedder_tokenizer, max_tokens=1500) | |
| ) | |
| # Add instruction prefix | |
| prefix = ( | |
| "Instruct: Given a cancer patient summary, retrieve clinical trial options " | |
| "that are reasonable for that patient; or, given a clinical trial option, " | |
| "retrieve cancer patients who are reasonable candidates for that trial. " | |
| ) | |
| texts_to_embed = [prefix + txt for txt in df['this_space_trunc'].tolist()] | |
| print(f" Text length stats:") | |
| print(f" Mean: {np.mean([len(t) for t in texts_to_embed]):.0f} chars") | |
| print(f" Max: {max([len(t) for t in texts_to_embed])} chars") | |
| # Embed with progress bar | |
| with torch.no_grad(): | |
| embeddings = embedder_model.encode( | |
| texts_to_embed, | |
| batch_size=64, | |
| convert_to_tensor=True, | |
| normalize_embeddings=True, | |
| show_progress_bar=True, | |
| prompt='query' | |
| ) | |
| embeddings_np = embeddings.cpu().numpy() | |
| print(f"β Embedding complete") | |
| print(f" Shape: {embeddings_np.shape}") | |
| print(f" Dtype: {embeddings_np.dtype}") | |
| return embeddings_np, embedder_path | |
| def save_embeddings(df: pd.DataFrame, embeddings: np.ndarray, output_prefix: str, embedder_path: str): | |
| """Save trial data, embeddings, and metadata to disk.""" | |
| print(f"\n{'='*70}") | |
| print(f"Saving to: {output_prefix}_*") | |
| print(f"{'='*70}") | |
| output_path = Path(output_prefix).parent | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| # Save dataframe | |
| df_file = f"{output_prefix}_data.pkl" | |
| df.to_pickle(df_file) | |
| print(f"β Saved trial dataframe: {df_file}") | |
| print(f" Size: {Path(df_file).stat().st_size / 1024 / 1024:.2f} MB") | |
| # Save embeddings | |
| embeddings_file = f"{output_prefix}_vectors.npy" | |
| np.save(embeddings_file, embeddings) | |
| print(f"β Saved embeddings: {embeddings_file}") | |
| print(f" Size: {Path(embeddings_file).stat().st_size / 1024 / 1024:.2f} MB") | |
| # Save metadata | |
| metadata = { | |
| "created_at": datetime.now().isoformat(), | |
| "embedder_model": embedder_path, | |
| "num_trials": len(df), | |
| "embedding_dim": embeddings.shape[1], | |
| "nct_ids": df['nct_id'].tolist()[:10] + ["..."] if len(df) > 10 else df['nct_id'].tolist(), | |
| "embedding_dtype": str(embeddings.dtype), | |
| "normalized": True | |
| } | |
| metadata_file = f"{output_prefix}_metadata.json" | |
| with open(metadata_file, 'w') as f: | |
| json.dump(metadata, f, indent=2) | |
| print(f"β Saved metadata: {metadata_file}") | |
| print(f"\n{'='*70}") | |
| print(f"PRE-EMBEDDING COMPLETE") | |
| print(f"{'='*70}") | |
| print(f"\nTo use these pre-embedded trials in your app:") | |
| print(f"1. Update config.py with:") | |
| print(f" PREEMBEDDED_TRIALS = '{output_prefix}'") | |
| print(f"2. Restart the application") | |
| print(f"\nThe app will automatically load these embeddings on startup!") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Pre-embed clinical trials for faster loading", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python preembed_trials.py --trials data/trials.csv --embedder models/embedder --output embeddings/trial_embeddings | |
| python preembed_trials.py --trials trials.xlsx --embedder Qwen/Qwen3-Embedding-0.6B --output trial_embeddings --device cuda | |
| """ | |
| ) | |
| parser.add_argument( | |
| '--trials', | |
| type=str, | |
| required=True, | |
| help='Path to trial database (CSV or Excel)' | |
| ) | |
| parser.add_argument( | |
| '--embedder', | |
| type=str, | |
| required=True, | |
| help='Path to embedder model or HuggingFace model name' | |
| ) | |
| parser.add_argument( | |
| '--output', | |
| type=str, | |
| required=True, | |
| help='Output prefix for saved files (e.g., "trial_embeddings" will create trial_embeddings_data.pkl, etc.)' | |
| ) | |
| parser.add_argument( | |
| '--device', | |
| type=str, | |
| default=None, | |
| #choices=['cuda', 'cpu'], | |
| help='Device to use for embedding (default: auto-detect)' | |
| ) | |
| args = parser.parse_args() | |
| print(f"\n{'='*70}") | |
| print(f"CLINICAL TRIAL PRE-EMBEDDING SCRIPT") | |
| print(f"{'='*70}") | |
| print(f"Trial Database: {args.trials}") | |
| print(f"Embedder Model: {args.embedder}") | |
| print(f"Output Prefix: {args.output}") | |
| print(f"{'='*70}\n") | |
| try: | |
| # Load trials | |
| df = load_trials(args.trials) | |
| # Embed trials | |
| embeddings, embedder_path = embed_trials(df, args.embedder, args.device) | |
| # Save everything | |
| save_embeddings(df, embeddings, args.output, embedder_path) | |
| print(f"\nβ SUCCESS!") | |
| except Exception as e: | |
| print(f"\nβ ERROR: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |