kenlkehl commited on
Commit
2ed7323
·
verified ·
1 Parent(s): 3a186e7

Upload 8 files

Browse files
config.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Clinical Trial Matching Pipeline
2
+ #
3
+ # Edit the values below to set your default models and trial database.
4
+ # Models will auto-load on application startup.
5
+
6
+ # ============================================================================
7
+ # MODEL PATHS - Set your default models here
8
+ # ============================================================================
9
+
10
+ # Set to None to skip auto-loading, or provide model path/HuggingFace ID
11
+ MODEL_CONFIG = {
12
+ # TinyBERT tagger for extracting relevant excerpts
13
+ "tagger": "/ksg/kehl_mm_data/meta/2024/v17/v17_models/auto-tiny-bert-tagger", # e.g., "prajjwal1/bert-tiny" or "./auto-tiny-bert-tagger"
14
+
15
+ # Sentence transformer for embedding patient summaries and trials
16
+ "embedder": "/ksg/kehl_mm_data/meta/2024/v17/v17_models/reranker_round2.model", # e.g., "Qwen/Qwen3-Embedding-0.6B" or "./reranker_round2.model"
17
+
18
+ # Large language model for patient history summarization
19
+ "llm": "meta-llama/Llama-3.2-3B-Instruct", # e.g., "microsoft/Phi-3-mini-4k-instruct" or "openai/gpt-oss-120b"
20
+
21
+ # ModernBERT classifier for eligibility prediction
22
+ "trial_checker": "/ksg/kehl_mm_data/meta/2024/v17/v17_models/modernbert-trial-checker", # e.g., "answerdotai/ModernBERT-large" or "./modernbert-trial-checker"
23
+
24
+ # ModernBERT classifier for boilerplate exclusion prediction
25
+ "boilerplate_checker": "/ksg/kehl_mm_data/meta/2024/v17/v17_models/modernbert-boilerplate-checker", # e.g., "answerdotai/ModernBERT-large" or "./modernbert-boilerplate-checker"
26
+ }
27
+
28
+ # Example configuration with base models:
29
+ # MODEL_CONFIG = {
30
+ # "tagger": "prajjwal1/bert-tiny",
31
+ # "embedder": "Qwen/Qwen3-Embedding-0.6B",
32
+ # "llm": "microsoft/Phi-3-mini-4k-instruct",
33
+ # "trial_checker": "answerdotai/ModernBERT-large",
34
+ # "boilerplate_checker": "answerdotai/ModernBERT-large",
35
+ # }
36
+
37
+ # Example configuration with fine-tuned models:
38
+ # MODEL_CONFIG = {
39
+ # "tagger": "./auto-tiny-bert-tagger",
40
+ # "embedder": "./reranker_round2.model",
41
+ # "llm": "/data/models/gpt-oss-120b",
42
+ # "trial_checker": "./modernbert-trial-checker",
43
+ # "boilerplate_checker": "./modernbert-boilerplate-checker",
44
+ # }
45
+
46
+ # ============================================================================
47
+ # DEFAULT TRIAL DATABASE
48
+ # ============================================================================
49
+
50
+ # Path to default trial database CSV/Excel file
51
+ # Will auto-load and embed when embedder model is ready
52
+ # Set to None to disable auto-loading
53
+ DEFAULT_TRIAL_DB = "/data1/ken/meta/2024/v17b/trial_space_lineitems.csv" # e.g., "./my_trials.csv" or "./sample_trials.csv"
54
+
55
+ PREEMBEDDED_TRIALS = "trial_embeddings"
56
+
57
+ # ============================================================================
58
+ # USAGE NOTES
59
+ # ============================================================================
60
+ #
61
+ # 1. Set the model paths above to your preferred models
62
+ # 2. Optionally set DEFAULT_TRIAL_DB to your trial database file
63
+ # 3. Save this file
64
+ # 4. Run: python trial_matching_app.py
65
+ # 5. Models will load automatically on startup
66
+ #
67
+ # You can still manually load different models through the web interface
68
+ # if you need to switch models during a session.
69
+ #
create_sample_data.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate sample data for testing the Clinical Trial Matching Pipeline
4
+ """
5
+
6
+ import pandas as pd
7
+ from datetime import datetime, timedelta
8
+
9
+ def create_sample_trials():
10
+ """Create a sample trial database CSV."""
11
+
12
+ trials = [
13
+ {
14
+ 'nct_id': 'NCT12345678',
15
+ 'this_space': '''Metastatic non-small cell lung cancer (NSCLC) with EGFR exon 19 deletion or L858R mutation
16
+ Prior treatment: At least one prior platinum-based chemotherapy regimen
17
+ ECOG performance status: 0-2
18
+ Measurable disease per RECIST v1.1
19
+ Adequate organ function''',
20
+ 'trial_text': '''Phase III randomized study of osimertinib versus platinum-based chemotherapy in patients with
21
+ EGFR-mutated metastatic NSCLC who have progressed on first-line EGFR TKI therapy. Primary endpoint is progression-free
22
+ survival. Secondary endpoints include overall survival, objective response rate, and quality of life.''',
23
+ 'trial_boilerplate_text': '''No active brain metastases requiring immediate intervention
24
+ No prior treatment with third-generation EGFR TKIs
25
+ No interstitial lung disease or pneumonitis
26
+ No congestive heart failure NYHA class III-IV
27
+ No HIV, hepatitis B, or hepatitis C infection'''
28
+ },
29
+ {
30
+ 'nct_id': 'NCT23456789',
31
+ 'this_space': '''HER2-positive metastatic breast cancer
32
+ Prior treatment: Trastuzumab and pertuzumab in any setting
33
+ ECOG performance status: 0-1
34
+ Brain metastases allowed if treated and stable
35
+ LVEF ≥50%''',
36
+ 'trial_text': '''Phase II study of trastuzumab deruxtecan in HER2-positive metastatic breast cancer patients
37
+ who have received prior trastuzumab and pertuzumab. Primary endpoint is objective response rate. Key secondary endpoints
38
+ include duration of response, progression-free survival, and safety.''',
39
+ 'trial_boilerplate_text': '''No history of pneumonitis or interstitial lung disease
40
+ No concurrent cardiac dysfunction
41
+ No active hepatitis B or C infection
42
+ No pregnancy or breastfeeding'''
43
+ },
44
+ {
45
+ 'nct_id': 'NCT34567890',
46
+ 'this_space': '''Advanced melanoma with BRAF V600E or V600K mutation
47
+ Treatment-naive for metastatic disease (adjuvant therapy allowed if completed >6 months prior)
48
+ ECOG performance status: 0-1
49
+ No active autoimmune disease requiring systemic therapy
50
+ Adequate bone marrow, hepatic, and renal function''',
51
+ 'trial_text': '''Phase III randomized trial comparing dabrafenib plus trametinib versus vemurafenib monotherapy
52
+ in previously untreated BRAF-mutant metastatic melanoma. Primary endpoint is overall survival. Secondary endpoints include
53
+ progression-free survival, response rate, and toxicity.''',
54
+ 'trial_boilerplate_text': '''No prior systemic therapy for metastatic melanoma
55
+ No active brain metastases (treated and stable brain metastases allowed)
56
+ No history of inflammatory bowel disease
57
+ No significant cardiac disease
58
+ No HIV infection on antiretroviral therapy'''
59
+ },
60
+ {
61
+ 'nct_id': 'NCT45678901',
62
+ 'this_space': '''Microsatellite instability-high (MSI-H) or mismatch repair deficient (dMMR) advanced solid tumors
63
+ Progressive disease on or after prior standard therapy
64
+ ECOG performance status: 0-2
65
+ Measurable disease per RECIST v1.1
66
+ No prior checkpoint inhibitor therapy''',
67
+ 'trial_text': '''Phase II basket study of pembrolizumab in patients with MSI-H/dMMR advanced solid tumors.
68
+ Primary endpoint is objective response rate by tumor type. Secondary endpoints include duration of response,
69
+ progression-free survival, and overall survival.''',
70
+ 'trial_boilerplate_text': '''No active autoimmune disease requiring systemic therapy
71
+ No history of severe immune-related adverse events
72
+ No active pneumonitis or interstitial lung disease
73
+ No concurrent systemic corticosteroids (>10mg prednisone equivalent daily)
74
+ No HIV, hepatitis B, or hepatitis C infection'''
75
+ },
76
+ {
77
+ 'nct_id': 'NCT56789012',
78
+ 'this_space': '''Advanced or metastatic renal cell carcinoma (RCC), clear cell histology
79
+ No prior systemic therapy for advanced disease
80
+ Intermediate or poor risk per IMDC criteria
81
+ ECOG performance status: 0-1
82
+ Measurable disease per RECIST v1.1''',
83
+ 'trial_text': '''Phase III randomized study of cabozantinib plus nivolumab versus sunitinib in previously
84
+ untreated advanced RCC. Primary endpoint is progression-free survival. Secondary endpoints include overall survival,
85
+ objective response rate, and safety.''',
86
+ 'trial_boilerplate_text': '''No prior systemic therapy for metastatic RCC
87
+ No active brain metastases
88
+ No history of bowel perforation or fistula
89
+ No poorly controlled hypertension
90
+ No active hepatitis B or C infection
91
+ No significant cardiovascular disease'''
92
+ }
93
+ ]
94
+
95
+ df = pd.DataFrame(trials)
96
+ df.to_csv('sample_trials.csv', index=False)
97
+ print(f"✓ Created sample_trials.csv with {len(df)} trials")
98
+ return df
99
+
100
+ def create_sample_patient_notes():
101
+ """Create sample patient clinical notes CSV."""
102
+
103
+ base_date = datetime(2023, 1, 1)
104
+
105
+ notes = [
106
+ {
107
+ 'date': base_date,
108
+ 'text': 'Patient is a 67-year-old male with a 40 pack-year smoking history presenting with cough and weight loss. CT chest shows a 4.5 cm right upper lobe mass with mediastinal lymphadenopathy.',
109
+ 'note_type': 'clinical_note'
110
+ },
111
+ {
112
+ 'date': base_date + timedelta(days=7),
113
+ 'text': 'CT-guided lung biopsy performed. Pathology shows adenocarcinoma, moderately differentiated.',
114
+ 'note_type': 'pathology_report'
115
+ },
116
+ {
117
+ 'date': base_date + timedelta(days=14),
118
+ 'text': 'PET/CT shows FDG-avid right upper lobe mass (SUVmax 12.3), right hilar nodes (SUVmax 8.7), and mediastinal nodes (SUVmax 9.2). No distant metastatic disease identified.',
119
+ 'note_type': 'imaging_report'
120
+ },
121
+ {
122
+ 'date': base_date + timedelta(days=21),
123
+ 'text': '''Next-generation sequencing (NGS) performed on lung biopsy specimen.
124
+ Results: EGFR exon 19 deletion (L747_A750delinsP) detected.
125
+ Other findings: TP53 p.R273H mutation, MYC amplification (copy number gain).
126
+ PD-L1 expression by immunohistochemistry: 75% tumor proportion score.
127
+ TMB: 4 mutations/Mb (low).
128
+ No ALK, ROS1, BRAF, MET, RET, or KRAS alterations detected.''',
129
+ 'note_type': 'ngs_report'
130
+ },
131
+ {
132
+ 'date': base_date + timedelta(days=28),
133
+ 'text': 'Mediastinoscopy with biopsy of station 4R and 7 lymph nodes. Pathology confirms metastatic adenocarcinoma. Clinical stage: T2aN2M0, stage IIIA.',
134
+ 'note_type': 'pathology_report'
135
+ },
136
+ {
137
+ 'date': base_date + timedelta(days=42),
138
+ 'text': 'Patient underwent concurrent chemoradiation with carboplatin/pemetrexed and 60 Gy radiation to primary tumor and mediastinum. Tolerated well with grade 2 esophagitis.',
139
+ 'note_type': 'clinical_note'
140
+ },
141
+ {
142
+ 'date': base_date + timedelta(days=112),
143
+ 'text': 'Post-treatment CT chest shows near-complete response of primary tumor (now 1.2 cm) and resolution of lymphadenopathy. Started consolidation durvalumab.',
144
+ 'note_type': 'imaging_report'
145
+ },
146
+ {
147
+ 'date': base_date + timedelta(days=280),
148
+ 'text': 'Surveillance CT shows new liver lesions (segment 6 and 7, largest 2.3 cm) and increase in size of lung primary to 3.1 cm. Progression of disease.',
149
+ 'note_type': 'imaging_report'
150
+ },
151
+ {
152
+ 'date': base_date + timedelta(days=287),
153
+ 'text': 'Patient now has metastatic NSCLC (stage IV). ECOG performance status 1. Discussed treatment options. Given EGFR mutation, recommend EGFR TKI therapy.',
154
+ 'note_type': 'clinical_note'
155
+ },
156
+ {
157
+ 'date': base_date + timedelta(days=294),
158
+ 'text': 'Started osimertinib 80 mg daily for EGFR-mutant metastatic NSCLC.',
159
+ 'note_type': 'clinical_note'
160
+ },
161
+ {
162
+ 'date': base_date + timedelta(days=378),
163
+ 'text': 'Restaging CT shows partial response. Liver lesions decreased to 1.2 and 0.9 cm. Primary lung tumor stable at 2.8 cm. Tolerating osimertinib well with mild diarrhea and dry skin.',
164
+ 'note_type': 'imaging_report'
165
+ },
166
+ {
167
+ 'date': base_date + timedelta(days=560),
168
+ 'text': 'Patient reports increased fatigue and back pain over past 3 weeks.',
169
+ 'note_type': 'clinical_note'
170
+ },
171
+ {
172
+ 'date': base_date + timedelta(days=567),
173
+ 'text': '''CT chest/abdomen/pelvis shows:
174
+ - Progression of liver metastases (segment 6: 3.8 cm, previously 1.2 cm; segment 7: 2.9 cm, previously 0.9 cm)
175
+ - New liver lesions in segments 4 and 5
176
+ - Lung primary increased to 4.2 cm
177
+ - New small pleural effusion
178
+ Assessment: Progressive disease on osimertinib.''',
179
+ 'note_type': 'imaging_report'
180
+ },
181
+ {
182
+ 'date': base_date + timedelta(days=574),
183
+ 'text': 'MRI brain with contrast shows no brain metastases. Patient has progressive EGFR-mutant NSCLC after first-line osimertinib. ECOG PS 1. Discussing clinical trial options for second-line therapy.',
184
+ 'note_type': 'clinical_note'
185
+ }
186
+ ]
187
+
188
+ df = pd.DataFrame(notes)
189
+ df.to_csv('sample_patient_notes.csv', index=False)
190
+ print(f"✓ Created sample_patient_notes.csv with {len(df)} notes")
191
+ return df
192
+
193
+ def create_sample_patient_summary():
194
+ """Create a sample patient summary text file."""
195
+
196
+ summary = """Cancer type: Non-small cell lung cancer (NSCLC)
197
+ Histology: Adenocarcinoma, moderately differentiated
198
+ Stage at diagnosis: Stage IIIA (T2aN2M0)
199
+ Current extent: Metastatic (stage IV) with liver metastases
200
+
201
+ Biomarkers:
202
+ - EGFR exon 19 deletion (L747_A750delinsP)
203
+ - TP53 p.R273H mutation
204
+ - MYC amplification
205
+ - PD-L1 75% TPS
206
+ - TMB: 4 mutations/Mb (low)
207
+
208
+ Treatment history:
209
+ # 1/28/2023 - 4/15/2023: Concurrent chemoradiation (carboplatin/pemetrexed with 60 Gy)
210
+ # 4/22/2023 - 10/5/2023: Consolidation durvalumab
211
+ # 10/19/2023 - present: Osimertinib 80 mg daily for metastatic disease
212
+
213
+ Disease course:
214
+ - Initial diagnosis: January 2023, stage IIIA
215
+ - Near-complete response to chemoradiation
216
+ - Progression to stage IV in September 2023 (liver metastases)
217
+ - Partial response to osimertinib
218
+ - Current progression on osimertinib (July 2024) after ~9 months of therapy
219
+
220
+ Current status:
221
+ - ECOG performance status: 1
222
+ - Progressive disease with liver metastases
223
+ - No brain metastases on recent MRI
224
+
225
+ Boilerplate:
226
+ No evidence of brain metastases (MRI brain 7/22/2024).
227
+ No history of pneumonitis, interstitial lung disease, congestive heart failure, HIV, or hepatitis infection documented.
228
+ Adequate performance status (ECOG 1).
229
+ """
230
+
231
+ with open('sample_patient_summary.txt', 'w') as f:
232
+ f.write(summary)
233
+
234
+ print(f"✓ Created sample_patient_summary.txt")
235
+ return summary
236
+
237
+ if __name__ == "__main__":
238
+ print("Generating sample data for Clinical Trial Matching Pipeline...\n")
239
+
240
+ create_sample_trials()
241
+ create_sample_patient_notes()
242
+ create_sample_patient_summary()
243
+
244
+ print("\n✓ All sample files created successfully!")
245
+ print("\nFiles generated:")
246
+ print(" - sample_trials.csv (5 clinical trials)")
247
+ print(" - sample_patient_notes.csv (14 clinical notes)")
248
+ print(" - sample_patient_summary.txt (pre-made summary)")
249
+ print("\nYou can now use these files to test the Gradio application.")
launch.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Launch script for Clinical Trial Matching Pipeline
4
+
5
+ Checks dependencies and provides helpful startup information.
6
+ """
7
+
8
+ import sys
9
+ import subprocess
10
+ import importlib.util
11
+
12
+ def check_package(package_name, display_name=None):
13
+ """Check if a package is installed."""
14
+ if display_name is None:
15
+ display_name = package_name
16
+
17
+ spec = importlib.util.find_spec(package_name)
18
+ if spec is None:
19
+ return False, display_name
20
+ return True, display_name
21
+
22
+ def check_dependencies():
23
+ """Check if all required dependencies are installed."""
24
+
25
+ required_packages = [
26
+ ('gradio', 'gradio'),
27
+ ('pandas', 'pandas'),
28
+ ('numpy', 'numpy'),
29
+ ('torch', 'PyTorch'),
30
+ ('transformers', 'transformers'),
31
+ ('sentence_transformers', 'sentence-transformers'),
32
+ ]
33
+
34
+ optional_packages = [
35
+ ('vllm', 'vLLM (for faster LLM inference)'),
36
+ ]
37
+
38
+ print("Checking dependencies...\n")
39
+
40
+ missing = []
41
+ for package, display in required_packages:
42
+ installed, name = check_package(package, display)
43
+ status = "✓" if installed else "✗"
44
+ print(f" {status} {name}")
45
+ if not installed:
46
+ missing.append(package)
47
+
48
+ print("\nOptional packages:")
49
+ for package, display in optional_packages:
50
+ installed, name = check_package(package, display)
51
+ status = "✓" if installed else "○"
52
+ print(f" {status} {name}")
53
+
54
+ if missing:
55
+ print(f"\n❌ Missing required packages: {', '.join(missing)}")
56
+ print("\nInstall with:")
57
+ print(f" pip install {' '.join(missing)}")
58
+ print("\nOr install all requirements:")
59
+ print(" pip install -r requirements.txt")
60
+ return False
61
+
62
+ print("\n✓ All required dependencies installed!")
63
+ return True
64
+
65
+ def check_cuda():
66
+ """Check CUDA availability."""
67
+ try:
68
+ import torch
69
+ if torch.cuda.is_available():
70
+ print(f"\n🚀 CUDA available!")
71
+ print(f" GPU count: {torch.cuda.device_count()}")
72
+ for i in range(torch.cuda.device_count()):
73
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
74
+ return True
75
+ else:
76
+ print("\n⚠️ CUDA not available - running on CPU")
77
+ print(" For better performance, install PyTorch with CUDA:")
78
+ print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
79
+ return False
80
+ except ImportError:
81
+ return False
82
+
83
+ def print_startup_info():
84
+ """Print helpful startup information."""
85
+ print("\n" + "="*70)
86
+ print("Clinical Trial Matching Pipeline")
87
+ print("="*70)
88
+ print("\nStarting Gradio web interface...")
89
+ print("\nOnce started, the interface will be available at:")
90
+ print(" Local: http://localhost:7860")
91
+ print(" Network: http://0.0.0.0:7860")
92
+ print("\nPress Ctrl+C to stop the server.")
93
+ print("\n" + "="*70 + "\n")
94
+
95
+ def main():
96
+ """Main launch function."""
97
+
98
+ # Check dependencies
99
+ if not check_dependencies():
100
+ sys.exit(1)
101
+
102
+ # Check CUDA
103
+ check_cuda()
104
+
105
+ # Print startup info
106
+ print_startup_info()
107
+
108
+ # Launch the app
109
+ try:
110
+ import trial_matching_app
111
+ # The app will launch automatically when imported
112
+ except KeyboardInterrupt:
113
+ print("\n\nShutting down gracefully...")
114
+ sys.exit(0)
115
+ except Exception as e:
116
+ print(f"\n❌ Error launching application: {e}")
117
+ import traceback
118
+ traceback.print_exc()
119
+ sys.exit(1)
120
+
121
+ if __name__ == "__main__":
122
+ main()
preembed_trials.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Pre-embed Clinical Trials Script
6
+
7
+ This script pre-processes and embeds a clinical trial database,
8
+ saving the results to disk for faster loading in the main application.
9
+
10
+ Usage:
11
+ python preembed_trials.py --trials trials.csv --embedder path/to/embedder --output trial_embeddings
12
+ python preembed_trials.py --trials /data1/ken/meta/2024/v17b/trial_space_lineitems.csv --embedder /ksg/kehl_mm_data/meta/2024/v17/v17_models/reranker_round2.model --output trial_embeddings --device cuda:2
13
+
14
+
15
+ This will create:
16
+ - trial_embeddings_data.pkl: Trial dataframe
17
+ - trial_embeddings_vectors.npy: Embedding vectors
18
+ - trial_embeddings_metadata.json: Metadata about the embedding process
19
+ """
20
+
21
+ import argparse
22
+ import pandas as pd
23
+ import numpy as np
24
+ import torch
25
+ import json
26
+ import re
27
+ from pathlib import Path
28
+ from datetime import datetime
29
+ from typing import Tuple
30
+ from sentence_transformers import SentenceTransformer
31
+ from transformers import AutoTokenizer
32
+
33
+ def truncate_text(text: str, tokenizer, max_tokens: int = 1500) -> str:
34
+ """Truncate text to a maximum number of tokens."""
35
+ return tokenizer.decode(
36
+ tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_tokens),
37
+ skip_special_tokens=True
38
+ )
39
+
40
+ def load_trials(file_path: str) -> pd.DataFrame:
41
+ """Load trials from CSV or Excel file."""
42
+ print(f"\n{'='*70}")
43
+ print(f"Loading trial database from: {file_path}")
44
+ print(f"{'='*70}")
45
+
46
+ if file_path.endswith('.csv'):
47
+ df = pd.read_csv(file_path)
48
+ elif file_path.endswith(('.xlsx', '.xls')):
49
+ df = pd.read_excel(file_path)
50
+ else:
51
+ raise ValueError("Unsupported file format. Use CSV or Excel.")
52
+
53
+ # Check required columns
54
+ required_cols = ['nct_id', 'this_space', 'trial_text', 'trial_boilerplate_text']
55
+ missing = [col for col in required_cols if col not in df.columns]
56
+ if missing:
57
+ raise ValueError(f"Missing required columns: {', '.join(missing)}")
58
+
59
+ print(f"✓ Loaded {len(df)} trials")
60
+ print(f" Columns: {', '.join(df.columns.tolist())}")
61
+
62
+ # Clean data
63
+ original_count = len(df)
64
+ df = df[~df['this_space'].isnull()].copy()
65
+ df['trial_boilerplate_text'] = df['trial_boilerplate_text'].fillna('')
66
+
67
+ if len(df) < original_count:
68
+ print(f" ⚠ Removed {original_count - len(df)} trials with missing 'this_space'")
69
+
70
+ return df
71
+
72
+ def embed_trials(df: pd.DataFrame, embedder_path: str, device: str = None) -> Tuple[np.ndarray, str]:
73
+ """Embed trials using the specified embedder model."""
74
+ print(f"\n{'='*70}")
75
+ print(f"Loading embedder model: {embedder_path}")
76
+ print(f"{'='*70}")
77
+
78
+ if device is None:
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+
81
+ print(f"Device: {device}")
82
+
83
+ # Load embedder
84
+ embedder_model = SentenceTransformer(embedder_path, device=device, trust_remote_code=True)
85
+ embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_path, trust_remote_code=True)
86
+
87
+ print(f"✓ Embedder loaded")
88
+
89
+ # Set the instruction prompt
90
+ try:
91
+ embedder_model.prompts['query'] = (
92
+ "Instruct: Given a cancer patient summary, retrieve clinical trial options "
93
+ "that are reasonable for that patient; or, given a clinical trial option, "
94
+ "retrieve cancer patients who are reasonable candidates for that trial."
95
+ )
96
+ except:
97
+ pass
98
+
99
+ try:
100
+ embedder_model.max_seq_length = 1500
101
+ except:
102
+ pass
103
+
104
+ print(f"\n{'='*70}")
105
+ print(f"Embedding {len(df)} trials")
106
+ print(f"{'='*70}")
107
+
108
+ # Prepare texts for embedding
109
+ df['this_space_trunc'] = df['this_space'].apply(
110
+ lambda x: truncate_text(str(x), embedder_tokenizer, max_tokens=1500)
111
+ )
112
+
113
+ # Add instruction prefix
114
+ prefix = (
115
+ "Instruct: Given a cancer patient summary, retrieve clinical trial options "
116
+ "that are reasonable for that patient; or, given a clinical trial option, "
117
+ "retrieve cancer patients who are reasonable candidates for that trial. "
118
+ )
119
+ texts_to_embed = [prefix + txt for txt in df['this_space_trunc'].tolist()]
120
+
121
+ print(f" Text length stats:")
122
+ print(f" Mean: {np.mean([len(t) for t in texts_to_embed]):.0f} chars")
123
+ print(f" Max: {max([len(t) for t in texts_to_embed])} chars")
124
+
125
+ # Embed with progress bar
126
+ with torch.no_grad():
127
+ embeddings = embedder_model.encode(
128
+ texts_to_embed,
129
+ batch_size=64,
130
+ convert_to_tensor=True,
131
+ normalize_embeddings=True,
132
+ show_progress_bar=True,
133
+ prompt='query'
134
+ )
135
+
136
+ embeddings_np = embeddings.cpu().numpy()
137
+
138
+ print(f"✓ Embedding complete")
139
+ print(f" Shape: {embeddings_np.shape}")
140
+ print(f" Dtype: {embeddings_np.dtype}")
141
+
142
+ return embeddings_np, embedder_path
143
+
144
+ def save_embeddings(df: pd.DataFrame, embeddings: np.ndarray, output_prefix: str, embedder_path: str):
145
+ """Save trial data, embeddings, and metadata to disk."""
146
+ print(f"\n{'='*70}")
147
+ print(f"Saving to: {output_prefix}_*")
148
+ print(f"{'='*70}")
149
+
150
+ output_path = Path(output_prefix).parent
151
+ output_path.mkdir(parents=True, exist_ok=True)
152
+
153
+ # Save dataframe
154
+ df_file = f"{output_prefix}_data.pkl"
155
+ df.to_pickle(df_file)
156
+ print(f"✓ Saved trial dataframe: {df_file}")
157
+ print(f" Size: {Path(df_file).stat().st_size / 1024 / 1024:.2f} MB")
158
+
159
+ # Save embeddings
160
+ embeddings_file = f"{output_prefix}_vectors.npy"
161
+ np.save(embeddings_file, embeddings)
162
+ print(f"✓ Saved embeddings: {embeddings_file}")
163
+ print(f" Size: {Path(embeddings_file).stat().st_size / 1024 / 1024:.2f} MB")
164
+
165
+ # Save metadata
166
+ metadata = {
167
+ "created_at": datetime.now().isoformat(),
168
+ "embedder_model": embedder_path,
169
+ "num_trials": len(df),
170
+ "embedding_dim": embeddings.shape[1],
171
+ "nct_ids": df['nct_id'].tolist()[:10] + ["..."] if len(df) > 10 else df['nct_id'].tolist(),
172
+ "embedding_dtype": str(embeddings.dtype),
173
+ "normalized": True
174
+ }
175
+
176
+ metadata_file = f"{output_prefix}_metadata.json"
177
+ with open(metadata_file, 'w') as f:
178
+ json.dump(metadata, f, indent=2)
179
+ print(f"✓ Saved metadata: {metadata_file}")
180
+
181
+ print(f"\n{'='*70}")
182
+ print(f"PRE-EMBEDDING COMPLETE")
183
+ print(f"{'='*70}")
184
+ print(f"\nTo use these pre-embedded trials in your app:")
185
+ print(f"1. Update config.py with:")
186
+ print(f" PREEMBEDDED_TRIALS = '{output_prefix}'")
187
+ print(f"2. Restart the application")
188
+ print(f"\nThe app will automatically load these embeddings on startup!")
189
+
190
+ def main():
191
+ parser = argparse.ArgumentParser(
192
+ description="Pre-embed clinical trials for faster loading",
193
+ formatter_class=argparse.RawDescriptionHelpFormatter,
194
+ epilog="""
195
+ Examples:
196
+ python preembed_trials.py --trials data/trials.csv --embedder models/embedder --output embeddings/trial_embeddings
197
+ python preembed_trials.py --trials trials.xlsx --embedder Qwen/Qwen3-Embedding-0.6B --output trial_embeddings --device cuda
198
+ """
199
+ )
200
+
201
+ parser.add_argument(
202
+ '--trials',
203
+ type=str,
204
+ required=True,
205
+ help='Path to trial database (CSV or Excel)'
206
+ )
207
+
208
+ parser.add_argument(
209
+ '--embedder',
210
+ type=str,
211
+ required=True,
212
+ help='Path to embedder model or HuggingFace model name'
213
+ )
214
+
215
+ parser.add_argument(
216
+ '--output',
217
+ type=str,
218
+ required=True,
219
+ help='Output prefix for saved files (e.g., "trial_embeddings" will create trial_embeddings_data.pkl, etc.)'
220
+ )
221
+
222
+ parser.add_argument(
223
+ '--device',
224
+ type=str,
225
+ default=None,
226
+ #choices=['cuda', 'cpu'],
227
+ help='Device to use for embedding (default: auto-detect)'
228
+ )
229
+
230
+ args = parser.parse_args()
231
+
232
+ print(f"\n{'='*70}")
233
+ print(f"CLINICAL TRIAL PRE-EMBEDDING SCRIPT")
234
+ print(f"{'='*70}")
235
+ print(f"Trial Database: {args.trials}")
236
+ print(f"Embedder Model: {args.embedder}")
237
+ print(f"Output Prefix: {args.output}")
238
+ print(f"{'='*70}\n")
239
+
240
+ try:
241
+ # Load trials
242
+ df = load_trials(args.trials)
243
+
244
+ # Embed trials
245
+ embeddings, embedder_path = embed_trials(df, args.embedder, args.device)
246
+
247
+ # Save everything
248
+ save_embeddings(df, embeddings, args.output, embedder_path)
249
+
250
+ print(f"\n✓ SUCCESS!")
251
+
252
+ except Exception as e:
253
+ print(f"\n✗ ERROR: {str(e)}")
254
+ import traceback
255
+ traceback.print_exc()
256
+ return 1
257
+
258
+ return 0
259
+
260
+ if __name__ == "__main__":
261
+ exit(main())
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ pandas>=2.0.0
3
+ numpy>=1.24.0
4
+ torch>=2.0.0
5
+ transformers>=4.35.0
6
+ sentence-transformers>=2.2.0
7
+ openpyxl>=3.1.0
8
+ xlrd>=2.0.0
9
+
10
+ # Optional but recommended for faster LLM inference
11
+ vllm>=0.5.0
12
+
13
+ # For CUDA support (if using GPU)
14
+ # Install PyTorch with CUDA separately:
15
+ # pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
trial_embeddings_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c733ad315bba227bca0aecf3f7e4947a9383f312520ec87c785d6110afb826a3
3
+ size 362844655
trial_embeddings_metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "created_at": "2025-10-28T16:17:07.203636",
3
+ "embedder_model": "/ksg/kehl_mm_data/meta/2024/v17/v17_models/reranker_round2.model",
4
+ "num_trials": 39266,
5
+ "embedding_dim": 1024,
6
+ "nct_ids": [
7
+ "NCT00001160",
8
+ "NCT00001160",
9
+ "NCT00001186",
10
+ "NCT00001186",
11
+ "NCT00001238",
12
+ "NCT00001238",
13
+ "NCT00001238",
14
+ "NCT00001238",
15
+ "NCT00001238",
16
+ "NCT00001238",
17
+ "..."
18
+ ],
19
+ "embedding_dtype": "float32",
20
+ "normalized": true
21
+ }
trial_embeddings_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42f6e5f85a93a8b95ec04ea59c57d4be73d042fb483522c144a7a6f0720c4379
3
+ size 160833664