|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Patient Matching Pipeline - Gradio Web Interface |
|
|
|
|
|
This interface allows users to: |
|
|
1. Configure models (embedder, trial_checker, boilerplate_checker) |
|
|
2. Upload patient database OR load pre-embedded patients |
|
|
3. Enter set of clinical criteria (trial eligibility criteria) |
|
|
4. Get ranked patient recommendations with eligibility predictions |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import pickle |
|
|
import html |
|
|
from typing import List, Tuple |
|
|
from pathlib import Path |
|
|
import pyarrow.parquet as pq |
|
|
|
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
) |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
try: |
|
|
import config |
|
|
HAS_CONFIG = True |
|
|
print("✓ Found config.py - will auto-load models on startup") |
|
|
except ImportError: |
|
|
HAS_CONFIG = False |
|
|
print("○ No config.py found - using manual model loading") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AppState: |
|
|
def __init__(self): |
|
|
self.embedder_model = None |
|
|
self.embedder_tokenizer = None |
|
|
self.trial_checker_model = None |
|
|
self.trial_checker_tokenizer = None |
|
|
self.boilerplate_checker_model = None |
|
|
self.boilerplate_checker_tokenizer = None |
|
|
|
|
|
self.patient_df = None |
|
|
self.patient_embeddings = None |
|
|
self.patient_preview_df = None |
|
|
|
|
|
|
|
|
self.last_results_df = None |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.auto_load_status = { |
|
|
"embedder": "", |
|
|
"trial_checker": "", |
|
|
"boilerplate_checker": "", |
|
|
"patients": "" |
|
|
} |
|
|
|
|
|
def reset_patients(self): |
|
|
self.patient_df = None |
|
|
self.patient_embeddings = None |
|
|
self.patient_preview_df = None |
|
|
|
|
|
state = AppState() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_EMBEDDER_SEQ_LEN = 2500 |
|
|
MAX_TRIAL_CHECKER_LENGTH = 4096 |
|
|
MAX_BOILERPLATE_CHECKER_LENGTH = 3192 |
|
|
CLASSIFIER_BATCH_SIZE = 32 |
|
|
|
|
|
|
|
|
DEFAULT_CLINICAL_SPACE_TEMPLATE = """Age range allowed: |
|
|
Sex allowed: |
|
|
Cancer type allowed: |
|
|
Histology allowed: |
|
|
Cancer burden allowed: |
|
|
Prior treatment required: |
|
|
Prior treatment excluded: |
|
|
Biomarkers required: |
|
|
Biomarkers excluded: """ |
|
|
|
|
|
DEFAULT_BOILERPLATE_TEMPLATE = """History of pneumonitis: |
|
|
Heart failure or cardiac dysfunction: |
|
|
Renal dysfunction: |
|
|
Liver dysfunction: |
|
|
Uncontrolled brain metastases: |
|
|
HIV or hepatitis infection: |
|
|
Poor performance status (ECOG >= 2): |
|
|
Other relevant exclusions: """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str: |
|
|
"""Truncate text to a maximum number of tokens.""" |
|
|
return tokenizer.decode( |
|
|
tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens), |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
def format_probability_visual(val, is_exclusion=False): |
|
|
"""Format probabilities with visual indicators.""" |
|
|
try: |
|
|
val_float = float(val) |
|
|
except: |
|
|
return val |
|
|
|
|
|
if not is_exclusion: |
|
|
|
|
|
if val_float >= 0.8: |
|
|
return f"🟢 **{val_float:.2f}**" |
|
|
elif val_float >= 0.5: |
|
|
return f"🟡 {val_float:.2f}" |
|
|
else: |
|
|
return f"🔴 {val_float:.2f}" |
|
|
else: |
|
|
|
|
|
if val_float >= 0.5: |
|
|
return f"🔴 **{val_float:.2f}**" |
|
|
elif val_float >= 0.2: |
|
|
return f"🟡 {val_float:.2f}" |
|
|
else: |
|
|
return f"🟢 {val_float:.2f}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_load_models_from_config(): |
|
|
"""Auto-load models specified in config.py""" |
|
|
if not HAS_CONFIG: |
|
|
return |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("AUTO-LOADING MODELS FROM CONFIG") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
if config.MODEL_CONFIG.get("embedder"): |
|
|
print(f"\n[1/3] Loading embedder: {config.MODEL_CONFIG['embedder']}") |
|
|
status, _, _ = load_embedder_model(config.MODEL_CONFIG["embedder"]) |
|
|
state.auto_load_status["embedder"] = status |
|
|
print(status) |
|
|
|
|
|
|
|
|
if config.MODEL_CONFIG.get("trial_checker"): |
|
|
print(f"\n[2/3] Loading trial checker: {config.MODEL_CONFIG['trial_checker']}") |
|
|
status, _ = load_trial_checker(config.MODEL_CONFIG["trial_checker"]) |
|
|
state.auto_load_status["trial_checker"] = status |
|
|
print(status) |
|
|
|
|
|
|
|
|
if config.MODEL_CONFIG.get("boilerplate_checker"): |
|
|
print(f"\n[3/3] Loading boilerplate checker: {config.MODEL_CONFIG['boilerplate_checker']}") |
|
|
status, _ = load_boilerplate_checker(config.MODEL_CONFIG["boilerplate_checker"]) |
|
|
state.auto_load_status["boilerplate_checker"] = status |
|
|
print(status) |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("MODEL AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
def auto_load_patients_from_config(): |
|
|
"""Auto-load patient database from config.py - prefers pre-embedded over fresh embedding.""" |
|
|
if not HAS_CONFIG: |
|
|
return |
|
|
|
|
|
|
|
|
if hasattr(config, 'PREEMBEDDED_PATIENTS') and config.PREEMBEDDED_PATIENTS: |
|
|
preembed_path = config.PREEMBEDDED_PATIENTS |
|
|
|
|
|
|
|
|
if preembed_path.startswith("http://") or preembed_path.startswith("https://"): |
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (URL): {preembed_path}") |
|
|
print("="*70) |
|
|
|
|
|
status, preview = load_preembedded_patients(preembed_path) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
return |
|
|
|
|
|
|
|
|
parquet_path = preembed_path if preembed_path.endswith('.parquet') else f"{preembed_path}.parquet" |
|
|
old_format_data = f"{preembed_path}_data.pkl" |
|
|
|
|
|
if os.path.exists(parquet_path): |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (parquet): {parquet_path}") |
|
|
print("="*70) |
|
|
|
|
|
status, preview = load_preembedded_patients(parquet_path) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
return |
|
|
elif os.path.exists(old_format_data): |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (legacy): {preembed_path}") |
|
|
print("="*70) |
|
|
|
|
|
status, preview = load_preembedded_patients(preembed_path) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
return |
|
|
else: |
|
|
print(f"✗ Pre-embedded patient files not found: {preembed_path}") |
|
|
state.auto_load_status["patients"] = f"✗ Pre-embedded files not found: {preembed_path}" |
|
|
return |
|
|
|
|
|
|
|
|
if not hasattr(config, 'DEFAULT_PATIENT_DB') or not config.DEFAULT_PATIENT_DB: |
|
|
print("○ No patient database specified in config") |
|
|
return |
|
|
|
|
|
if not os.path.exists(config.DEFAULT_PATIENT_DB): |
|
|
print(f"✗ Default patient database not found: {config.DEFAULT_PATIENT_DB}") |
|
|
state.auto_load_status["patients"] = f"✗ Patient database file not found: {config.DEFAULT_PATIENT_DB}" |
|
|
return |
|
|
|
|
|
if state.embedder_model is None: |
|
|
print("○ Embedder not loaded yet - skipping patient database auto-load") |
|
|
state.auto_load_status["patients"] = "○ Waiting for embedder model to be loaded..." |
|
|
return |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(f"AUTO-LOADING PATIENT DATABASE: {config.DEFAULT_PATIENT_DB}") |
|
|
print("="*70) |
|
|
|
|
|
class FilePath: |
|
|
def __init__(self, path): |
|
|
self.name = path |
|
|
|
|
|
status, preview = load_and_embed_patients(FilePath(config.DEFAULT_PATIENT_DB), show_progress=True) |
|
|
state.auto_load_status["patients"] = status |
|
|
state.patient_preview_df = preview |
|
|
|
|
|
print("="*70) |
|
|
print("PATIENT DATABASE AUTO-LOADING COMPLETE") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_embedder_model(model_path: str) -> Tuple[str, str, str]: |
|
|
"""Load sentence transformer embedder model.""" |
|
|
try: |
|
|
will_need_reembed = state.patient_df is not None and len(state.patient_df) > 0 |
|
|
|
|
|
if will_need_reembed: |
|
|
warning_msg = f"\n⚠️ Warning: {len(state.patient_df)} patients are currently loaded. They will need to be re-embedded with the new model." |
|
|
else: |
|
|
warning_msg = "" |
|
|
|
|
|
state.embedder_model = SentenceTransformer(model_path, device=state.device, trust_remote_code=True) |
|
|
state.embedder_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
|
|
|
try: |
|
|
state.embedder_model.prompts['query'] = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial." |
|
|
) |
|
|
except: |
|
|
pass |
|
|
|
|
|
try: |
|
|
state.embedder_model.max_seq_length = MAX_EMBEDDER_SEQ_LEN |
|
|
except: |
|
|
pass |
|
|
|
|
|
success_msg = f"✓ Embedder model loaded from {model_path}{warning_msg}" |
|
|
|
|
|
if will_need_reembed: |
|
|
state.patient_embeddings = None |
|
|
success_msg += "\n→ Patient embeddings cleared. Please reload patient database to re-embed." |
|
|
|
|
|
return success_msg, "", warning_msg |
|
|
except Exception as e: |
|
|
return f"✗ Error loading embedder model: {str(e)}", str(e), "" |
|
|
|
|
|
|
|
|
def load_trial_checker(model_path: str) -> Tuple[str, str]: |
|
|
"""Load ModernBERT trial checker.""" |
|
|
try: |
|
|
state.trial_checker_tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
state.trial_checker_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if state.device == "cuda" else torch.float32 |
|
|
).to(state.device) |
|
|
state.trial_checker_model.eval() |
|
|
return f"✓ Trial checker loaded from {model_path}", "" |
|
|
except Exception as e: |
|
|
return f"✗ Error loading trial checker: {str(e)}", str(e) |
|
|
|
|
|
|
|
|
def load_boilerplate_checker(model_path: str) -> Tuple[str, str]: |
|
|
"""Load ModernBERT boilerplate checker.""" |
|
|
try: |
|
|
state.boilerplate_checker_tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
state.boilerplate_checker_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if state.device == "cuda" else torch.float32 |
|
|
).to(state.device) |
|
|
state.boilerplate_checker_model.eval() |
|
|
return f"✓ Boilerplate checker loaded from {model_path}", "" |
|
|
except Exception as e: |
|
|
return f"✗ Error loading boilerplate checker: {str(e)}", str(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_preembedded_patients(preembedded_path: str) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load pre-embedded patient database from disk. |
|
|
|
|
|
Supports two formats: |
|
|
1. New format: Single parquet file with patient_embedding column |
|
|
- Path should end with .parquet |
|
|
- Embeddings stored as lists in patient_embedding column |
|
|
- Metadata stored in parquet file metadata |
|
|
|
|
|
2. Legacy format: Separate pkl/npy/json files |
|
|
- Path is a prefix (e.g., "patient_embeddings") |
|
|
- Creates patient_embeddings_data.pkl, _vectors.npy, _metadata.json |
|
|
""" |
|
|
try: |
|
|
|
|
|
is_parquet = preembedded_path.endswith('.parquet') or os.path.exists(f"{preembedded_path}.parquet") if not preembedded_path.endswith('.parquet') else True |
|
|
|
|
|
if is_parquet: |
|
|
return _load_preembedded_parquet(preembedded_path) |
|
|
else: |
|
|
return _load_preembedded_legacy(preembedded_path) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"✗ Error loading pre-embedded patients: {str(e)}", None |
|
|
|
|
|
|
|
|
def _load_preembedded_parquet(parquet_path: str) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load pre-embedded patients from new single parquet format.""" |
|
|
is_url = parquet_path.startswith("http://") or parquet_path.startswith("https://") |
|
|
|
|
|
|
|
|
if not is_url and not parquet_path.endswith('.parquet'): |
|
|
parquet_path = f"{parquet_path}.parquet" |
|
|
|
|
|
if not is_url and not os.path.exists(parquet_path): |
|
|
return f"✗ Pre-embedded parquet file not found: {parquet_path}", None |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"LOADING PRE-EMBEDDED PATIENTS (Parquet Format)") |
|
|
print(f"{'='*70}") |
|
|
print(f"Loading from: {parquet_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
if is_url: |
|
|
df = pd.read_parquet(parquet_path) |
|
|
|
|
|
|
|
|
print(f"Metadata: (Skipped for URL)") |
|
|
else: |
|
|
|
|
|
parquet_file = pq.read_table(parquet_path) |
|
|
|
|
|
|
|
|
if parquet_file.schema.metadata and b'patient_embedding_metadata' in parquet_file.schema.metadata: |
|
|
metadata = json.loads(parquet_file.schema.metadata[b'patient_embedding_metadata'].decode('utf-8')) |
|
|
print(f"Metadata:") |
|
|
print(f" Created: {metadata.get('created_at', 'unknown')}") |
|
|
print(f" Embedder: {metadata.get('embedder_model', 'unknown')}") |
|
|
print(f" Patients: {metadata.get('num_patients', 'unknown')}") |
|
|
print(f" Embedding dim: {metadata.get('embedding_dim', 'unknown')}") |
|
|
|
|
|
|
|
|
df = parquet_file.to_pandas() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"✗ Failed to read parquet file from {parquet_path}: {str(e)}" |
|
|
print(error_msg) |
|
|
return error_msg, None |
|
|
|
|
|
print(f"✓ Loaded {len(df)} patients") |
|
|
print(f" Columns: {', '.join(df.columns.tolist())}") |
|
|
|
|
|
|
|
|
if 'patient_embedding' not in df.columns: |
|
|
return f"✗ Parquet file missing 'patient_embedding' column: {parquet_path}", None |
|
|
|
|
|
if 'patient_id' not in df.columns: |
|
|
return f"✗ Parquet file missing 'patient_id' column: {parquet_path}", None |
|
|
|
|
|
if 'patient_summary' not in df.columns: |
|
|
return f"✗ Parquet file missing 'patient_summary' column: {parquet_path}", None |
|
|
|
|
|
|
|
|
if 'patient_boilerplate' in df.columns: |
|
|
non_empty_bp = (df['patient_boilerplate'].astype(str).str.strip().str.len() > 0).sum() |
|
|
print(f" ✓ patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text") |
|
|
else: |
|
|
print(f" ⚠ No patient_boilerplate column found") |
|
|
df['patient_boilerplate'] = '' |
|
|
|
|
|
|
|
|
print(f"Converting embeddings to numpy array...") |
|
|
embeddings = np.array(df['patient_embedding'].tolist(), dtype=np.float32) |
|
|
print(f"✓ Loaded embeddings: {embeddings.shape}") |
|
|
|
|
|
|
|
|
df_without_embeddings = df.drop(columns=['patient_embedding']) |
|
|
|
|
|
state.patient_df = df_without_embeddings |
|
|
state.patient_embeddings = embeddings |
|
|
|
|
|
print(f"{'='*70}") |
|
|
print(f"PRE-EMBEDDED PATIENTS LOADED SUCCESSFULLY") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
preview = df_without_embeddings[['patient_id', 'patient_summary']].head(10) |
|
|
return f"✓ Loaded {len(df)} pre-embedded patients from {os.path.basename(parquet_path)}", preview |
|
|
|
|
|
|
|
|
def _load_preembedded_legacy(preembedded_prefix: str) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load pre-embedded patients from legacy format (pkl + npy + json files).""" |
|
|
data_file = f"{preembedded_prefix}_data.pkl" |
|
|
vectors_file = f"{preembedded_prefix}_vectors.npy" |
|
|
metadata_file = f"{preembedded_prefix}_metadata.json" |
|
|
|
|
|
if not os.path.exists(data_file): |
|
|
return f"✗ Pre-embedded data file not found: {data_file}", None |
|
|
if not os.path.exists(vectors_file): |
|
|
return f"✗ Pre-embedded vectors file not found: {vectors_file}", None |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"LOADING PRE-EMBEDDED PATIENTS (Legacy Format)") |
|
|
print(f"{'='*70}") |
|
|
print(f"Loading from: {preembedded_prefix}_*") |
|
|
|
|
|
if os.path.exists(metadata_file): |
|
|
with open(metadata_file, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
print(f"Metadata:") |
|
|
print(f" Created: {metadata.get('created_at', 'unknown')}") |
|
|
print(f" Embedder: {metadata.get('embedder_model', 'unknown')}") |
|
|
print(f" Patients: {metadata.get('num_patients', 'unknown')}") |
|
|
print(f" Embedding dim: {metadata.get('embedding_dim', 'unknown')}") |
|
|
|
|
|
print(f"Loading patient dataframe...") |
|
|
with open(data_file, 'rb') as f: |
|
|
df = pickle.load(f) |
|
|
print(f"✓ Loaded {len(df)} patients") |
|
|
print(f" Columns: {', '.join(df.columns.tolist())}") |
|
|
|
|
|
|
|
|
if 'patient_boilerplate' in df.columns: |
|
|
non_empty_bp = (df['patient_boilerplate'].astype(str).str.strip().str.len() > 0).sum() |
|
|
print(f" ✓ patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text") |
|
|
else: |
|
|
print(f" ⚠ No patient_boilerplate column found") |
|
|
df['patient_boilerplate'] = '' |
|
|
|
|
|
print(f"Loading embeddings...") |
|
|
embeddings = np.load(vectors_file) |
|
|
print(f"✓ Loaded embeddings: {embeddings.shape}") |
|
|
|
|
|
if len(df) != embeddings.shape[0]: |
|
|
return ( |
|
|
f"✗ Mismatch: {len(df)} patients but {embeddings.shape[0]} embeddings", |
|
|
None |
|
|
) |
|
|
|
|
|
state.patient_df = df |
|
|
state.patient_embeddings = embeddings |
|
|
|
|
|
print(f"{'='*70}") |
|
|
print(f"PRE-EMBEDDED PATIENTS LOADED SUCCESSFULLY") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
preview = df[['patient_id', 'patient_summary']].head(10) |
|
|
return f"✓ Loaded {len(df)} pre-embedded patients from {preembedded_prefix}_*", preview |
|
|
|
|
|
|
|
|
def load_and_embed_patients(file, show_progress: bool = False) -> Tuple[str, pd.DataFrame]: |
|
|
"""Load patient database and embed summaries.""" |
|
|
try: |
|
|
if state.embedder_model is None: |
|
|
return "✗ Please load the embedder model first!", None |
|
|
|
|
|
|
|
|
if file.name.endswith('.parquet'): |
|
|
df = pd.read_parquet(file.name) |
|
|
elif file.name.endswith('.csv'): |
|
|
df = pd.read_csv(file.name) |
|
|
elif file.name.endswith(('.xlsx', '.xls')): |
|
|
df = pd.read_excel(file.name) |
|
|
else: |
|
|
return "✗ Unsupported format. Use Parquet, CSV, or Excel.", None |
|
|
|
|
|
|
|
|
required_cols = ['patient_id', 'patient_summary'] |
|
|
missing = [col for col in required_cols if col not in df.columns] |
|
|
if missing: |
|
|
return f"✗ Missing columns: {', '.join(missing)}", None |
|
|
|
|
|
|
|
|
df = df[~df['patient_summary'].isnull()].copy() |
|
|
df = df[df['patient_summary'].astype(str).str.strip().str.len() > 0].copy() |
|
|
|
|
|
if 'patient_boilerplate' not in df.columns: |
|
|
df['patient_boilerplate'] = '' |
|
|
else: |
|
|
df['patient_boilerplate'] = df['patient_boilerplate'].fillna('') |
|
|
|
|
|
|
|
|
df['patient_summary_trunc'] = df['patient_summary'].apply( |
|
|
lambda x: truncate_text(str(x), state.embedder_tokenizer, max_tokens=1500) |
|
|
) |
|
|
|
|
|
prefix = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial. " |
|
|
) |
|
|
texts_to_embed = [prefix + txt for txt in df['patient_summary_trunc'].tolist()] |
|
|
|
|
|
if not show_progress: |
|
|
gr.Info(f"Embedding {len(df)} patient summaries...") |
|
|
else: |
|
|
print(f"Embedding {len(df)} patient summaries...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = state.embedder_model.encode( |
|
|
texts_to_embed, |
|
|
batch_size=64, |
|
|
convert_to_tensor=True, |
|
|
normalize_embeddings=True, |
|
|
show_progress_bar=show_progress, |
|
|
prompt='query' |
|
|
) |
|
|
|
|
|
state.patient_df = df |
|
|
state.patient_embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
preview = df[['patient_id', 'patient_summary']].head(10) |
|
|
|
|
|
success_msg = f"✓ Loaded and embedded {len(df)} patients" |
|
|
if show_progress: |
|
|
print(success_msg) |
|
|
|
|
|
return success_msg, preview |
|
|
|
|
|
except Exception as e: |
|
|
return f"✗ Error processing patients: {str(e)}", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def match_patients( |
|
|
clinical_space: str, |
|
|
boilerplate_criteria: str, |
|
|
top_k_check: int = 1000, |
|
|
eligibility_threshold: float = 0.5 |
|
|
) -> Tuple[pd.DataFrame, str]: |
|
|
"""Match clinical query to patients and run eligibility checks.""" |
|
|
try: |
|
|
if state.embedder_model is None: |
|
|
raise ValueError("Embedder model not loaded") |
|
|
if state.patient_embeddings is None: |
|
|
raise ValueError("Patient database not loaded") |
|
|
if state.trial_checker_model is None: |
|
|
raise ValueError("Trial checker model not loaded") |
|
|
if state.boilerplate_checker_model is None: |
|
|
raise ValueError("Boilerplate checker model not loaded") |
|
|
|
|
|
if not clinical_space or not clinical_space.strip(): |
|
|
raise ValueError("Please enter clinical criteria") |
|
|
|
|
|
|
|
|
prefix = ( |
|
|
"Instruct: Given a cancer patient summary, retrieve clinical trial options " |
|
|
"that are reasonable for that patient; or, given a clinical trial option, " |
|
|
"retrieve cancer patients who are reasonable candidates for that trial. " |
|
|
) |
|
|
|
|
|
query_text = truncate_text(clinical_space, state.embedder_tokenizer, max_tokens=MAX_EMBEDDER_SEQ_LEN) |
|
|
query_text_with_prefix = prefix + query_text |
|
|
|
|
|
gr.Info("Ranking all patients by similarity...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
query_emb = state.embedder_model.encode( |
|
|
[query_text_with_prefix], |
|
|
convert_to_tensor=True, |
|
|
normalize_embeddings=True, |
|
|
prompt='query' |
|
|
) |
|
|
|
|
|
|
|
|
query_emb_np = query_emb.cpu().numpy() |
|
|
similarities = np.dot(state.patient_embeddings, query_emb_np.T).squeeze() |
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(similarities)[::-1] |
|
|
|
|
|
|
|
|
all_patients_ranked = state.patient_df.iloc[sorted_indices].copy() |
|
|
all_patients_ranked['similarity_score'] = similarities[sorted_indices] |
|
|
|
|
|
|
|
|
top_k_check = min(top_k_check, len(all_patients_ranked)) |
|
|
patients_to_check = all_patients_ranked.head(top_k_check).copy() |
|
|
|
|
|
gr.Info(f"Running eligibility checks on top {len(patients_to_check)} patients...") |
|
|
|
|
|
|
|
|
trial_check_inputs = [ |
|
|
f"{clinical_space}\nNow here is the patient summary:{row['patient_summary']}" |
|
|
for _, row in patients_to_check.iterrows() |
|
|
] |
|
|
|
|
|
trial_probs_list = [] |
|
|
for i in range(0, len(trial_check_inputs), CLASSIFIER_BATCH_SIZE): |
|
|
batch_inputs = trial_check_inputs[i:i + CLASSIFIER_BATCH_SIZE] |
|
|
|
|
|
batch_encodings = state.trial_checker_tokenizer( |
|
|
batch_inputs, |
|
|
truncation=True, |
|
|
max_length=MAX_TRIAL_CHECKER_LENGTH, |
|
|
padding=True, |
|
|
return_tensors='pt' |
|
|
).to(state.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_outputs = state.trial_checker_model(**batch_encodings) |
|
|
batch_probs = torch.softmax(batch_outputs.logits, dim=1)[:, 1].cpu().numpy() |
|
|
trial_probs_list.append(batch_probs) |
|
|
|
|
|
trial_probs = np.concatenate(trial_probs_list) |
|
|
patients_to_check['eligibility_probability'] = trial_probs |
|
|
|
|
|
|
|
|
|
|
|
def get_boilerplate_text(row): |
|
|
bp = row.get('patient_boilerplate', '') |
|
|
if bp and isinstance(bp, str) and bp.strip(): |
|
|
return bp |
|
|
return row['patient_summary'] |
|
|
|
|
|
boilerplate_check_inputs = [ |
|
|
f"Patient history: {get_boilerplate_text(row)}\nTrial exclusions:{boilerplate_criteria}" |
|
|
for _, row in patients_to_check.iterrows() |
|
|
] |
|
|
|
|
|
boilerplate_probs_list = [] |
|
|
for i in range(0, len(boilerplate_check_inputs), CLASSIFIER_BATCH_SIZE): |
|
|
batch_inputs = boilerplate_check_inputs[i:i + CLASSIFIER_BATCH_SIZE] |
|
|
|
|
|
batch_encodings = state.boilerplate_checker_tokenizer( |
|
|
batch_inputs, |
|
|
truncation=True, |
|
|
max_length=MAX_BOILERPLATE_CHECKER_LENGTH, |
|
|
padding=True, |
|
|
return_tensors='pt' |
|
|
).to(state.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_outputs = state.boilerplate_checker_model(**batch_encodings) |
|
|
batch_probs = torch.softmax(batch_outputs.logits, dim=1)[:, 1].cpu().numpy() |
|
|
boilerplate_probs_list.append(batch_probs) |
|
|
|
|
|
boilerplate_probs = np.concatenate(boilerplate_probs_list) |
|
|
patients_to_check['exclusion_probability'] = boilerplate_probs |
|
|
|
|
|
|
|
|
patients_to_check = patients_to_check.sort_values('eligibility_probability', ascending=False) |
|
|
|
|
|
|
|
|
state.last_results_df = patients_to_check.copy() |
|
|
|
|
|
|
|
|
num_eligible = (patients_to_check['eligibility_probability'] >= eligibility_threshold).sum() |
|
|
num_no_exclusion = (patients_to_check['exclusion_probability'] < 0.5).sum() |
|
|
num_both = ((patients_to_check['eligibility_probability'] >= eligibility_threshold) & |
|
|
(patients_to_check['exclusion_probability'] < 0.5)).sum() |
|
|
|
|
|
bottom_line = f""" |
|
|
### 📊 Summary: Patients Meeting Your Criteria |
|
|
| Metric | Count | |
|
|
|--------|-------| |
|
|
| Total patients in database | **{len(state.patient_df)}** | |
|
|
| Top patients checked with classifiers | **{len(patients_to_check)}** | |
|
|
| Meeting eligibility criteria (≥{eligibility_threshold}) | **{num_eligible}** | |
|
|
| Without boilerplate exclusions (<0.5) | **{num_no_exclusion}** | |
|
|
| **Meeting BOTH criteria** | **{num_both}** | |
|
|
""" |
|
|
|
|
|
|
|
|
patients_to_check['eligibility_display'] = patients_to_check['eligibility_probability'].apply( |
|
|
lambda x: format_probability_visual(x, is_exclusion=False) |
|
|
) |
|
|
patients_to_check['exclusion_display'] = patients_to_check['exclusion_probability'].apply( |
|
|
lambda x: format_probability_visual(x, is_exclusion=True) |
|
|
) |
|
|
patients_to_check['similarity_display'] = patients_to_check['similarity_score'].apply( |
|
|
lambda x: f"{x:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
patients_to_check['summary_preview'] = patients_to_check['patient_summary'].apply( |
|
|
lambda x: str(x)[:300] + "..." if len(str(x)) > 300 else str(x) |
|
|
) |
|
|
|
|
|
|
|
|
display_cols = [ |
|
|
'patient_id', |
|
|
'eligibility_display', |
|
|
'exclusion_display', |
|
|
'similarity_display', |
|
|
'summary_preview' |
|
|
] |
|
|
|
|
|
result_df = patients_to_check[display_cols].reset_index(drop=True) |
|
|
result_df.columns = [ |
|
|
'Patient ID', |
|
|
'Eligibility', |
|
|
'Exclusion', |
|
|
'Similarity', |
|
|
'Summary Preview' |
|
|
] |
|
|
|
|
|
return result_df, bottom_line |
|
|
|
|
|
except Exception as e: |
|
|
gr.Error(f"Error matching patients: {str(e)}") |
|
|
return pd.DataFrame(), f"**Error:** {str(e)}" |
|
|
|
|
|
|
|
|
def get_patient_details(df: pd.DataFrame, evt: gr.SelectData) -> str: |
|
|
"""Get full patient details when user clicks on a row.""" |
|
|
try: |
|
|
if df is None or len(df) == 0: |
|
|
return "No patient selected" |
|
|
|
|
|
row_idx = evt.index[0] |
|
|
patient_id = df.iloc[row_idx]['Patient ID'] |
|
|
|
|
|
|
|
|
if state.last_results_df is None: |
|
|
return "No results available" |
|
|
|
|
|
matching_rows = state.last_results_df[ |
|
|
state.last_results_df['patient_id'] == patient_id |
|
|
] |
|
|
|
|
|
if len(matching_rows) == 0: |
|
|
return f"Error: Could not find patient {patient_id}" |
|
|
|
|
|
patient_row = matching_rows.iloc[0] |
|
|
|
|
|
|
|
|
raw_boilerplate = patient_row.get('patient_boilerplate', '') |
|
|
has_separate_boilerplate = raw_boilerplate and isinstance(raw_boilerplate, str) and raw_boilerplate.strip() |
|
|
|
|
|
if has_separate_boilerplate: |
|
|
boilerplate_text = raw_boilerplate |
|
|
else: |
|
|
boilerplate_text = "(No separate boilerplate column - patient summary was used for boilerplate exclusion check)" |
|
|
|
|
|
|
|
|
summary_escaped = html.escape(str(patient_row['patient_summary'])) |
|
|
boilerplate_escaped = html.escape(str(boilerplate_text)) |
|
|
|
|
|
details = f""" |
|
|
# Patient Details: {patient_id} |
|
|
|
|
|
--- |
|
|
|
|
|
## Scores |
|
|
- **Eligibility Probability:** {patient_row['eligibility_probability']:.3f} |
|
|
- **Exclusion Probability:** {patient_row['exclusion_probability']:.3f} |
|
|
- **Similarity Score:** {patient_row['similarity_score']:.3f} |
|
|
|
|
|
--- |
|
|
|
|
|
## Full Patient Summary |
|
|
<pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{summary_escaped}</pre> |
|
|
|
|
|
--- |
|
|
|
|
|
## Boilerplate Exclusion Check Input |
|
|
<pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{boilerplate_escaped}</pre> |
|
|
""" |
|
|
return details |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error retrieving patient details: {str(e)}" |
|
|
|
|
|
|
|
|
def request_identified_patients(): |
|
|
"""Placeholder for requesting identified patient list.""" |
|
|
if state.last_results_df is None or len(state.last_results_df) == 0: |
|
|
gr.Warning("No results to request - run a search first") |
|
|
return |
|
|
|
|
|
|
|
|
gr.Info("Request functionality not yet implemented") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
|
|
|
theme = gr.themes.Soft( |
|
|
primary_hue="teal", |
|
|
secondary_hue="slate", |
|
|
).set( |
|
|
body_background_fill="*neutral_50", |
|
|
block_background_fill="white", |
|
|
block_border_width="1px", |
|
|
block_label_background_fill="*primary_50", |
|
|
) |
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { font-family: 'Inter', Arial, sans-serif !important; } |
|
|
.model-status { min-height: 80px !important; font-size: 0.9em; } |
|
|
.status-box { background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; padding: 10px; } |
|
|
h1 { color: #0d9488; } |
|
|
""" |
|
|
|
|
|
|
|
|
clinical_space_template = getattr(config, 'CLINICAL_SPACE_TEMPLATE', DEFAULT_CLINICAL_SPACE_TEMPLATE) if HAS_CONFIG else DEFAULT_CLINICAL_SPACE_TEMPLATE |
|
|
boilerplate_template = getattr(config, 'BOILERPLATE_TEMPLATE', DEFAULT_BOILERPLATE_TEMPLATE) if HAS_CONFIG else DEFAULT_BOILERPLATE_TEMPLATE |
|
|
|
|
|
with gr.Blocks(title="Patient Search Prototype", theme=theme, css=custom_css) as demo: |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
|
with gr.Column(scale=4): |
|
|
gr.Markdown(""" |
|
|
# 🔬 Patient Search Prototype |
|
|
**Find patients matching clinical criteria. Designed for clinical trial matching.** |
|
|
""") |
|
|
with gr.Column(scale=1): |
|
|
pass |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("1️⃣ Search"): |
|
|
gr.Markdown(""" |
|
|
### Define Your Search Criteria |
|
|
Enter the clinical criteria to search for matching patients. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
clinical_space_input = gr.Textbox( |
|
|
label="Clinical Criteria", |
|
|
placeholder="Enter eligibility criteria...", |
|
|
value=clinical_space_template, |
|
|
lines=12, |
|
|
info="Define age, sex, cancer type, histology, treatments, biomarkers, etc." |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
boilerplate_input = gr.Textbox( |
|
|
label="Boilerplate Exclusion Criteria", |
|
|
placeholder="Enter boilerplate exclusions...", |
|
|
value=boilerplate_template, |
|
|
lines=12, |
|
|
info="Common exclusions like organ dysfunction, infections, etc." |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
match_btn = gr.Button("🔍 Find Matching Patients", variant="primary", size="lg") |
|
|
with gr.Column(scale=3): |
|
|
with gr.Accordion("Search Settings", open=False): |
|
|
top_k_check_slider = gr.Slider( |
|
|
minimum=5, maximum=10000, value=500, step=50, |
|
|
label="Patients to Check with Classifiers", |
|
|
info="Number of top-ranked patients to run through eligibility/boilerplate models (larger queries take more time)" |
|
|
) |
|
|
eligibility_threshold_slider = gr.Slider( |
|
|
minimum=0.0, maximum=1.0, value=0.5, step=0.05, |
|
|
label="Eligibility Threshold", |
|
|
info="Threshold for counting patients as 'eligible'" |
|
|
) |
|
|
|
|
|
gr.Markdown("### 📊 Results") |
|
|
|
|
|
|
|
|
bottom_line_output = gr.Markdown( |
|
|
value="*Run a search to see results*" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=7): |
|
|
results_df = gr.Dataframe( |
|
|
label="Matched Patients", |
|
|
interactive=False, |
|
|
wrap=True, |
|
|
datatype=["str", "markdown", "markdown", "str", "str"], |
|
|
column_widths=["12%", "12%", "12%", "10%", "54%"] |
|
|
) |
|
|
|
|
|
with gr.Column(scale=5): |
|
|
patient_details = gr.Markdown( |
|
|
label="Patient Details", |
|
|
value="<div style='text-align: center; padding: 50px; color: #666;'>👈 Click on a patient row to see full details here</div>" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
request_btn = gr.Button("📋 Request Identified Patient List", variant="secondary") |
|
|
|
|
|
|
|
|
match_btn.click( |
|
|
fn=match_patients, |
|
|
inputs=[clinical_space_input, boilerplate_input, top_k_check_slider, eligibility_threshold_slider], |
|
|
outputs=[results_df, bottom_line_output] |
|
|
) |
|
|
|
|
|
results_df.select( |
|
|
fn=get_patient_details, |
|
|
inputs=[results_df], |
|
|
outputs=[patient_details] |
|
|
) |
|
|
|
|
|
request_btn.click( |
|
|
fn=request_identified_patients, |
|
|
inputs=[], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("2️⃣ Patient Database"): |
|
|
gr.Markdown("### 📊 Patient Database Management") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Load Pre-embedded Patients (Fast)") |
|
|
preembed_prefix = gr.Textbox( |
|
|
label="Pre-embedded Prefix", |
|
|
placeholder="patient_embeddings", |
|
|
value=getattr(config, 'PREEMBEDDED_PATIENTS', '') or "" if HAS_CONFIG else "" |
|
|
) |
|
|
preembed_btn = gr.Button("Load Pre-embedded", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Upload & Embed New Database") |
|
|
patient_file = gr.File( |
|
|
label="Upload Patient Database (Parquet/CSV/Excel)", |
|
|
file_types=[".parquet", ".csv", ".xlsx", ".xls"] |
|
|
) |
|
|
patient_upload_btn = gr.Button("Process & Embed", variant="secondary") |
|
|
|
|
|
patient_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("patients", "No patients loaded") |
|
|
) |
|
|
|
|
|
patient_preview = gr.Dataframe( |
|
|
label="Patient Preview (first 10)", |
|
|
value=state.patient_preview_df, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
preembed_btn.click( |
|
|
fn=load_preembedded_patients, |
|
|
inputs=[preembed_prefix], |
|
|
outputs=[patient_status, patient_preview] |
|
|
) |
|
|
|
|
|
patient_upload_btn.click( |
|
|
fn=load_and_embed_patients, |
|
|
inputs=[patient_file], |
|
|
outputs=[patient_status, patient_preview] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("3️⃣ Model Configuration"): |
|
|
gr.Markdown("### 🧠 Model Management") |
|
|
|
|
|
status_msg = """ |
|
|
**Config file detected** - Models will auto-load on startup. |
|
|
""" if HAS_CONFIG else """ |
|
|
**No config file found** - Please load models manually below. |
|
|
""" |
|
|
gr.Info(status_msg) |
|
|
|
|
|
with gr.Group(): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
embedder_input = gr.Textbox( |
|
|
label="Embedder Model", |
|
|
placeholder="Qwen/Qwen3-Embedding-0.6B", |
|
|
value=config.MODEL_CONFIG.get("embedder", "") if HAS_CONFIG else "" |
|
|
) |
|
|
embedder_btn = gr.Button("Load Embedder") |
|
|
embedder_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("embedder", ""), |
|
|
elem_classes=["model-status"] |
|
|
) |
|
|
embedder_warning = gr.Textbox(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
trial_checker_input = gr.Textbox( |
|
|
label="Trial Checker Model", |
|
|
placeholder="answerdotai/ModernBERT-large", |
|
|
value=config.MODEL_CONFIG.get("trial_checker", "") if HAS_CONFIG else "" |
|
|
) |
|
|
trial_checker_btn = gr.Button("Load Trial Checker") |
|
|
trial_checker_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("trial_checker", ""), |
|
|
elem_classes=["model-status"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
boilerplate_checker_input = gr.Textbox( |
|
|
label="Boilerplate Checker Model", |
|
|
placeholder="answerdotai/ModernBERT-large", |
|
|
value=config.MODEL_CONFIG.get("boilerplate_checker", "") if HAS_CONFIG else "" |
|
|
) |
|
|
boilerplate_checker_btn = gr.Button("Load Boilerplate Checker") |
|
|
boilerplate_checker_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value=state.auto_load_status.get("boilerplate_checker", ""), |
|
|
elem_classes=["model-status"] |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
pass |
|
|
|
|
|
|
|
|
embedder_btn.click( |
|
|
fn=load_embedder_model, |
|
|
inputs=[embedder_input], |
|
|
outputs=[embedder_status, gr.Textbox(visible=False), embedder_warning] |
|
|
) |
|
|
trial_checker_btn.click( |
|
|
fn=load_trial_checker, |
|
|
inputs=[trial_checker_input], |
|
|
outputs=[trial_checker_status, gr.Textbox(visible=False)] |
|
|
) |
|
|
boilerplate_checker_btn.click( |
|
|
fn=load_boilerplate_checker, |
|
|
inputs=[boilerplate_checker_input], |
|
|
outputs=[boilerplate_checker_status, gr.Textbox(visible=False)] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(f"Device: {state.device}") |
|
|
print(f"GPU Available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"GPU Count: {torch.cuda.device_count()}") |
|
|
|
|
|
|
|
|
if HAS_CONFIG: |
|
|
auto_load_models_from_config() |
|
|
|
|
|
|
|
|
if state.embedder_model is not None or (hasattr(config, 'PREEMBEDDED_PATIENTS') and config.PREEMBEDDED_PATIENTS): |
|
|
auto_load_patients_from_config() |
|
|
|
|
|
demo = create_interface() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
share=False |
|
|
) |
|
|
|