kenlkehl commited on
Commit
0fe62c7
·
verified ·
1 Parent(s): f064c8d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +1128 -0
  2. config.py +105 -0
  3. preembed_patients.py +392 -0
app.py ADDED
@@ -0,0 +1,1128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Patient Matching Pipeline - Gradio Web Interface
6
+
7
+ This interface allows users to:
8
+ 1. Configure models (embedder, trial_checker, boilerplate_checker)
9
+ 2. Upload patient database OR load pre-embedded patients
10
+ 3. Enter set of clinical criteria (trial eligibility criteria)
11
+ 4. Get ranked patient recommendations with eligibility predictions
12
+ """
13
+
14
+ import gradio as gr
15
+ import pandas as pd
16
+ import numpy as np
17
+ import torch
18
+ import os
19
+ import json
20
+ import pickle
21
+ import html
22
+ from typing import List, Tuple
23
+ from pathlib import Path
24
+ import pyarrow.parquet as pq
25
+
26
+ # HuggingFace imports
27
+ from transformers import (
28
+ AutoTokenizer,
29
+ AutoModelForSequenceClassification,
30
+ )
31
+ from sentence_transformers import SentenceTransformer
32
+
33
+ # Try to import configuration
34
+ try:
35
+ import config
36
+ HAS_CONFIG = True
37
+ print("✓ Found config.py - will auto-load models on startup")
38
+ except ImportError:
39
+ HAS_CONFIG = False
40
+ print("○ No config.py found - using manual model loading")
41
+
42
+ # ============================================================================
43
+ # GLOBAL STATE
44
+ # ============================================================================
45
+
46
+ class AppState:
47
+ def __init__(self):
48
+ self.embedder_model = None
49
+ self.embedder_tokenizer = None
50
+ self.trial_checker_model = None
51
+ self.trial_checker_tokenizer = None
52
+ self.boilerplate_checker_model = None
53
+ self.boilerplate_checker_tokenizer = None
54
+
55
+ self.patient_df = None
56
+ self.patient_embeddings = None
57
+ self.patient_preview_df = None
58
+
59
+ # Store last results for export
60
+ self.last_results_df = None
61
+
62
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
63
+
64
+ self.auto_load_status = {
65
+ "embedder": "",
66
+ "trial_checker": "",
67
+ "boilerplate_checker": "",
68
+ "patients": ""
69
+ }
70
+
71
+ def reset_patients(self):
72
+ self.patient_df = None
73
+ self.patient_embeddings = None
74
+ self.patient_preview_df = None
75
+
76
+ state = AppState()
77
+
78
+ # ============================================================================
79
+ # CONSTANTS
80
+ # ============================================================================
81
+
82
+ MAX_EMBEDDER_SEQ_LEN = 2500
83
+ MAX_TRIAL_CHECKER_LENGTH = 4096
84
+ MAX_BOILERPLATE_CHECKER_LENGTH = 3192
85
+ CLASSIFIER_BATCH_SIZE = 32 # Batch size for trial_checker and boilerplate_checker inference
86
+
87
+ # Default templates
88
+ DEFAULT_CLINICAL_SPACE_TEMPLATE = """Age range allowed:
89
+ Sex allowed:
90
+ Cancer type allowed:
91
+ Histology allowed:
92
+ Cancer burden allowed:
93
+ Prior treatment required:
94
+ Prior treatment excluded:
95
+ Biomarkers required:
96
+ Biomarkers excluded: """
97
+
98
+ DEFAULT_BOILERPLATE_TEMPLATE = """History of pneumonitis:
99
+ Heart failure or cardiac dysfunction:
100
+ Renal dysfunction:
101
+ Liver dysfunction:
102
+ Uncontrolled brain metastases:
103
+ HIV or hepatitis infection:
104
+ Poor performance status (ECOG >= 2):
105
+ Other relevant exclusions: """
106
+
107
+ # ============================================================================
108
+ # UTILITY FUNCTIONS
109
+ # ============================================================================
110
+
111
+ def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str:
112
+ """Truncate text to a maximum number of tokens."""
113
+ return tokenizer.decode(
114
+ tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens),
115
+ skip_special_tokens=True
116
+ )
117
+
118
+
119
+ def format_probability_visual(val, is_exclusion=False):
120
+ """Format probabilities with visual indicators."""
121
+ try:
122
+ val_float = float(val)
123
+ except:
124
+ return val
125
+
126
+ if not is_exclusion:
127
+ # High eligibility is good
128
+ if val_float >= 0.8:
129
+ return f"🟢 **{val_float:.2f}**"
130
+ elif val_float >= 0.5:
131
+ return f"🟡 {val_float:.2f}"
132
+ else:
133
+ return f"🔴 {val_float:.2f}"
134
+ else:
135
+ # High exclusion is bad
136
+ if val_float >= 0.5:
137
+ return f"🔴 **{val_float:.2f}**"
138
+ elif val_float >= 0.2:
139
+ return f"🟡 {val_float:.2f}"
140
+ else:
141
+ return f"🟢 {val_float:.2f}"
142
+
143
+
144
+ # ============================================================================
145
+ # AUTO-LOADING FROM CONFIG
146
+ # ============================================================================
147
+
148
+ def auto_load_models_from_config():
149
+ """Auto-load models specified in config.py"""
150
+ if not HAS_CONFIG:
151
+ return
152
+
153
+ print("\n" + "="*70)
154
+ print("AUTO-LOADING MODELS FROM CONFIG")
155
+ print("="*70)
156
+
157
+ # Load embedder
158
+ if config.MODEL_CONFIG.get("embedder"):
159
+ print(f"\n[1/3] Loading embedder: {config.MODEL_CONFIG['embedder']}")
160
+ status, _, _ = load_embedder_model(config.MODEL_CONFIG["embedder"])
161
+ state.auto_load_status["embedder"] = status
162
+ print(status)
163
+
164
+ # Load trial checker
165
+ if config.MODEL_CONFIG.get("trial_checker"):
166
+ print(f"\n[2/3] Loading trial checker: {config.MODEL_CONFIG['trial_checker']}")
167
+ status, _ = load_trial_checker(config.MODEL_CONFIG["trial_checker"])
168
+ state.auto_load_status["trial_checker"] = status
169
+ print(status)
170
+
171
+ # Load boilerplate checker
172
+ if config.MODEL_CONFIG.get("boilerplate_checker"):
173
+ print(f"\n[3/3] Loading boilerplate checker: {config.MODEL_CONFIG['boilerplate_checker']}")
174
+ status, _ = load_boilerplate_checker(config.MODEL_CONFIG["boilerplate_checker"])
175
+ state.auto_load_status["boilerplate_checker"] = status
176
+ print(status)
177
+
178
+ print("\n" + "="*70)
179
+ print("MODEL AUTO-LOADING COMPLETE")
180
+ print("="*70 + "\n")
181
+
182
+
183
+ def auto_load_patients_from_config():
184
+ """Auto-load patient database from config.py - prefers pre-embedded over fresh embedding."""
185
+ if not HAS_CONFIG:
186
+ return
187
+
188
+ # Check for pre-embedded patients first (much faster)
189
+ if hasattr(config, 'PREEMBEDDED_PATIENTS') and config.PREEMBEDDED_PATIENTS:
190
+ preembed_path = config.PREEMBEDDED_PATIENTS
191
+
192
+ # Handle URL paths for Hugging Face datasets
193
+ if preembed_path.startswith("http://") or preembed_path.startswith("https://"):
194
+ print("\n" + "="*70)
195
+ print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (URL): {preembed_path}")
196
+ print("="*70)
197
+
198
+ status, preview = load_preembedded_patients(preembed_path)
199
+ state.auto_load_status["patients"] = status
200
+ state.patient_preview_df = preview
201
+
202
+ print("="*70)
203
+ print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE")
204
+ print("="*70 + "\n")
205
+ return
206
+
207
+ # Check for new parquet format first, then fall back to old format
208
+ parquet_path = preembed_path if preembed_path.endswith('.parquet') else f"{preembed_path}.parquet"
209
+ old_format_data = f"{preembed_path}_data.pkl"
210
+
211
+ if os.path.exists(parquet_path):
212
+ # New parquet format
213
+ print("\n" + "="*70)
214
+ print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (parquet): {parquet_path}")
215
+ print("="*70)
216
+
217
+ status, preview = load_preembedded_patients(parquet_path)
218
+ state.auto_load_status["patients"] = status
219
+ state.patient_preview_df = preview
220
+
221
+ print("="*70)
222
+ print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE")
223
+ print("="*70 + "\n")
224
+ return
225
+ elif os.path.exists(old_format_data):
226
+ # Old format (pkl + npy + json)
227
+ print("\n" + "="*70)
228
+ print(f"AUTO-LOADING PRE-EMBEDDED PATIENTS (legacy): {preembed_path}")
229
+ print("="*70)
230
+
231
+ status, preview = load_preembedded_patients(preembed_path)
232
+ state.auto_load_status["patients"] = status
233
+ state.patient_preview_df = preview
234
+
235
+ print("="*70)
236
+ print("PRE-EMBEDDED PATIENTS AUTO-LOADING COMPLETE")
237
+ print("="*70 + "\n")
238
+ return
239
+ else:
240
+ print(f"✗ Pre-embedded patient files not found: {preembed_path}")
241
+ state.auto_load_status["patients"] = f"✗ Pre-embedded files not found: {preembed_path}"
242
+ return
243
+
244
+ # Fall back to fresh embedding if no pre-embedded patients specified
245
+ if not hasattr(config, 'DEFAULT_PATIENT_DB') or not config.DEFAULT_PATIENT_DB:
246
+ print("○ No patient database specified in config")
247
+ return
248
+
249
+ if not os.path.exists(config.DEFAULT_PATIENT_DB):
250
+ print(f"✗ Default patient database not found: {config.DEFAULT_PATIENT_DB}")
251
+ state.auto_load_status["patients"] = f"✗ Patient database file not found: {config.DEFAULT_PATIENT_DB}"
252
+ return
253
+
254
+ if state.embedder_model is None:
255
+ print("○ Embedder not loaded yet - skipping patient database auto-load")
256
+ state.auto_load_status["patients"] = "○ Waiting for embedder model to be loaded..."
257
+ return
258
+
259
+ print("\n" + "="*70)
260
+ print(f"AUTO-LOADING PATIENT DATABASE: {config.DEFAULT_PATIENT_DB}")
261
+ print("="*70)
262
+
263
+ class FilePath:
264
+ def __init__(self, path):
265
+ self.name = path
266
+
267
+ status, preview = load_and_embed_patients(FilePath(config.DEFAULT_PATIENT_DB), show_progress=True)
268
+ state.auto_load_status["patients"] = status
269
+ state.patient_preview_df = preview
270
+
271
+ print("="*70)
272
+ print("PATIENT DATABASE AUTO-LOADING COMPLETE")
273
+ print("="*70 + "\n")
274
+
275
+
276
+ # ============================================================================
277
+ # MODEL LOADING FUNCTIONS
278
+ # ============================================================================
279
+
280
+ def load_embedder_model(model_path: str) -> Tuple[str, str, str]:
281
+ """Load sentence transformer embedder model."""
282
+ try:
283
+ will_need_reembed = state.patient_df is not None and len(state.patient_df) > 0
284
+
285
+ if will_need_reembed:
286
+ warning_msg = f"\n⚠️ Warning: {len(state.patient_df)} patients are currently loaded. They will need to be re-embedded with the new model."
287
+ else:
288
+ warning_msg = ""
289
+
290
+ state.embedder_model = SentenceTransformer(model_path, device=state.device, trust_remote_code=True)
291
+ state.embedder_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
292
+
293
+ # Set the instruction prompt
294
+ try:
295
+ state.embedder_model.prompts['query'] = (
296
+ "Instruct: Given a cancer patient summary, retrieve clinical trial options "
297
+ "that are reasonable for that patient; or, given a clinical trial option, "
298
+ "retrieve cancer patients who are reasonable candidates for that trial."
299
+ )
300
+ except:
301
+ pass
302
+
303
+ try:
304
+ state.embedder_model.max_seq_length = MAX_EMBEDDER_SEQ_LEN
305
+ except:
306
+ pass
307
+
308
+ success_msg = f"✓ Embedder model loaded from {model_path}{warning_msg}"
309
+
310
+ if will_need_reembed:
311
+ state.patient_embeddings = None
312
+ success_msg += "\n→ Patient embeddings cleared. Please reload patient database to re-embed."
313
+
314
+ return success_msg, "", warning_msg
315
+ except Exception as e:
316
+ return f"✗ Error loading embedder model: {str(e)}", str(e), ""
317
+
318
+
319
+ def load_trial_checker(model_path: str) -> Tuple[str, str]:
320
+ """Load ModernBERT trial checker."""
321
+ try:
322
+ state.trial_checker_tokenizer = AutoTokenizer.from_pretrained(model_path)
323
+ state.trial_checker_model = AutoModelForSequenceClassification.from_pretrained(
324
+ model_path,
325
+ torch_dtype=torch.float16 if state.device == "cuda" else torch.float32
326
+ ).to(state.device)
327
+ state.trial_checker_model.eval()
328
+ return f"✓ Trial checker loaded from {model_path}", ""
329
+ except Exception as e:
330
+ return f"✗ Error loading trial checker: {str(e)}", str(e)
331
+
332
+
333
+ def load_boilerplate_checker(model_path: str) -> Tuple[str, str]:
334
+ """Load ModernBERT boilerplate checker."""
335
+ try:
336
+ state.boilerplate_checker_tokenizer = AutoTokenizer.from_pretrained(model_path)
337
+ state.boilerplate_checker_model = AutoModelForSequenceClassification.from_pretrained(
338
+ model_path,
339
+ torch_dtype=torch.float16 if state.device == "cuda" else torch.float32
340
+ ).to(state.device)
341
+ state.boilerplate_checker_model.eval()
342
+ return f"✓ Boilerplate checker loaded from {model_path}", ""
343
+ except Exception as e:
344
+ return f"✗ Error loading boilerplate checker: {str(e)}", str(e)
345
+
346
+
347
+ # ============================================================================
348
+ # PATIENT DATA LOADING
349
+ # ============================================================================
350
+
351
+ def load_preembedded_patients(preembedded_path: str) -> Tuple[str, pd.DataFrame]:
352
+ """Load pre-embedded patient database from disk.
353
+
354
+ Supports two formats:
355
+ 1. New format: Single parquet file with patient_embedding column
356
+ - Path should end with .parquet
357
+ - Embeddings stored as lists in patient_embedding column
358
+ - Metadata stored in parquet file metadata
359
+
360
+ 2. Legacy format: Separate pkl/npy/json files
361
+ - Path is a prefix (e.g., "patient_embeddings")
362
+ - Creates patient_embeddings_data.pkl, _vectors.npy, _metadata.json
363
+ """
364
+ try:
365
+ # Determine format based on path
366
+ is_parquet = preembedded_path.endswith('.parquet') or os.path.exists(f"{preembedded_path}.parquet") if not preembedded_path.endswith('.parquet') else True
367
+
368
+ if is_parquet:
369
+ return _load_preembedded_parquet(preembedded_path)
370
+ else:
371
+ return _load_preembedded_legacy(preembedded_path)
372
+
373
+ except Exception as e:
374
+ import traceback
375
+ traceback.print_exc()
376
+ return f"✗ Error loading pre-embedded patients: {str(e)}", None
377
+
378
+
379
+ def _load_preembedded_parquet(parquet_path: str) -> Tuple[str, pd.DataFrame]:
380
+ """Load pre-embedded patients from new single parquet format."""
381
+ is_url = parquet_path.startswith("http://") or parquet_path.startswith("https://")
382
+
383
+ # Ensure .parquet extension for local files
384
+ if not is_url and not parquet_path.endswith('.parquet'):
385
+ parquet_path = f"{parquet_path}.parquet"
386
+
387
+ if not is_url and not os.path.exists(parquet_path):
388
+ return f"✗ Pre-embedded parquet file not found: {parquet_path}", None
389
+
390
+ print(f"\n{'='*70}")
391
+ print(f"LOADING PRE-EMBEDDED PATIENTS (Parquet Format)")
392
+ print(f"{'='*70}")
393
+ print(f"Loading from: {parquet_path}")
394
+
395
+ try:
396
+ # Read parquet file - from URL or local path
397
+ if is_url:
398
+ df = pd.read_parquet(parquet_path)
399
+ # For remote files, we can't easily read pyarrow metadata without downloading
400
+ # the file first, so we'll just load the dataframe directly.
401
+ print(f"Metadata: (Skipped for URL)")
402
+ else:
403
+ # Read local parquet file with pyarrow to access metadata
404
+ parquet_file = pq.read_table(parquet_path)
405
+
406
+ # Extract metadata if available
407
+ if parquet_file.schema.metadata and b'patient_embedding_metadata' in parquet_file.schema.metadata:
408
+ metadata = json.loads(parquet_file.schema.metadata[b'patient_embedding_metadata'].decode('utf-8'))
409
+ print(f"Metadata:")
410
+ print(f" Created: {metadata.get('created_at', 'unknown')}")
411
+ print(f" Embedder: {metadata.get('embedder_model', 'unknown')}")
412
+ print(f" Patients: {metadata.get('num_patients', 'unknown')}")
413
+ print(f" Embedding dim: {metadata.get('embedding_dim', 'unknown')}")
414
+
415
+ # Convert to pandas
416
+ df = parquet_file.to_pandas()
417
+
418
+ except Exception as e:
419
+ error_msg = f"✗ Failed to read parquet file from {parquet_path}: {str(e)}"
420
+ print(error_msg)
421
+ return error_msg, None
422
+
423
+ print(f"✓ Loaded {len(df)} patients")
424
+ print(f" Columns: {', '.join(df.columns.tolist())}")
425
+
426
+ # Check for required columns
427
+ if 'patient_embedding' not in df.columns:
428
+ return f"✗ Parquet file missing 'patient_embedding' column: {parquet_path}", None
429
+
430
+ if 'patient_id' not in df.columns:
431
+ return f"✗ Parquet file missing 'patient_id' column: {parquet_path}", None
432
+
433
+ if 'patient_summary' not in df.columns:
434
+ return f"✗ Parquet file missing 'patient_summary' column: {parquet_path}", None
435
+
436
+ # Check boilerplate column
437
+ if 'patient_boilerplate' in df.columns:
438
+ non_empty_bp = (df['patient_boilerplate'].astype(str).str.strip().str.len() > 0).sum()
439
+ print(f" ✓ patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text")
440
+ else:
441
+ print(f" ⚠ No patient_boilerplate column found")
442
+ df['patient_boilerplate'] = ''
443
+
444
+ # Extract embeddings from column and convert to numpy array
445
+ print(f"Converting embeddings to numpy array...")
446
+ embeddings = np.array(df['patient_embedding'].tolist(), dtype=np.float32)
447
+ print(f"✓ Loaded embeddings: {embeddings.shape}")
448
+
449
+ # Remove embedding column from dataframe (we store it separately in memory)
450
+ df_without_embeddings = df.drop(columns=['patient_embedding'])
451
+
452
+ state.patient_df = df_without_embeddings
453
+ state.patient_embeddings = embeddings
454
+
455
+ print(f"{'='*70}")
456
+ print(f"PRE-EMBEDDED PATIENTS LOADED SUCCESSFULLY")
457
+ print(f"{'='*70}\n")
458
+
459
+ preview = df_without_embeddings[['patient_id', 'patient_summary']].head(10)
460
+ return f"✓ Loaded {len(df)} pre-embedded patients from {os.path.basename(parquet_path)}", preview
461
+
462
+
463
+ def _load_preembedded_legacy(preembedded_prefix: str) -> Tuple[str, pd.DataFrame]:
464
+ """Load pre-embedded patients from legacy format (pkl + npy + json files)."""
465
+ data_file = f"{preembedded_prefix}_data.pkl"
466
+ vectors_file = f"{preembedded_prefix}_vectors.npy"
467
+ metadata_file = f"{preembedded_prefix}_metadata.json"
468
+
469
+ if not os.path.exists(data_file):
470
+ return f"✗ Pre-embedded data file not found: {data_file}", None
471
+ if not os.path.exists(vectors_file):
472
+ return f"✗ Pre-embedded vectors file not found: {vectors_file}", None
473
+
474
+ print(f"\n{'='*70}")
475
+ print(f"LOADING PRE-EMBEDDED PATIENTS (Legacy Format)")
476
+ print(f"{'='*70}")
477
+ print(f"Loading from: {preembedded_prefix}_*")
478
+
479
+ if os.path.exists(metadata_file):
480
+ with open(metadata_file, 'r') as f:
481
+ metadata = json.load(f)
482
+ print(f"Metadata:")
483
+ print(f" Created: {metadata.get('created_at', 'unknown')}")
484
+ print(f" Embedder: {metadata.get('embedder_model', 'unknown')}")
485
+ print(f" Patients: {metadata.get('num_patients', 'unknown')}")
486
+ print(f" Embedding dim: {metadata.get('embedding_dim', 'unknown')}")
487
+
488
+ print(f"Loading patient dataframe...")
489
+ with open(data_file, 'rb') as f:
490
+ df = pickle.load(f)
491
+ print(f"✓ Loaded {len(df)} patients")
492
+ print(f" Columns: {', '.join(df.columns.tolist())}")
493
+
494
+ # Check boilerplate column
495
+ if 'patient_boilerplate' in df.columns:
496
+ non_empty_bp = (df['patient_boilerplate'].astype(str).str.strip().str.len() > 0).sum()
497
+ print(f" ✓ patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text")
498
+ else:
499
+ print(f" ⚠ No patient_boilerplate column found")
500
+ df['patient_boilerplate'] = ''
501
+
502
+ print(f"Loading embeddings...")
503
+ embeddings = np.load(vectors_file)
504
+ print(f"✓ Loaded embeddings: {embeddings.shape}")
505
+
506
+ if len(df) != embeddings.shape[0]:
507
+ return (
508
+ f"✗ Mismatch: {len(df)} patients but {embeddings.shape[0]} embeddings",
509
+ None
510
+ )
511
+
512
+ state.patient_df = df
513
+ state.patient_embeddings = embeddings
514
+
515
+ print(f"{'='*70}")
516
+ print(f"PRE-EMBEDDED PATIENTS LOADED SUCCESSFULLY")
517
+ print(f"{'='*70}\n")
518
+
519
+ preview = df[['patient_id', 'patient_summary']].head(10)
520
+ return f"✓ Loaded {len(df)} pre-embedded patients from {preembedded_prefix}_*", preview
521
+
522
+
523
+ def load_and_embed_patients(file, show_progress: bool = False) -> Tuple[str, pd.DataFrame]:
524
+ """Load patient database and embed summaries."""
525
+ try:
526
+ if state.embedder_model is None:
527
+ return "✗ Please load the embedder model first!", None
528
+
529
+ # Read file
530
+ if file.name.endswith('.parquet'):
531
+ df = pd.read_parquet(file.name)
532
+ elif file.name.endswith('.csv'):
533
+ df = pd.read_csv(file.name)
534
+ elif file.name.endswith(('.xlsx', '.xls')):
535
+ df = pd.read_excel(file.name)
536
+ else:
537
+ return "✗ Unsupported format. Use Parquet, CSV, or Excel.", None
538
+
539
+ # Check required columns
540
+ required_cols = ['patient_id', 'patient_summary']
541
+ missing = [col for col in required_cols if col not in df.columns]
542
+ if missing:
543
+ return f"✗ Missing columns: {', '.join(missing)}", None
544
+
545
+ # Clean data
546
+ df = df[~df['patient_summary'].isnull()].copy()
547
+ df = df[df['patient_summary'].astype(str).str.strip().str.len() > 0].copy()
548
+
549
+ if 'patient_boilerplate' not in df.columns:
550
+ df['patient_boilerplate'] = ''
551
+ else:
552
+ df['patient_boilerplate'] = df['patient_boilerplate'].fillna('')
553
+
554
+ # Prepare texts for embedding
555
+ df['patient_summary_trunc'] = df['patient_summary'].apply(
556
+ lambda x: truncate_text(str(x), state.embedder_tokenizer, max_tokens=1500)
557
+ )
558
+
559
+ prefix = (
560
+ "Instruct: Given a cancer patient summary, retrieve clinical trial options "
561
+ "that are reasonable for that patient; or, given a clinical trial option, "
562
+ "retrieve cancer patients who are reasonable candidates for that trial. "
563
+ )
564
+ texts_to_embed = [prefix + txt for txt in df['patient_summary_trunc'].tolist()]
565
+
566
+ if not show_progress:
567
+ gr.Info(f"Embedding {len(df)} patient summaries...")
568
+ else:
569
+ print(f"Embedding {len(df)} patient summaries...")
570
+
571
+ with torch.no_grad():
572
+ embeddings = state.embedder_model.encode(
573
+ texts_to_embed,
574
+ batch_size=64,
575
+ convert_to_tensor=True,
576
+ normalize_embeddings=True,
577
+ show_progress_bar=show_progress,
578
+ prompt='query'
579
+ )
580
+
581
+ state.patient_df = df
582
+ state.patient_embeddings = embeddings.cpu().numpy()
583
+
584
+ preview = df[['patient_id', 'patient_summary']].head(10)
585
+
586
+ success_msg = f"✓ Loaded and embedded {len(df)} patients"
587
+ if show_progress:
588
+ print(success_msg)
589
+
590
+ return success_msg, preview
591
+
592
+ except Exception as e:
593
+ return f"✗ Error processing patients: {str(e)}", None
594
+
595
+
596
+ # ============================================================================
597
+ # PATIENT MATCHING
598
+ # ============================================================================
599
+
600
+ def match_patients(
601
+ clinical_space: str,
602
+ boilerplate_criteria: str,
603
+ top_k_check: int = 1000,
604
+ eligibility_threshold: float = 0.5
605
+ ) -> Tuple[pd.DataFrame, str]:
606
+ """Match clinical query to patients and run eligibility checks."""
607
+ try:
608
+ if state.embedder_model is None:
609
+ raise ValueError("Embedder model not loaded")
610
+ if state.patient_embeddings is None:
611
+ raise ValueError("Patient database not loaded")
612
+ if state.trial_checker_model is None:
613
+ raise ValueError("Trial checker model not loaded")
614
+ if state.boilerplate_checker_model is None:
615
+ raise ValueError("Boilerplate checker model not loaded")
616
+
617
+ if not clinical_space or not clinical_space.strip():
618
+ raise ValueError("Please enter clinical criteria")
619
+
620
+ # Embed clinical query
621
+ prefix = (
622
+ "Instruct: Given a cancer patient summary, retrieve clinical trial options "
623
+ "that are reasonable for that patient; or, given a clinical trial option, "
624
+ "retrieve cancer patients who are reasonable candidates for that trial. "
625
+ )
626
+
627
+ query_text = truncate_text(clinical_space, state.embedder_tokenizer, max_tokens=MAX_EMBEDDER_SEQ_LEN)
628
+ query_text_with_prefix = prefix + query_text
629
+
630
+ gr.Info("Ranking all patients by similarity...")
631
+
632
+ with torch.no_grad():
633
+ query_emb = state.embedder_model.encode(
634
+ [query_text_with_prefix],
635
+ convert_to_tensor=True,
636
+ normalize_embeddings=True,
637
+ prompt='query'
638
+ )
639
+
640
+ # Calculate similarities for all patients
641
+ query_emb_np = query_emb.cpu().numpy()
642
+ similarities = np.dot(state.patient_embeddings, query_emb_np.T).squeeze()
643
+
644
+ # Rank all patients by similarity
645
+ sorted_indices = np.argsort(similarities)[::-1]
646
+
647
+ # Get all patients ranked
648
+ all_patients_ranked = state.patient_df.iloc[sorted_indices].copy()
649
+ all_patients_ranked['similarity_score'] = similarities[sorted_indices]
650
+
651
+ # Limit to top_k_check for classifier models
652
+ top_k_check = min(top_k_check, len(all_patients_ranked))
653
+ patients_to_check = all_patients_ranked.head(top_k_check).copy()
654
+
655
+ gr.Info(f"Running eligibility checks on top {len(patients_to_check)} patients...")
656
+
657
+ # Run trial checker in batches
658
+ trial_check_inputs = [
659
+ f"{clinical_space}\nNow here is the patient summary:{row['patient_summary']}"
660
+ for _, row in patients_to_check.iterrows()
661
+ ]
662
+
663
+ trial_probs_list = []
664
+ for i in range(0, len(trial_check_inputs), CLASSIFIER_BATCH_SIZE):
665
+ batch_inputs = trial_check_inputs[i:i + CLASSIFIER_BATCH_SIZE]
666
+
667
+ batch_encodings = state.trial_checker_tokenizer(
668
+ batch_inputs,
669
+ truncation=True,
670
+ max_length=MAX_TRIAL_CHECKER_LENGTH,
671
+ padding=True,
672
+ return_tensors='pt'
673
+ ).to(state.device)
674
+
675
+ with torch.no_grad():
676
+ batch_outputs = state.trial_checker_model(**batch_encodings)
677
+ batch_probs = torch.softmax(batch_outputs.logits, dim=1)[:, 1].cpu().numpy()
678
+ trial_probs_list.append(batch_probs)
679
+
680
+ trial_probs = np.concatenate(trial_probs_list)
681
+ patients_to_check['eligibility_probability'] = trial_probs
682
+
683
+ # Run boilerplate checker in batches
684
+ # Use patient_boilerplate if available, otherwise fall back to patient_summary
685
+ def get_boilerplate_text(row):
686
+ bp = row.get('patient_boilerplate', '')
687
+ if bp and isinstance(bp, str) and bp.strip():
688
+ return bp
689
+ return row['patient_summary']
690
+
691
+ boilerplate_check_inputs = [
692
+ f"Patient history: {get_boilerplate_text(row)}\nTrial exclusions:{boilerplate_criteria}"
693
+ for _, row in patients_to_check.iterrows()
694
+ ]
695
+
696
+ boilerplate_probs_list = []
697
+ for i in range(0, len(boilerplate_check_inputs), CLASSIFIER_BATCH_SIZE):
698
+ batch_inputs = boilerplate_check_inputs[i:i + CLASSIFIER_BATCH_SIZE]
699
+
700
+ batch_encodings = state.boilerplate_checker_tokenizer(
701
+ batch_inputs,
702
+ truncation=True,
703
+ max_length=MAX_BOILERPLATE_CHECKER_LENGTH,
704
+ padding=True,
705
+ return_tensors='pt'
706
+ ).to(state.device)
707
+
708
+ with torch.no_grad():
709
+ batch_outputs = state.boilerplate_checker_model(**batch_encodings)
710
+ batch_probs = torch.softmax(batch_outputs.logits, dim=1)[:, 1].cpu().numpy()
711
+ boilerplate_probs_list.append(batch_probs)
712
+
713
+ boilerplate_probs = np.concatenate(boilerplate_probs_list)
714
+ patients_to_check['exclusion_probability'] = boilerplate_probs
715
+
716
+ # Sort by eligibility probability
717
+ patients_to_check = patients_to_check.sort_values('eligibility_probability', ascending=False)
718
+
719
+ # Store full results for export
720
+ state.last_results_df = patients_to_check.copy()
721
+
722
+ # Calculate bottom line stats
723
+ num_eligible = (patients_to_check['eligibility_probability'] >= eligibility_threshold).sum()
724
+ num_no_exclusion = (patients_to_check['exclusion_probability'] < 0.5).sum()
725
+ num_both = ((patients_to_check['eligibility_probability'] >= eligibility_threshold) &
726
+ (patients_to_check['exclusion_probability'] < 0.5)).sum()
727
+
728
+ bottom_line = f"""
729
+ ### 📊 Summary: Patients Meeting Your Criteria
730
+ | Metric | Count |
731
+ |--------|-------|
732
+ | Total patients in database | **{len(state.patient_df)}** |
733
+ | Top patients checked with classifiers | **{len(patients_to_check)}** |
734
+ | Meeting eligibility criteria (≥{eligibility_threshold}) | **{num_eligible}** |
735
+ | Without boilerplate exclusions (<0.5) | **{num_no_exclusion}** |
736
+ | **Meeting BOTH criteria** | **{num_both}** |
737
+ """
738
+
739
+ # Format for display
740
+ patients_to_check['eligibility_display'] = patients_to_check['eligibility_probability'].apply(
741
+ lambda x: format_probability_visual(x, is_exclusion=False)
742
+ )
743
+ patients_to_check['exclusion_display'] = patients_to_check['exclusion_probability'].apply(
744
+ lambda x: format_probability_visual(x, is_exclusion=True)
745
+ )
746
+ patients_to_check['similarity_display'] = patients_to_check['similarity_score'].apply(
747
+ lambda x: f"{x:.3f}"
748
+ )
749
+
750
+ # Truncate summary for display
751
+ patients_to_check['summary_preview'] = patients_to_check['patient_summary'].apply(
752
+ lambda x: str(x)[:300] + "..." if len(str(x)) > 300 else str(x)
753
+ )
754
+
755
+ # Select columns for display
756
+ display_cols = [
757
+ 'patient_id',
758
+ 'eligibility_display',
759
+ 'exclusion_display',
760
+ 'similarity_display',
761
+ 'summary_preview'
762
+ ]
763
+
764
+ result_df = patients_to_check[display_cols].reset_index(drop=True)
765
+ result_df.columns = [
766
+ 'Patient ID',
767
+ 'Eligibility',
768
+ 'Exclusion',
769
+ 'Similarity',
770
+ 'Summary Preview'
771
+ ]
772
+
773
+ return result_df, bottom_line
774
+
775
+ except Exception as e:
776
+ gr.Error(f"Error matching patients: {str(e)}")
777
+ return pd.DataFrame(), f"**Error:** {str(e)}"
778
+
779
+
780
+ def get_patient_details(df: pd.DataFrame, evt: gr.SelectData) -> str:
781
+ """Get full patient details when user clicks on a row."""
782
+ try:
783
+ if df is None or len(df) == 0:
784
+ return "No patient selected"
785
+
786
+ row_idx = evt.index[0]
787
+ patient_id = df.iloc[row_idx]['Patient ID']
788
+
789
+ # Find in full results
790
+ if state.last_results_df is None:
791
+ return "No results available"
792
+
793
+ matching_rows = state.last_results_df[
794
+ state.last_results_df['patient_id'] == patient_id
795
+ ]
796
+
797
+ if len(matching_rows) == 0:
798
+ return f"Error: Could not find patient {patient_id}"
799
+
800
+ patient_row = matching_rows.iloc[0]
801
+
802
+ # Get boilerplate text - use same fallback logic as the checker
803
+ raw_boilerplate = patient_row.get('patient_boilerplate', '')
804
+ has_separate_boilerplate = raw_boilerplate and isinstance(raw_boilerplate, str) and raw_boilerplate.strip()
805
+
806
+ if has_separate_boilerplate:
807
+ boilerplate_text = raw_boilerplate
808
+ else:
809
+ boilerplate_text = "(No separate boilerplate column - patient summary was used for boilerplate exclusion check)"
810
+
811
+ # Escape any HTML characters in the text
812
+ summary_escaped = html.escape(str(patient_row['patient_summary']))
813
+ boilerplate_escaped = html.escape(str(boilerplate_text))
814
+
815
+ details = f"""
816
+ # Patient Details: {patient_id}
817
+
818
+ ---
819
+
820
+ ## Scores
821
+ - **Eligibility Probability:** {patient_row['eligibility_probability']:.3f}
822
+ - **Exclusion Probability:** {patient_row['exclusion_probability']:.3f}
823
+ - **Similarity Score:** {patient_row['similarity_score']:.3f}
824
+
825
+ ---
826
+
827
+ ## Full Patient Summary
828
+ <pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{summary_escaped}</pre>
829
+
830
+ ---
831
+
832
+ ## Boilerplate Exclusion Check Input
833
+ <pre style="white-space: pre-wrap; word-wrap: break-word; background-color: #1a1a1a; color: #ffffff; padding: 10px; border-radius: 5px; font-family: monospace; font-size: 0.9em;">{boilerplate_escaped}</pre>
834
+ """
835
+ return details
836
+
837
+ except Exception as e:
838
+ return f"Error retrieving patient details: {str(e)}"
839
+
840
+
841
+ def request_identified_patients():
842
+ """Placeholder for requesting identified patient list."""
843
+ if state.last_results_df is None or len(state.last_results_df) == 0:
844
+ gr.Warning("No results to request - run a search first")
845
+ return
846
+
847
+ # TODO: Implement actual request functionality
848
+ gr.Info("Request functionality not yet implemented")
849
+
850
+
851
+ # ============================================================================
852
+ # GRADIO INTERFACE
853
+ # ============================================================================
854
+
855
+ def create_interface():
856
+
857
+ theme = gr.themes.Soft(
858
+ primary_hue="teal",
859
+ secondary_hue="slate",
860
+ ).set(
861
+ body_background_fill="*neutral_50",
862
+ block_background_fill="white",
863
+ block_border_width="1px",
864
+ block_label_background_fill="*primary_50",
865
+ )
866
+
867
+ custom_css = """
868
+ .gradio-container { font-family: 'Inter', Arial, sans-serif !important; }
869
+ .model-status { min-height: 80px !important; font-size: 0.9em; }
870
+ .status-box { background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; padding: 10px; }
871
+ h1 { color: #0d9488; }
872
+ """
873
+
874
+ # Get templates from config or use defaults
875
+ clinical_space_template = getattr(config, 'CLINICAL_SPACE_TEMPLATE', DEFAULT_CLINICAL_SPACE_TEMPLATE) if HAS_CONFIG else DEFAULT_CLINICAL_SPACE_TEMPLATE
876
+ boilerplate_template = getattr(config, 'BOILERPLATE_TEMPLATE', DEFAULT_BOILERPLATE_TEMPLATE) if HAS_CONFIG else DEFAULT_BOILERPLATE_TEMPLATE
877
+
878
+ with gr.Blocks(title="Patient Search Prototype", theme=theme, css=custom_css) as demo:
879
+
880
+ with gr.Row(variant="panel"):
881
+ with gr.Column(scale=4):
882
+ gr.Markdown("""
883
+ # 🔬 Patient Search Prototype
884
+ **Find patients matching clinical criteria. Designed for clinical trial matching.**
885
+ """)
886
+ with gr.Column(scale=1):
887
+ pass
888
+
889
+ with gr.Tabs():
890
+ # ============= TAB 1: SEARCH =============
891
+ with gr.Tab("1️⃣ Search"):
892
+ gr.Markdown("""
893
+ ### Define Your Search Criteria
894
+ Enter the clinical criteria to search for matching patients.
895
+ """)
896
+
897
+ with gr.Row():
898
+ with gr.Column():
899
+ clinical_space_input = gr.Textbox(
900
+ label="Clinical Criteria",
901
+ placeholder="Enter eligibility criteria...",
902
+ value=clinical_space_template,
903
+ lines=12,
904
+ info="Define age, sex, cancer type, histology, treatments, biomarkers, etc."
905
+ )
906
+
907
+ with gr.Column():
908
+ boilerplate_input = gr.Textbox(
909
+ label="Boilerplate Exclusion Criteria",
910
+ placeholder="Enter boilerplate exclusions...",
911
+ value=boilerplate_template,
912
+ lines=12,
913
+ info="Common exclusions like organ dysfunction, infections, etc."
914
+ )
915
+
916
+ gr.Markdown("---")
917
+
918
+ with gr.Row():
919
+ with gr.Column(scale=1):
920
+ match_btn = gr.Button("🔍 Find Matching Patients", variant="primary", size="lg")
921
+ with gr.Column(scale=3):
922
+ with gr.Accordion("Search Settings", open=False):
923
+ top_k_check_slider = gr.Slider(
924
+ minimum=5, maximum=10000, value=500, step=50,
925
+ label="Patients to Check with Classifiers",
926
+ info="Number of top-ranked patients to run through eligibility/boilerplate models (larger queries take more time)"
927
+ )
928
+ eligibility_threshold_slider = gr.Slider(
929
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
930
+ label="Eligibility Threshold",
931
+ info="Threshold for counting patients as 'eligible'"
932
+ )
933
+
934
+ gr.Markdown("### 📊 Results")
935
+
936
+ # Bottom line summary
937
+ bottom_line_output = gr.Markdown(
938
+ value="*Run a search to see results*"
939
+ )
940
+
941
+ with gr.Row():
942
+ with gr.Column(scale=7):
943
+ results_df = gr.Dataframe(
944
+ label="Matched Patients",
945
+ interactive=False,
946
+ wrap=True,
947
+ datatype=["str", "markdown", "markdown", "str", "str"],
948
+ column_widths=["12%", "12%", "12%", "10%", "54%"]
949
+ )
950
+
951
+ with gr.Column(scale=5):
952
+ patient_details = gr.Markdown(
953
+ label="Patient Details",
954
+ value="<div style='text-align: center; padding: 50px; color: #666;'>👈 Click on a patient row to see full details here</div>"
955
+ )
956
+
957
+ # Request identified patients button
958
+ with gr.Row():
959
+ request_btn = gr.Button("📋 Request Identified Patient List", variant="secondary")
960
+
961
+ # Wire up matching
962
+ match_btn.click(
963
+ fn=match_patients,
964
+ inputs=[clinical_space_input, boilerplate_input, top_k_check_slider, eligibility_threshold_slider],
965
+ outputs=[results_df, bottom_line_output]
966
+ )
967
+
968
+ results_df.select(
969
+ fn=get_patient_details,
970
+ inputs=[results_df],
971
+ outputs=[patient_details]
972
+ )
973
+
974
+ request_btn.click(
975
+ fn=request_identified_patients,
976
+ inputs=[],
977
+ outputs=[]
978
+ )
979
+
980
+ # ============= TAB 2: PATIENT DATABASE =============
981
+ with gr.Tab("2️⃣ Patient Database"):
982
+ gr.Markdown("### 📊 Patient Database Management")
983
+
984
+ with gr.Row():
985
+ with gr.Column():
986
+ gr.Markdown("#### Load Pre-embedded Patients (Fast)")
987
+ preembed_prefix = gr.Textbox(
988
+ label="Pre-embedded Prefix",
989
+ placeholder="patient_embeddings",
990
+ value=getattr(config, 'PREEMBEDDED_PATIENTS', '') or "" if HAS_CONFIG else ""
991
+ )
992
+ preembed_btn = gr.Button("Load Pre-embedded", variant="secondary")
993
+
994
+ with gr.Column():
995
+ gr.Markdown("#### Upload & Embed New Database")
996
+ patient_file = gr.File(
997
+ label="Upload Patient Database (Parquet/CSV/Excel)",
998
+ file_types=[".parquet", ".csv", ".xlsx", ".xls"]
999
+ )
1000
+ patient_upload_btn = gr.Button("Process & Embed", variant="secondary")
1001
+
1002
+ patient_status = gr.Textbox(
1003
+ label="Status",
1004
+ interactive=False,
1005
+ value=state.auto_load_status.get("patients", "No patients loaded")
1006
+ )
1007
+
1008
+ patient_preview = gr.Dataframe(
1009
+ label="Patient Preview (first 10)",
1010
+ value=state.patient_preview_df,
1011
+ wrap=True
1012
+ )
1013
+
1014
+ preembed_btn.click(
1015
+ fn=load_preembedded_patients,
1016
+ inputs=[preembed_prefix],
1017
+ outputs=[patient_status, patient_preview]
1018
+ )
1019
+
1020
+ patient_upload_btn.click(
1021
+ fn=load_and_embed_patients,
1022
+ inputs=[patient_file],
1023
+ outputs=[patient_status, patient_preview]
1024
+ )
1025
+
1026
+ # ============= TAB 3: MODEL CONFIGURATION =============
1027
+ with gr.Tab("3️⃣ Model Configuration"):
1028
+ gr.Markdown("### 🧠 Model Management")
1029
+
1030
+ status_msg = """
1031
+ **Config file detected** - Models will auto-load on startup.
1032
+ """ if HAS_CONFIG else """
1033
+ **No config file found** - Please load models manually below.
1034
+ """
1035
+ gr.Info(status_msg)
1036
+
1037
+ with gr.Group():
1038
+ with gr.Row():
1039
+ with gr.Column():
1040
+ embedder_input = gr.Textbox(
1041
+ label="Embedder Model",
1042
+ placeholder="Qwen/Qwen3-Embedding-0.6B",
1043
+ value=config.MODEL_CONFIG.get("embedder", "") if HAS_CONFIG else ""
1044
+ )
1045
+ embedder_btn = gr.Button("Load Embedder")
1046
+ embedder_status = gr.Textbox(
1047
+ label="Status",
1048
+ interactive=False,
1049
+ value=state.auto_load_status.get("embedder", ""),
1050
+ elem_classes=["model-status"]
1051
+ )
1052
+ embedder_warning = gr.Textbox(visible=False)
1053
+
1054
+ with gr.Column():
1055
+ trial_checker_input = gr.Textbox(
1056
+ label="Trial Checker Model",
1057
+ placeholder="answerdotai/ModernBERT-large",
1058
+ value=config.MODEL_CONFIG.get("trial_checker", "") if HAS_CONFIG else ""
1059
+ )
1060
+ trial_checker_btn = gr.Button("Load Trial Checker")
1061
+ trial_checker_status = gr.Textbox(
1062
+ label="Status",
1063
+ interactive=False,
1064
+ value=state.auto_load_status.get("trial_checker", ""),
1065
+ elem_classes=["model-status"]
1066
+ )
1067
+
1068
+ with gr.Row():
1069
+ with gr.Column(scale=1):
1070
+ boilerplate_checker_input = gr.Textbox(
1071
+ label="Boilerplate Checker Model",
1072
+ placeholder="answerdotai/ModernBERT-large",
1073
+ value=config.MODEL_CONFIG.get("boilerplate_checker", "") if HAS_CONFIG else ""
1074
+ )
1075
+ boilerplate_checker_btn = gr.Button("Load Boilerplate Checker")
1076
+ boilerplate_checker_status = gr.Textbox(
1077
+ label="Status",
1078
+ interactive=False,
1079
+ value=state.auto_load_status.get("boilerplate_checker", ""),
1080
+ elem_classes=["model-status"]
1081
+ )
1082
+ with gr.Column(scale=1):
1083
+ pass
1084
+
1085
+ # Wire up model loading
1086
+ embedder_btn.click(
1087
+ fn=load_embedder_model,
1088
+ inputs=[embedder_input],
1089
+ outputs=[embedder_status, gr.Textbox(visible=False), embedder_warning]
1090
+ )
1091
+ trial_checker_btn.click(
1092
+ fn=load_trial_checker,
1093
+ inputs=[trial_checker_input],
1094
+ outputs=[trial_checker_status, gr.Textbox(visible=False)]
1095
+ )
1096
+ boilerplate_checker_btn.click(
1097
+ fn=load_boilerplate_checker,
1098
+ inputs=[boilerplate_checker_input],
1099
+ outputs=[boilerplate_checker_status, gr.Textbox(visible=False)]
1100
+ )
1101
+
1102
+ return demo
1103
+
1104
+
1105
+ # ============================================================================
1106
+ # MAIN
1107
+ # ============================================================================
1108
+
1109
+ if __name__ == "__main__":
1110
+ print(f"Device: {state.device}")
1111
+ print(f"GPU Available: {torch.cuda.is_available()}")
1112
+ if torch.cuda.is_available():
1113
+ print(f"GPU Count: {torch.cuda.device_count()}")
1114
+
1115
+ # Auto-load models from config if available
1116
+ if HAS_CONFIG:
1117
+ auto_load_models_from_config()
1118
+
1119
+ # Auto-load patients after embedder is ready
1120
+ if state.embedder_model is not None or (hasattr(config, 'PREEMBEDDED_PATIENTS') and config.PREEMBEDDED_PATIENTS):
1121
+ auto_load_patients_from_config()
1122
+
1123
+ demo = create_interface()
1124
+ demo.launch(
1125
+ server_name="0.0.0.0",
1126
+ server_port=7861,
1127
+ share=False
1128
+ )
config.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Patient Matching Pipeline
2
+ #
3
+ # Edit the values below to set your default models and patient database.
4
+ # Models will auto-load on application startup.
5
+
6
+ # ============================================================================
7
+ # MODEL PATHS - Set your default models here
8
+ # ============================================================================
9
+
10
+ # Set to None to skip auto-loading, or provide model path/HuggingFace ID
11
+ MODEL_CONFIG = {
12
+ # Sentence transformer for embedding patient summaries and clinical spaces
13
+ "embedder": "ksg-dfci/TrialSpace-1225", # e.g., "Qwen/Qwen3-Embedding-0.6B" or "./reranker_round2.model"
14
+
15
+ # ModernBERT classifier for eligibility prediction
16
+ "trial_checker": "ksg-dfci/TrialChecker-1225", # e.g., "answerdotai/ModernBERT-large" or "./modernbert-trial-checker"
17
+
18
+ # ModernBERT classifier for boilerplate exclusion prediction
19
+ "boilerplate_checker": "ksg-dfci/BoilerplateChecker-1225", # e.g., "answerdotai/ModernBERT-large" or "./modernbert-boilerplate-checker"
20
+ }
21
+
22
+ # Example configuration with base models:
23
+ # MODEL_CONFIG = {
24
+ # "embedder": "Qwen/Qwen3-Embedding-0.6B",
25
+ # "trial_checker": "answerdotai/ModernBERT-large",
26
+ # "boilerplate_checker": "answerdotai/ModernBERT-large",
27
+ # }
28
+
29
+ # Example configuration with fine-tuned models:
30
+ # MODEL_CONFIG = {
31
+ # "embedder": "./reranker_round2.model",
32
+ # "trial_checker": "./modernbert-trial-checker",
33
+ # "boilerplate_checker": "./modernbert-boilerplate-checker",
34
+ # }
35
+
36
+ # ============================================================================
37
+ # DEFAULT PATIENT DATABASE
38
+ # ============================================================================
39
+
40
+ # Path to default patient database parquet file
41
+ # Required columns: patient_id, patient_summary
42
+ # Optional columns: patient_boilerplate (for boilerplate checking)
43
+ # Will auto-load and embed when embedder model is ready
44
+ # Set to None to disable auto-loading
45
+ #DEFAULT_PATIENT_DB = "./synthetic_patient_summary_sample.parquet" # e.g., "./patients.parquet" or "./patient_summaries.parquet"
46
+
47
+ # Path to pre-embedded patient database (faster loading)
48
+ #
49
+ # NEW FORMAT (recommended): Single parquet file with embedding column
50
+ # - Created by: python preembed_patients.py --output patient_embeddings.parquet
51
+ # - Contains all patient data + patient_embedding column (list of floats)
52
+ # - Compatible with Hugging Face datasets
53
+ # - Example: PREEMBEDDED_PATIENTS = "synthetic_patient_embeddings.parquet"
54
+ #
55
+ # LEGACY FORMAT (still supported): Prefix for pkl/npy/json files
56
+ # - Created by old version of preembed_patients.py
57
+ # - Files: {prefix}_data.pkl, {prefix}_vectors.npy, {prefix}_metadata.json
58
+ # - Example: PREEMBEDDED_PATIENTS = "synthetic_patient_embeddings"
59
+ #
60
+ PREEMBEDDED_PATIENTS = "https://huggingface.co/datasets/ksg-dfci/mmai-synthetic/resolve/main/synthetic_patient_embeddings.parquet" # e.g., "patient_embeddings.parquet" or "patient_embeddings" (legacy)
61
+
62
+ # ============================================================================
63
+ # CLINICAL SPACE TEMPLATE
64
+ # ============================================================================
65
+
66
+ # Default template for the clinical space query input
67
+ # Users will fill in these fields to define their search criteria
68
+ CLINICAL_SPACE_TEMPLATE = """Age range allowed: any
69
+ Sex allowed: Any
70
+ Cancer type allowed: Non-small cell lung cancer
71
+ Histology allowed: Adenocarcinoma
72
+ Cancer burden allowed: Metastatic
73
+ Prior treatment required: No requirements
74
+ Prior treatment excluded: No requirements
75
+ Biomarkers required: EGFR mutant
76
+ Biomarkers excluded: None"""
77
+
78
+ BOILERPLATE_TEMPLATE = "Patients must have no history of pneumonitis"
79
+
80
+
81
+ # ============================================================================
82
+ # USAGE NOTES
83
+ # ============================================================================
84
+ #
85
+ # 1. Set the model paths above to your preferred models
86
+ # 2. Optionally set DEFAULT_PATIENT_DB or PREEMBEDDED_PATIENTS
87
+ # 3. Customize CLINICAL_SPACE_TEMPLATE if needed
88
+ # 4. Save this file
89
+ # 5. Run: streamlit run patient_matching_app.py
90
+ # 6. Models will load automatically on startup
91
+ #
92
+ # To create pre-embedded patients (new parquet format, recommended):
93
+ # python preembed_patients.py --patients patients.parquet --embedder path/to/embedder --output patient_embeddings.parquet
94
+ #
95
+ # To upload pre-embedded patients to Hugging Face Hub:
96
+ # from datasets import Dataset
97
+ # ds = Dataset.from_parquet("patient_embeddings.parquet")
98
+ # ds.push_to_hub("your-username/patient-embeddings")
99
+ #
100
+ # To load pre-embedded patients from Hugging Face Hub in your app:
101
+ # from datasets import load_dataset
102
+ # ds = load_dataset("your-username/patient-embeddings")
103
+ # ds['train'].to_parquet("local_patient_embeddings.parquet")
104
+ # # Then set PREEMBEDDED_PATIENTS = "local_patient_embeddings.parquet"
105
+ #
preembed_patients.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Pre-embed Patient Summaries Script
6
+
7
+ This script pre-processes and embeds a patient database,
8
+ saving the results to a single Parquet file for faster loading
9
+ in the main application and compatibility with Hugging Face datasets.
10
+
11
+ Usage:
12
+ python preembed_patients.py --patients ../v20_public_data/patient_summaries_and_their_spaces.parquet --embedder ksg-dfci/TrialSpace-1225 --output synthetic_patient_embeddings.parquet --gpus 0,1 --patient-boilerplate-col patient_boilerplate_text --patient-id-col pseudo_mrn
13
+
14
+
15
+ This will create:
16
+ - synthetic_patient_embeddings.parquet: Patient dataframe with embedding vectors as a column
17
+
18
+ The parquet file contains:
19
+ - All original patient columns (patient_id, patient_summary, patient_boilerplate, etc.)
20
+ - patient_embedding: The embedding vector for each patient (stored as list of floats)
21
+ - Metadata stored in parquet file metadata (embedder model, creation date, etc.)
22
+
23
+ To upload to Hugging Face:
24
+ from datasets import Dataset
25
+ ds = Dataset.from_parquet("synthetic_patient_embeddings.parquet")
26
+ ds.push_to_hub("your-username/patient-embeddings")
27
+ """
28
+
29
+ import argparse
30
+ import pandas as pd
31
+ import numpy as np
32
+ import torch
33
+ import json
34
+ import pyarrow as pa
35
+ import pyarrow.parquet as pq
36
+ from pathlib import Path
37
+ from datetime import datetime
38
+ from typing import Tuple, List
39
+ from sentence_transformers import SentenceTransformer
40
+ from transformers import AutoTokenizer
41
+
42
+
43
+ def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str:
44
+ """Truncate text to a maximum number of tokens."""
45
+ return tokenizer.decode(
46
+ tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens),
47
+ skip_special_tokens=True
48
+ )
49
+
50
+
51
+ def load_patients(file_path: str, patient_id_col: str = 'patient_id', patient_boilerplate_col: str = 'patient_boilerplate') -> pd.DataFrame:
52
+ """Load patients from parquet file."""
53
+ print(f"\n{'='*70}")
54
+ print(f"Loading patient database from: {file_path}")
55
+ print(f"{'='*70}")
56
+
57
+ if file_path.endswith('.parquet'):
58
+ df = pd.read_parquet(file_path)
59
+ elif file_path.endswith('.csv'):
60
+ df = pd.read_csv(file_path)
61
+ elif file_path.endswith(('.xlsx', '.xls')):
62
+ df = pd.read_excel(file_path)
63
+ else:
64
+ raise ValueError("Unsupported file format. Use Parquet, CSV, or Excel.")
65
+
66
+ # Check required columns
67
+ required_cols = [patient_id_col, 'patient_summary']
68
+ missing = [col for col in required_cols if col not in df.columns]
69
+ if missing:
70
+ raise ValueError(f"Missing required columns: {', '.join(missing)}")
71
+
72
+ # Rename patient_id column to standard name if different
73
+ if patient_id_col != 'patient_id':
74
+ df = df.rename(columns={patient_id_col: 'patient_id'})
75
+ print(f" Renamed column '{patient_id_col}' to 'patient_id'")
76
+
77
+ print(f"✓ Loaded {len(df)} patients")
78
+ print(f" Columns: {', '.join(df.columns.tolist())}")
79
+
80
+ # Clean data
81
+ original_count = len(df)
82
+ df = df[~df['patient_summary'].isnull()].copy()
83
+ df = df[df['patient_summary'].str.strip().str.len() > 0].copy()
84
+
85
+ # Handle boilerplate column
86
+ if patient_boilerplate_col and patient_boilerplate_col in df.columns:
87
+ if patient_boilerplate_col != 'patient_boilerplate':
88
+ df = df.rename(columns={patient_boilerplate_col: 'patient_boilerplate'})
89
+ print(f" Renamed column '{patient_boilerplate_col}' to 'patient_boilerplate'")
90
+ df['patient_boilerplate'] = df['patient_boilerplate'].fillna('')
91
+ non_empty_bp = (df['patient_boilerplate'].str.strip().str.len() > 0).sum()
92
+ print(f" ✓ Found patient_boilerplate column: {non_empty_bp}/{len(df)} patients have boilerplate text")
93
+ else:
94
+ df['patient_boilerplate'] = ''
95
+ if patient_boilerplate_col:
96
+ print(f" ⚠ Column '{patient_boilerplate_col}' not found - patient_boilerplate will be empty")
97
+ else:
98
+ print(f" ○ No boilerplate column specified - patient_boilerplate will be empty")
99
+
100
+ if len(df) < original_count:
101
+ print(f" ⚠ Removed {original_count - len(df)} patients with missing/empty 'patient_summary'")
102
+
103
+ return df
104
+
105
+
106
+ def embed_patients(df: pd.DataFrame, embedder_path: str, device: str = None, gpus: list = None) -> Tuple[np.ndarray, str]:
107
+ """Embed patient summaries using the specified embedder model.
108
+
109
+ Args:
110
+ df: DataFrame with patient data
111
+ embedder_path: Path to embedder model
112
+ device: Single device string (e.g., 'cuda:0', 'cpu') - used if gpus not specified
113
+ gpus: List of GPU indices for multi-GPU parallel processing (e.g., [0, 1, 2, 3])
114
+ """
115
+ print(f"\n{'='*70}")
116
+ print(f"Loading embedder model: {embedder_path}")
117
+ print(f"{'='*70}")
118
+
119
+ # Determine device configuration
120
+ use_multi_gpu = gpus is not None and len(gpus) > 1
121
+
122
+ if use_multi_gpu:
123
+ target_devices = [f"cuda:{gpu}" for gpu in gpus]
124
+ print(f"Multi-GPU mode: {target_devices}")
125
+ # Load model on CPU first for multi-process pool
126
+ embedder_model = SentenceTransformer(embedder_path, device='cpu', trust_remote_code=True)
127
+ else:
128
+ if gpus is not None and len(gpus) == 1:
129
+ device = f"cuda:{gpus[0]}"
130
+ elif device is None:
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ print(f"Device: {device}")
133
+ embedder_model = SentenceTransformer(embedder_path, device=device, trust_remote_code=True)
134
+
135
+ embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_path, trust_remote_code=True)
136
+
137
+ print(f"✓ Embedder loaded")
138
+
139
+ # Set the instruction prompt
140
+ try:
141
+ embedder_model.prompts['query'] = (
142
+ "Instruct: Given a cancer patient summary, retrieve clinical trial options "
143
+ "that are reasonable for that patient; or, given a clinical trial option, "
144
+ "retrieve cancer patients who are reasonable candidates for that trial."
145
+ )
146
+ except:
147
+ pass
148
+
149
+ try:
150
+ embedder_model.max_seq_length = 2500
151
+ except:
152
+ pass
153
+
154
+ print(f"\n{'='*70}")
155
+ print(f"Embedding {len(df)} patient summaries")
156
+ print(f"{'='*70}")
157
+
158
+ # Prepare texts for embedding
159
+ df['patient_summary_trunc'] = df['patient_summary'].apply(
160
+ lambda x: truncate_text(str(x), embedder_tokenizer, max_tokens=1500)
161
+ )
162
+
163
+ # Add instruction prefix
164
+ prefix = (
165
+ "Instruct: Given a cancer patient summary, retrieve clinical trial options "
166
+ "that are reasonable for that patient; or, given a clinical trial option, "
167
+ "retrieve cancer patients who are reasonable candidates for that trial. "
168
+ )
169
+ texts_to_embed = [prefix + txt for txt in df['patient_summary_trunc'].tolist()]
170
+
171
+ print(f" Text length stats:")
172
+ print(f" Mean: {np.mean([len(t) for t in texts_to_embed]):.0f} chars")
173
+ print(f" Max: {max([len(t) for t in texts_to_embed])} chars")
174
+
175
+ # Embed with progress bar
176
+ if use_multi_gpu:
177
+ print(f" Starting multi-GPU pool on {target_devices}...")
178
+ pool = embedder_model.start_multi_process_pool(target_devices=target_devices)
179
+
180
+ try:
181
+ embeddings_np = embedder_model.encode_multi_process(
182
+ texts_to_embed,
183
+ pool,
184
+ batch_size=64,
185
+ normalize_embeddings=True,
186
+ )
187
+ finally:
188
+ embedder_model.stop_multi_process_pool(pool)
189
+ else:
190
+ with torch.no_grad():
191
+ embeddings = embedder_model.encode(
192
+ texts_to_embed,
193
+ batch_size=64,
194
+ convert_to_tensor=True,
195
+ normalize_embeddings=True,
196
+ show_progress_bar=True,
197
+ prompt='query'
198
+ )
199
+ embeddings_np = embeddings.cpu().numpy()
200
+
201
+ print(f"✓ Embedding complete")
202
+ print(f" Shape: {embeddings_np.shape}")
203
+ print(f" Dtype: {embeddings_np.dtype}")
204
+
205
+ return embeddings_np, embedder_path
206
+
207
+
208
+ def save_embeddings(df: pd.DataFrame, embeddings: np.ndarray, output_path: str, embedder_path: str, gpus: list = None):
209
+ """Save patient data with embeddings to a single Parquet file.
210
+
211
+ The embeddings are stored as a column of lists, which is compatible with
212
+ Hugging Face datasets and PyArrow.
213
+ """
214
+ print(f"\n{'='*70}")
215
+ print(f"Saving to: {output_path}")
216
+ print(f"{'='*70}")
217
+
218
+ # Ensure output path ends with .parquet
219
+ if not output_path.endswith('.parquet'):
220
+ output_path = f"{output_path}.parquet"
221
+
222
+ output_dir = Path(output_path).parent
223
+ if str(output_dir) and str(output_dir) != '.':
224
+ output_dir.mkdir(parents=True, exist_ok=True)
225
+
226
+ # Add embeddings as a column (convert numpy arrays to lists for parquet compatibility)
227
+ df_out = df.copy()
228
+ df_out['patient_embedding'] = [emb.tolist() for emb in embeddings]
229
+
230
+ # Create metadata dictionary
231
+ metadata = {
232
+ "created_at": datetime.now().isoformat(),
233
+ "embedder_model": embedder_path,
234
+ "num_patients": str(len(df)),
235
+ "embedding_dim": str(embeddings.shape[1]),
236
+ "embedding_dtype": str(embeddings.dtype),
237
+ "normalized": "true",
238
+ "gpus_used": str(gpus) if gpus else "single device",
239
+ "format_version": "2.0", # Version indicator for the new format
240
+ }
241
+
242
+ # Convert DataFrame to PyArrow Table
243
+ table = pa.Table.from_pandas(df_out)
244
+
245
+ # Add metadata to the table schema
246
+ existing_metadata = table.schema.metadata or {}
247
+ existing_metadata[b'patient_embedding_metadata'] = json.dumps(metadata).encode('utf-8')
248
+ table = table.replace_schema_metadata(existing_metadata)
249
+
250
+ # Write to parquet
251
+ pq.write_table(table, output_path)
252
+
253
+ file_size_mb = Path(output_path).stat().st_size / 1024 / 1024
254
+ print(f"✓ Saved parquet file: {output_path}")
255
+ print(f" Size: {file_size_mb:.2f} MB")
256
+ print(f" Columns: {', '.join(df_out.columns.tolist())}")
257
+ print(f" Embedding column: patient_embedding (dim={embeddings.shape[1]})")
258
+
259
+ print(f"\n{'='*70}")
260
+ print(f"PRE-EMBEDDING COMPLETE")
261
+ print(f"{'='*70}")
262
+ print(f"\nTo use these pre-embedded patients in your app:")
263
+ print(f"1. Update config.py with:")
264
+ print(f" PREEMBEDDED_PATIENTS = '{output_path}'")
265
+ print(f"2. Restart the application")
266
+ print(f"\nThe app will automatically load these embeddings on startup!")
267
+ print(f"\nTo upload to Hugging Face Hub:")
268
+ print(f" from datasets import Dataset")
269
+ print(f" ds = Dataset.from_parquet('{output_path}')")
270
+ print(f" ds.push_to_hub('your-username/patient-embeddings')")
271
+
272
+
273
+ def main():
274
+ parser = argparse.ArgumentParser(
275
+ description="Pre-embed patient summaries for faster loading",
276
+ formatter_class=argparse.RawDescriptionHelpFormatter,
277
+ epilog="""
278
+ Examples:
279
+ python preembed_patients.py --patients data/patients.parquet --embedder models/embedder --output embeddings/patient_embeddings.parquet
280
+ python preembed_patients.py --patients patients.csv --embedder Qwen/Qwen3-Embedding-0.6B --output patient_embeddings.parquet --device cuda
281
+ python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --patient-id-col mrn
282
+ python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --gpus 0,1,2,3
283
+ python preembed_patients.py --patients data.parquet --embedder models/embedder --output out.parquet --patient-boilerplate-col boilerplate_summary
284
+
285
+ Hugging Face Upload:
286
+ After creating the parquet file, you can upload to Hugging Face Hub:
287
+ from datasets import Dataset
288
+ ds = Dataset.from_parquet("patient_embeddings.parquet")
289
+ ds.push_to_hub("your-username/patient-embeddings")
290
+ """
291
+ )
292
+
293
+ parser.add_argument(
294
+ '--patients',
295
+ type=str,
296
+ required=True,
297
+ help='Path to patient database (Parquet, CSV, or Excel). Required columns: patient_summary and the patient ID column (default: patient_id, or specify with --patient-id-col)'
298
+ )
299
+
300
+ parser.add_argument(
301
+ '--embedder',
302
+ type=str,
303
+ required=True,
304
+ help='Path to embedder model or HuggingFace model name'
305
+ )
306
+
307
+ parser.add_argument(
308
+ '--output',
309
+ type=str,
310
+ required=True,
311
+ help='Output path for the parquet file (e.g., "patient_embeddings.parquet")'
312
+ )
313
+
314
+ parser.add_argument(
315
+ '--device',
316
+ type=str,
317
+ default=None,
318
+ help='Device to use for embedding (default: auto-detect). Examples: cuda, cuda:0, cuda:3, cpu. Ignored if --gpus is specified.'
319
+ )
320
+
321
+ parser.add_argument(
322
+ '--patient-id-col',
323
+ type=str,
324
+ default='patient_id',
325
+ help='Name of the patient ID column in the input file (default: patient_id)'
326
+ )
327
+
328
+ parser.add_argument(
329
+ '--patient-boilerplate-col',
330
+ type=str,
331
+ default='patient_boilerplate',
332
+ help='Name of the patient boilerplate column in the input file (default: patient_boilerplate). Set to empty string to skip.'
333
+ )
334
+
335
+ parser.add_argument(
336
+ '--gpus',
337
+ type=str,
338
+ default=None,
339
+ help='Comma-separated list of GPU indices for multi-GPU parallel processing (e.g., "0,1,2,3"). Overrides --device if specified.'
340
+ )
341
+
342
+ args = parser.parse_args()
343
+
344
+ # Parse GPU list if provided
345
+ gpu_list = None
346
+ if args.gpus:
347
+ try:
348
+ gpu_list = [int(g.strip()) for g in args.gpus.split(',')]
349
+ except ValueError:
350
+ print(f"✗ ERROR: Invalid GPU list format: {args.gpus}")
351
+ print(" Use comma-separated integers, e.g., '0,1,2,3'")
352
+ return 1
353
+
354
+ print(f"\n{'='*70}")
355
+ print(f"PATIENT SUMMARY PRE-EMBEDDING SCRIPT")
356
+ print(f"{'='*70}")
357
+ print(f"Patient Database: {args.patients}")
358
+ print(f"Embedder Model: {args.embedder}")
359
+ print(f"Output File: {args.output}")
360
+ print(f"Patient ID Col: {args.patient_id_col}")
361
+ print(f"Boilerplate Col: {args.patient_boilerplate_col or '(none)'}")
362
+ if gpu_list:
363
+ print(f"GPUs: {gpu_list} (multi-GPU mode)")
364
+ elif args.device:
365
+ print(f"Device: {args.device}")
366
+ else:
367
+ print(f"Device: auto-detect")
368
+ print(f"{'='*70}\n")
369
+
370
+ try:
371
+ # Load patients
372
+ df = load_patients(args.patients, args.patient_id_col, args.patient_boilerplate_col)
373
+
374
+ # Embed patients
375
+ embeddings, embedder_path = embed_patients(df, args.embedder, args.device, gpu_list)
376
+
377
+ # Save everything to single parquet file
378
+ save_embeddings(df, embeddings, args.output, embedder_path, gpu_list)
379
+
380
+ print(f"\n✓ SUCCESS!")
381
+
382
+ except Exception as e:
383
+ print(f"\n✗ ERROR: {str(e)}")
384
+ import traceback
385
+ traceback.print_exc()
386
+ return 1
387
+
388
+ return 0
389
+
390
+
391
+ if __name__ == "__main__":
392
+ exit(main())