ahuggingface01 commited on
Commit
a5b6ba6
·
verified ·
1 Parent(s): 4f75e01

Upload 8 files

Browse files
main.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import argparse
4
+ import os
5
+ import math
6
+
7
+ # Imports from your specific file structure
8
+ from src.preprocess.wtts_builder import WTTSBuilder
9
+ from src.utils.data_loader import DataLoader
10
+
11
+ # RAG pipeline imports (lazy — only used when --use_rag is set)
12
+ RAG_AVAILABLE = False
13
+ try:
14
+ from src.rag.embedder import WTTSEmbedder
15
+ from src.rag.rag_pipeline import RAGCRFExtractor
16
+ RAG_AVAILABLE = True
17
+ except ImportError:
18
+ pass
19
+
20
+ import google.generativeai as genai
21
+
22
+ # --- CONFIG ---
23
+ API_KEY = "AIzaSyAkgKha4IxsCjRXbeirhyoygT9Qmr4qYzU"
24
+
25
+ # --- PROMPTS ---
26
+
27
+ SKELETON_PROMPT = """
28
+ You are a Clinical Data Specialist.
29
+ Convert the Weighted Time Series (WTTS) below into a "Clinical Chronology Skeleton".
30
+
31
+ INPUT (WTTS):
32
+ {wtts_string}
33
+
34
+ INSTRUCTIONS:
35
+ 1. Create a strict chronological timeline (Admission to Discharge).
36
+ 2. IMPORTANT: You MUST retain the [S_xx] ID for every event you list.
37
+ 3. Filter out "Routine" (Weight 0.1) events unless they indicate a status change.
38
+ 4. Keep exact values (e.g., "BP 90/60", "Temp 102.5").
39
+
40
+ OUTPUT FORMAT:
41
+ [Date] [S_xx]: Event details
42
+ [Date] [S_xx]: Event details
43
+ ...
44
+ """
45
+
46
+ EXTRACTION_PROMPT = """
47
+ You are a Clinical Coding Expert.
48
+ Review the Patient Skeleton and the Valid Options for the requested items.
49
+
50
+ PATIENT SKELETON:
51
+ {skeleton}
52
+
53
+ TASK:
54
+ For each Clinical Item listed below, determine the value AND the supporting Sentence ID.
55
+ 1. **Value**: Must come strictly from the "Valid Options" provided.
56
+ 2. **Evidence**: Must be the specific [S_xx] ID from the skeleton that proves the value.
57
+
58
+ ITEMS TO EXTRACT & THEIR OPTIONS:
59
+ {chunk_schema_json}
60
+
61
+ OUTPUT FORMAT (JSON Object):
62
+ {{
63
+ "item_name": {{
64
+ "value": "Selected Option",
65
+ "evidence": "S_xx",
66
+ "reasoning": "Brief explanation"
67
+ }},
68
+ ...
69
+ }}
70
+ """
71
+
72
+ def chunk_data(data, size):
73
+ """Yield successive n-sized chunks from list."""
74
+ for i in range(0, len(data), size):
75
+ yield data[i:i + size]
76
+
77
+
78
+ async def generate_async(prompt, model, max_retries=3, initial_delay=1):
79
+ """Call Gemini via google-generativeai SDK (async-safe)."""
80
+ loop = asyncio.get_event_loop()
81
+ for attempt in range(max_retries):
82
+ try:
83
+ response = await loop.run_in_executor(
84
+ None,
85
+ lambda: model.generate_content(
86
+ contents=prompt,
87
+ generation_config=genai.GenerationConfig(
88
+ response_mime_type="application/json"
89
+ ),
90
+ )
91
+ )
92
+ try:
93
+ json_response = json.loads(response.text)
94
+ return json_response
95
+ except json.JSONDecodeError:
96
+ print(f"Generated content is not valid JSON. Retrying...")
97
+ continue
98
+
99
+ except Exception as e:
100
+ error_message = str(e)
101
+ if "429" in error_message or "500" in error_message:
102
+ if attempt < max_retries - 1:
103
+ delay = initial_delay * (2 ** attempt)
104
+ print(f"Rate limit / server error. Retrying in {delay}s...")
105
+ await asyncio.sleep(delay)
106
+ else:
107
+ print(f"Max retries reached.")
108
+ return {"error": f"Max retries reached - {error_message}"}
109
+ else:
110
+ print(f"Error in generate_async: {error_message}")
111
+ return {"error": error_message}
112
+
113
+ return {"error": "Failed to generate valid JSON after multiple attempts"}
114
+
115
+
116
+ async def process_patient(model, builder, patient_data, target_items, valid_options, semaphore):
117
+ """Executes the Two-Pass Pipeline for a single patient."""
118
+ async with semaphore:
119
+ pid = str(patient_data.get('document_id') or patient_data.get('patient_id')
120
+ or patient_data.get('hadm_id') or 'unknown')
121
+
122
+ try:
123
+ # --- PHASE 1: WTTS Construction ---
124
+ wtts_string = builder.build_wtts_string(patient_data)
125
+
126
+ # --- PHASE 2: Skeleton Generation (Pass 1) ---
127
+ skeleton_input = SKELETON_PROMPT.format(wtts_string=wtts_string)
128
+ skeleton_resp = await generate_async(skeleton_input, model)
129
+
130
+ skeleton_text = str(skeleton_resp)
131
+ if isinstance(skeleton_resp, dict):
132
+ skeleton_text = json.dumps(skeleton_resp)
133
+
134
+ # --- PHASE 3: Extraction (Pass 2) ---
135
+ final_predictions = {}
136
+ item_chunks = list(chunk_data(target_items, 10))
137
+
138
+ for chunk_items in item_chunks:
139
+ chunk_schema = {
140
+ item: valid_options.get(item, ["Yes", "No", "Unknown"])
141
+ for item in chunk_items
142
+ }
143
+
144
+ extract_input = EXTRACTION_PROMPT.format(
145
+ skeleton=skeleton_text,
146
+ chunk_schema_json=json.dumps(chunk_schema)
147
+ )
148
+
149
+ chunk_resp = await generate_async(extract_input, model)
150
+
151
+ if isinstance(chunk_resp, dict):
152
+ if 'error' in chunk_resp:
153
+ print(f" [WARN] LLM error for {pid}, chunk {chunk_items[:3]}...: {chunk_resp['error']}")
154
+ else:
155
+ final_predictions.update(chunk_resp)
156
+
157
+ return {
158
+ "patient_id": pid,
159
+ "skeleton_debug": skeleton_text[:500] + "...",
160
+ "predictions": final_predictions
161
+ }
162
+
163
+ except Exception as e:
164
+ print(f"Error processing {pid}: {e}")
165
+ return None
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # EVALUATION -- Accuracy & F1 Scoring
170
+ # ---------------------------------------------------------------------------
171
+
172
+ def _normalise(value):
173
+ """Lowercase + strip for fair comparison."""
174
+ if value is None:
175
+ return ""
176
+ return str(value).strip().lower()
177
+
178
+
179
+ def evaluate_predictions(results, gt_path):
180
+ """
181
+ Compare pipeline results against dev_gt.jsonl.
182
+ Prints accuracy, macro-F1, per-item breakdown, and sample errors.
183
+ Returns (overall_dict, per_item_dict).
184
+ """
185
+ # --- Load GT ---
186
+ gt = {}
187
+ with open(gt_path, 'r', encoding='utf-8') as f:
188
+ for line in f:
189
+ line = line.strip()
190
+ if not line:
191
+ continue
192
+ rec = json.loads(line)
193
+ doc_id = str(rec['document_id'])
194
+ gt[doc_id] = {a['item']: a['ground_truth'] for a in rec.get('annotations', [])}
195
+
196
+ # --- Build prediction lookup ---
197
+ preds = {}
198
+ for r in results:
199
+ doc_id = str(r.get('patient_id', 'unknown'))
200
+ items = {}
201
+ for item_name, item_val in r.get('predictions', {}).items():
202
+ if isinstance(item_val, dict):
203
+ items[item_name] = item_val.get('value', str(item_val))
204
+ else:
205
+ items[item_name] = str(item_val)
206
+ preds[doc_id] = items
207
+
208
+ # --- Collect all unique items ---
209
+ all_items = set()
210
+ for doc_items in gt.values():
211
+ all_items.update(doc_items.keys())
212
+
213
+ # --- Score ---
214
+ item_stats = {item: {'tp': 0, 'fp': 0, 'fn': 0, 'total': 0, 'correct': 0}
215
+ for item in all_items}
216
+ total_comparisons = 0
217
+ total_correct = 0
218
+ matched_patients = 0
219
+ errors = []
220
+
221
+ for doc_id, gt_items in gt.items():
222
+ pred_items = preds.get(doc_id, {})
223
+ if pred_items:
224
+ matched_patients += 1
225
+
226
+ for item_name, gt_val in gt_items.items():
227
+ gt_norm = _normalise(gt_val)
228
+ pred_val = pred_items.get(item_name)
229
+ pred_norm = _normalise(pred_val) if pred_val is not None else ""
230
+
231
+ total_comparisons += 1
232
+ item_stats[item_name]['total'] += 1
233
+
234
+ if gt_norm == pred_norm:
235
+ total_correct += 1
236
+ item_stats[item_name]['correct'] += 1
237
+ item_stats[item_name]['tp'] += 1
238
+ else:
239
+ item_stats[item_name]['fn'] += 1
240
+ if pred_norm:
241
+ item_stats[item_name]['fp'] += 1
242
+ errors.append((doc_id, item_name, gt_val,
243
+ pred_val if pred_val is not None else '<MISSING>'))
244
+
245
+ accuracy = total_correct / total_comparisons if total_comparisons > 0 else 0.0
246
+
247
+ # --- Per-item P/R/F1 ---
248
+ f1s = []
249
+ per_item = {}
250
+ for item_name in sorted(all_items):
251
+ s = item_stats[item_name]
252
+ tp, fp, fn = s['tp'], s['fp'], s['fn']
253
+ prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
254
+ rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
255
+ f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
256
+ item_acc = s['correct'] / s['total'] if s['total'] > 0 else 0.0
257
+ per_item[item_name] = {'accuracy': item_acc, 'precision': prec,
258
+ 'recall': rec, 'f1': f1, 'total': s['total']}
259
+ f1s.append(f1)
260
+
261
+ macro_f1 = sum(f1s) / len(f1s) if f1s else 0.0
262
+
263
+ # --- Print report ---
264
+ print("\n" + "=" * 70)
265
+ print(" CL4Health CRF Filling -- Evaluation Report")
266
+ print("=" * 70)
267
+ print(f"\n GT Patients: {len(gt)}")
268
+ print(f" Pred Patients: {len(preds)}")
269
+ print(f" Matched Patients: {matched_patients}")
270
+ print(f"\n Total Comparisons: {total_comparisons}")
271
+ print(f" Total Correct: {total_correct}")
272
+ print(f"\n {'Accuracy':>20s}: {accuracy:.4f}")
273
+ print(f" {'Macro F1':>20s}: {macro_f1:.4f}")
274
+
275
+ # Top / bottom items
276
+ sorted_items = sorted(per_item.items(), key=lambda x: x[1]['f1'], reverse=True)
277
+ n_show = min(15, len(sorted_items))
278
+
279
+ print(f"\n Top {n_show} Items by F1:")
280
+ print(f" {'Item':<45s} {'Acc':>6s} {'P':>6s} {'R':>6s} {'F1':>6s}")
281
+ print(f" {'-'*45} {'---':>6s} {'---':>6s} {'---':>6s} {'---':>6s}")
282
+ for name, s in sorted_items[:n_show]:
283
+ print(f" {name:<45s} {s['accuracy']:>6.2f} {s['precision']:>6.2f} {s['recall']:>6.2f} {s['f1']:>6.2f}")
284
+
285
+ print(f"\n Bottom {n_show} Items by F1:")
286
+ print(f" {'Item':<45s} {'Acc':>6s} {'P':>6s} {'R':>6s} {'F1':>6s}")
287
+ print(f" {'-'*45} {'---':>6s} {'---':>6s} {'---':>6s} {'---':>6s}")
288
+ for name, s in sorted_items[-n_show:]:
289
+ print(f" {name:<45s} {s['accuracy']:>6.2f} {s['precision']:>6.2f} {s['recall']:>6.2f} {s['f1']:>6.2f}")
290
+
291
+ # Sample errors
292
+ if errors:
293
+ n_err = min(15, len(errors))
294
+ print(f"\n Sample Mismatches ({n_err} of {len(errors)}):")
295
+ print(f" {'DocID':<12s} {'Item':<40s} {'GT':<20s} {'Pred':<20s}")
296
+ print(f" {'-'*12} {'-'*40} {'-'*20} {'-'*20}")
297
+ for doc_id, item, gt_v, pred_v in errors[:n_err]:
298
+ print(f" {doc_id:<12s} {item:<40s} {str(gt_v):<20s} {str(pred_v):<20s}")
299
+
300
+ print("=" * 70)
301
+
302
+ return {'accuracy': round(accuracy, 4), 'macro_f1': round(macro_f1, 4)}, per_item
303
+
304
+
305
+ # ---------------------------------------------------------------------------
306
+ # MAIN
307
+ # ---------------------------------------------------------------------------
308
+
309
+ async def main():
310
+ parser = argparse.ArgumentParser()
311
+ parser.add_argument("--api_key", default=API_KEY,
312
+ help="Google AI Studio API key")
313
+ parser.add_argument("--model_name", default="gemini-1.5-pro",
314
+ help="Gemini model name")
315
+ parser.add_argument("--data_folders", nargs="+",
316
+ default=[
317
+ r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-clinical-notes",
318
+ r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-crf-development",
319
+ ],
320
+ help="Directories containing .parquet shards (searched recursively)")
321
+ parser.add_argument("--gt_file",
322
+ default=r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dev_gt.jsonl")
323
+ parser.add_argument("--options_folder",
324
+ default=r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-valid-options\dyspnea-valid-options\data")
325
+ parser.add_argument("--output_file",
326
+ default="data/processed/materialized_ehr/submission.json")
327
+ parser.add_argument("--skip_eval", action="store_true",
328
+ help="Skip evaluation after generating predictions")
329
+ parser.add_argument("--concurrency", type=int, default=5,
330
+ help="Max concurrent LLM calls (free tier: keep at 5)")
331
+ # --- RAG options ---
332
+ parser.add_argument("--use_rag", action="store_true",
333
+ help="Use RAG-guided extraction (retrieves relevant tuples per CRF item)")
334
+ parser.add_argument("--rag_top_k", type=int, default=15,
335
+ help="Number of WTTS tuples to retrieve per CRF item group (RAG mode)")
336
+ parser.add_argument("--rag_model", type=str, default="all-MiniLM-L6-v2",
337
+ help="SentenceTransformer model for embeddings (swap to clinical model on GPU)")
338
+ parser.add_argument("--rag_device", type=str, default="cpu",
339
+ help="Device for embedding model: 'cpu' or 'cuda'")
340
+ args = parser.parse_args()
341
+
342
+ # 1. Setup — Configure Gemini API
343
+ genai.configure(api_key=args.api_key)
344
+ model = genai.GenerativeModel(args.model_name)
345
+ print(f"Using model: {args.model_name} (Google AI Studio)")
346
+
347
+ # Limit concurrency (free tier = 15 RPM, so keep low)
348
+ semaphore = asyncio.Semaphore(args.concurrency)
349
+
350
+ # 2. Load Data
351
+ loader = DataLoader(data_folders=args.data_folders, gt_path=args.gt_file)
352
+
353
+ target_items = loader.get_target_schema()
354
+ valid_options = loader.load_valid_options(args.options_folder)
355
+
356
+ merged_data = loader.load_and_merge()
357
+
358
+ if not merged_data:
359
+ print("No data found. Exiting.")
360
+ return
361
+
362
+ # 3. Process
363
+ builder = WTTSBuilder()
364
+ print(f"Starting pipeline for {len(merged_data)} patients...")
365
+ print(f"Schema: {len(target_items)} items per patient.")
366
+
367
+ if args.use_rag:
368
+ # --- RAG Pipeline ---
369
+ if not RAG_AVAILABLE:
370
+ print("ERROR: RAG dependencies not installed. Run:")
371
+ print(" pip install sentence-transformers faiss-cpu")
372
+ return
373
+
374
+ print(f"\n [RAG MODE] Embedding model: {args.rag_model}")
375
+ print(f" [RAG MODE] Device: {args.rag_device}")
376
+ print(f" [RAG MODE] Top-k: {args.rag_top_k}\n")
377
+
378
+ embedder = WTTSEmbedder(model_name=args.rag_model, device=args.rag_device)
379
+ extractor = RAGCRFExtractor(
380
+ embedder=embedder,
381
+ generate_fn=generate_async,
382
+ top_k=args.rag_top_k,
383
+ )
384
+
385
+ tasks = [
386
+ extractor.extract_patient(
387
+ p, builder, target_items, valid_options, semaphore, model
388
+ )
389
+ for p in merged_data
390
+ ]
391
+ else:
392
+ # --- Original Two-Pass Pipeline ---
393
+ tasks = [
394
+ process_patient(model, builder, p, target_items, valid_options, semaphore)
395
+ for p in merged_data
396
+ ]
397
+
398
+ results = await asyncio.gather(*tasks)
399
+ results = [r for r in results if r is not None]
400
+
401
+ # 4. Save
402
+ os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
403
+ with open(args.output_file, 'w') as f:
404
+ json.dump(results, f, indent=2)
405
+
406
+ print(f"\nDone! {len(results)} results saved to {args.output_file}")
407
+
408
+ # 5. Evaluate against GT
409
+ if not args.skip_eval:
410
+ print("\nRunning evaluation against ground truth...")
411
+ overall, _ = evaluate_predictions(results, args.gt_file)
412
+ print(f"\n >>> Final Accuracy: {overall['accuracy']:.4f} | Macro F1: {overall['macro_f1']:.4f}")
413
+
414
+
415
+ if __name__ == "__main__":
416
+ asyncio.run(main())
predictor.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import asyncio
4
+ import glob
5
+ import argparse
6
+ from argparse import ArgumentTypeError
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ import google.generativeai as genai
10
+ from typing import Dict, Optional
11
+
12
+ import src.utils.utils as utils
13
+
14
+ generation_config = {
15
+ "max_output_tokens": 8192,
16
+ "temperature": 1,
17
+ "top_p": 0.95,
18
+ }
19
+
20
+ def merge_jsonl_files(input_folder, output_file):
21
+ seen_records = set()
22
+ with open(output_file, 'w') as outfile:
23
+ for filename in glob.glob(f"{input_folder}/*.jsonl"):
24
+ with open(filename, 'r') as infile:
25
+ for line in infile:
26
+ record = json.loads(line)
27
+ unique_id = str(record.get('person_id', ''))
28
+ if unique_id not in seen_records:
29
+ outfile.write(line)
30
+ seen_records.add(unique_id)
31
+ print(f"Merged {len(seen_records)} unique records into {output_file}")
32
+
33
+ def get_context_path_safe(ehr_context):
34
+ """Convert ehr_context value to a path-safe string."""
35
+ if isinstance(ehr_context, int) or (isinstance(ehr_context, str) and ehr_context.isdigit()):
36
+ return f"context_{ehr_context}"
37
+ return str(ehr_context)
38
+
39
+ def validate_ehr_context(value):
40
+ """
41
+ Custom validation for ehr_context argument.
42
+ Accepts either predefined strings or positive integers.
43
+ Returns:
44
+ - String values unchanged ("full", "last_five")
45
+ - Integer values as integers
46
+ - For path construction, numeric values are converted to "context_{value}" format
47
+ """
48
+ if value in ["full", "last_five"]:
49
+ return value
50
+ try:
51
+ context_length = int(value)
52
+ if context_length <= 0:
53
+ raise ArgumentTypeError(f"Context length must be positive, got {value}")
54
+ return context_length
55
+ except ValueError:
56
+ raise ArgumentTypeError(
57
+ f"Invalid ehr_context value: {value}. Must be either 'full', 'last_five'"
58
+ "or a positive integer for context length."
59
+ )
60
+
61
+
62
+ async def generate_async(prompt, model, max_retries=3, initial_delay=1):
63
+ loop = asyncio.get_event_loop()
64
+ for attempt in range(max_retries):
65
+ try:
66
+ # Wrap the synchronous SDK call so it doesn't block the event loop
67
+ response = await loop.run_in_executor(
68
+ None,
69
+ lambda: model.generate_content(
70
+ contents=prompt,
71
+ generation_config={"response_mime_type": "application/json"},
72
+ stream=False
73
+ )
74
+ )
75
+ try:
76
+ json_response = json.loads(response.text)
77
+ return json_response
78
+ except json.JSONDecodeError:
79
+ print(f"Generated content is not valid JSON. Retrying...")
80
+ continue
81
+
82
+ except Exception as e:
83
+ error_message = str(e)
84
+ if "500" in error_message:
85
+ if attempt < max_retries - 1:
86
+ delay = initial_delay * (2 ** attempt) # Exponential backoff
87
+ print(f"Encountered 500 error. Retrying in {delay} seconds...")
88
+ await asyncio.sleep(delay)
89
+ else:
90
+ print(f"Max retries reached. Unable to generate content.")
91
+ return {"error": f"Max retries reached - {error_message}"}
92
+ else:
93
+ print(f"Error in generate_async: {error_message}")
94
+ return {"error": error_message}
95
+
96
+ return {"error": "Failed to generate valid JSON content after multiple attempts"}
97
+
98
+
99
+ async def process_single_timeline(timeline, args, model, semaphore, windowed_jsonl_file):
100
+ person_id = timeline.get('person_id', 'unknown')
101
+ print(f"Processing timeline for person_id: {person_id}")
102
+
103
+ try:
104
+ context_windows = utils.process_ehr_context(
105
+ timeline,
106
+ args.ehr_context,
107
+ windowed_jsonl_file
108
+ )
109
+ responses = []
110
+
111
+ for window in context_windows:
112
+ if args.prompt_method == "general":
113
+ prompt = utils.create_prompt_from_timeline(
114
+ window,
115
+ args.prompt_template
116
+ )
117
+ async with semaphore:
118
+ response = await generate_async(prompt, model)
119
+ responses.append({
120
+ "window_index": window.get("window_index", 0),
121
+ "response": response,
122
+ "window_token_count": window.get("window_token_count"),
123
+ "window_percent_full": window.get("window_percent_full"),
124
+ "start_date": window.get("start_date"),
125
+ "end_date": window.get("end_date")
126
+ })
127
+ specialty = None
128
+ visit_occurrence_id = None
129
+ elif args.prompt_method == "persona":
130
+ prompt, specialty, visit_occurrence_id = utils.create_persona_prompt_from_timeline(
131
+ window,
132
+ args.prompt_template,
133
+ person_id
134
+ )
135
+ async with semaphore:
136
+ response = await generate_async(prompt, model)
137
+ responses.append({
138
+ "window_index": window.get("window_index", 0),
139
+ "response": response,
140
+ "window_token_count": window.get("window_token_count"),
141
+ "window_percent_full": window.get("window_percent_full"),
142
+ "start_date": window.get("start_date"),
143
+ "end_date": window.get("end_date")
144
+ })
145
+ else:
146
+ raise ValueError(f"Unknown prompt method: {args.prompt_method}")
147
+
148
+ print(f"Completed processing for person_id: {person_id} with {len(responses)} context windows")
149
+ return {
150
+ "person_id": person_id,
151
+ "visit_occurrence_id": visit_occurrence_id,
152
+ "responses": responses,
153
+ "specialty": specialty
154
+ }
155
+ except Exception as e:
156
+ print(f"Error processing timeline for person_id {person_id}: {str(e)}")
157
+ return None
158
+
159
+ async def process_timelines(merged_jsonl_file, args, windowed_jsonl_file):
160
+ genai.configure(api_key=args.api_key)
161
+ model = genai.GenerativeModel(args.model_name)
162
+ semaphore = asyncio.Semaphore(args.max_concurrent_calls)
163
+
164
+ with open(merged_jsonl_file, 'r') as f:
165
+ timelines = [json.loads(line) for line in f]
166
+ total_timelines = min(len(timelines), args.max_samples) if args.max_samples else len(timelines)
167
+ print(f"Found {total_timelines} total timelines")
168
+
169
+ results = []
170
+ processed_count = 0
171
+ for timeline in tqdm(timelines[:total_timelines], total=total_timelines, desc="Processing timelines"):
172
+ try:
173
+ result = await process_single_timeline(timeline, args, model, semaphore, windowed_jsonl_file)
174
+ if result is not None:
175
+ results.append(result)
176
+ processed_count += 1
177
+ if processed_count % 10 == 0 or processed_count == total_timelines:
178
+ print(f"Processed {processed_count}/{total_timelines} samples")
179
+ except Exception as e:
180
+ print(f"Error processing timeline: {e}")
181
+
182
+ return results
183
+
184
+ def save_single_response(resp_data: Dict, output_file: Path, person_id: str, template_name: Optional[str] = None):
185
+ try:
186
+ with open(output_file, 'w') as f:
187
+ json.dump(resp_data, f, indent=2)
188
+ print(f"Saved response for person_id: {person_id}" +
189
+ (f", template: {template_name}" if template_name else ""))
190
+ except Exception as e:
191
+ print(f"Error saving response for person_id: {person_id}" +
192
+ (f", template: {template_name}" if template_name else ""))
193
+ print(f"Error details: {str(e)}")
194
+
195
+ def save_responses(responses, output_folder, prompt_method):
196
+ """
197
+ Save responses with window metadata including date ranges.
198
+ """
199
+ for response in responses:
200
+ person_id = response["person_id"]
201
+
202
+ for window_response in response["responses"]:
203
+ if prompt_method == "general":
204
+ output_file = Path(output_folder) / f"{person_id}_{window_response['window_index']}.json"
205
+ elif prompt_method == "persona":
206
+ specialty = response.get("specialty", "unknown")
207
+ output_file = Path(output_folder) / f"{person_id}_{specialty}_{window_response['window_index']}.json"
208
+
209
+ metadata = {
210
+ "person_id": person_id,
211
+ "window_index": window_response["window_index"],
212
+ "window_token_count": window_response["window_token_count"],
213
+ "window_percent_full": window_response["window_percent_full"],
214
+ "start_date": window_response["start_date"],
215
+ "end_date": window_response["end_date"],
216
+ "response_data": window_response["response"]
217
+ }
218
+
219
+ output_file.parent.mkdir(parents=True, exist_ok=True)
220
+ save_single_response(metadata, output_file, person_id)
221
+
222
+ print(f"Saved response for person_id: {person_id}, "
223
+ f"window: {window_response['window_index']}, "
224
+ f"dates: {window_response['start_date']} to {window_response['end_date']}")
225
+
226
+ print(f"Saved responses from {len(responses)} timelines to {output_folder}")
227
+
228
+
229
+ async def main(args):
230
+ start_time = time.time()
231
+
232
+ merged_jsonl_file = Path(args.materialized_ehr_folder) / f"merged_timelines.jsonl"
233
+ windowed_dir = Path(args.materialized_ehr_folder) / "windowed"
234
+ windowed_dir.mkdir(parents=True, exist_ok=True)
235
+ windowed_jsonl_file = windowed_dir / f"windowed_timelines_{args.prompt_method}_{args.ehr_context}.jsonl"
236
+
237
+ merge_jsonl_files(args.materialized_ehr_folder, merged_jsonl_file)
238
+
239
+ print("Processing timelines...")
240
+ responses = await process_timelines(merged_jsonl_file, args, windowed_jsonl_file)
241
+ print(f"Processed {len(responses)} samples")
242
+
243
+ if args.max_samples:
244
+ print(f"Limited to {args.max_samples} samples for testing")
245
+
246
+ if responses:
247
+ context_path = get_context_path_safe(args.ehr_context)
248
+ output_folder = Path(args.output_folder) / args.model_name / args.prompt_method / context_path / Path(args.prompt_template).stem
249
+ save_responses(responses, output_folder, args.prompt_method)
250
+
251
+ total_time = time.time() - start_time
252
+ print(f"Total time taken: {total_time:.2f} seconds")
253
+ print(f"Average time per sample: {total_time / len(responses):.2f} seconds")
254
+ else:
255
+ print("No responses were generated. Check the logs for more information.")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ parser = argparse.ArgumentParser(description="Generate synthetic instruction-response pairs from materialized patient timelines using async processing.")
260
+ parser.add_argument("--materialized_ehr_folder", type=str, help="Folder containing materialized EHR JSONL files")
261
+ parser.add_argument("--prompt_template", type=str)
262
+ parser.add_argument("--prompt_method", type=str, default="persona", help="Prompt method: general, persona")
263
+ parser.add_argument("--ehr_context",type=validate_ehr_context,default=16384,help="EHR context to use: 'full', 'last_five', or a positive integer for context length (e.g., 8192)")
264
+ parser.add_argument("--output_folder", type=str)
265
+ parser.add_argument("--project_id", type=str)
266
+ parser.add_argument("--location", type=str, default="us-central1", help="Google Cloud location")
267
+ parser.add_argument("--model_name", type=str)
268
+ parser.add_argument("--max_samples", type=int)
269
+ parser.add_argument("--max_concurrent_calls")
270
+ parser.add_argument("--csv_file")
271
+ parser.add_argument("--dataset_id")
272
+ parser.add_argument("--table_id", type=str)
273
+ args = parser.parse_args()
274
+
275
+ asyncio.run(main(args))
preprocess/__init__.py ADDED
File without changes
preprocess/wtts_builder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+ import pandas as pd
5
+ import argparse
6
+ from datetime import datetime
7
+ from src.utils.data_loader import DataLoader
8
+
9
+ class WTTSBuilder:
10
+ def __init__(self):
11
+ # --- WEIGHTING RULES (W) ---
12
+ self.critical_patterns = [
13
+ r'respiratory failure', r'seizure', r'cardiac arrest', r'intubat',
14
+ r'abnormal', r'critical', r'hemorrhage', r'positive',
15
+ r'emergency', r'acute', r'hypoxia', r'flagged', r'icu',
16
+ r'dyspnea', r'shortness of breath', r'sob', r'mrc grade', r'nyha'
17
+ ]
18
+ self.chronic_patterns = [
19
+ r'history of', r'chronic', r'stable', r'continued',
20
+ r'maintained', r'diagnosed with', r'previous', r'known'
21
+ ]
22
+ self.routine_patterns = [
23
+ r'routine', r'normal', r'negative', r'unremarkable',
24
+ r'no acute', r'clear', r'regular diet', r'resting'
25
+ ]
26
+
27
+ def _get_normalized_time(self, event_time_str, admit_str, disch_str):
28
+ """Calculates P_j (0.0 to 1.0)"""
29
+ try:
30
+ e_dt = pd.to_datetime(event_time_str)
31
+ a_dt = pd.to_datetime(admit_str)
32
+ d_dt = pd.to_datetime(disch_str)
33
+
34
+ total_duration = (d_dt - a_dt).total_seconds()
35
+ elapsed = (e_dt - a_dt).total_seconds()
36
+
37
+ if total_duration <= 0: return 1.0
38
+ return round(max(0.0, min(1.0, elapsed / total_duration)), 2)
39
+ except:
40
+ return 0.5
41
+
42
+ def _get_weight(self, text):
43
+ t = text.lower()
44
+ if any(re.search(p, t) for p in self.critical_patterns): return 1.0
45
+ if any(re.search(p, t) for p in self.chronic_patterns): return 0.5
46
+ if any(re.search(p, t) for p in self.routine_patterns): return 0.1
47
+ return 0.5
48
+
49
+ def _extract_sentences_with_ids(self, text, start_index):
50
+ """
51
+ Splits notes and assigns UNIQUE IDs.
52
+ CRITICAL FIX: Sanitizes newlines to preserve WTTS structure.
53
+ """
54
+ # 1. Replace newlines/tabs with spaces to keep tuple on one line
55
+ text = text.replace('\n', ' ').replace('\r', '').replace('\t', ' ')
56
+
57
+ # 2. Remove de-id brackets
58
+ text = re.sub(r'\[\*\*.*?\*\*\]', '', text)
59
+
60
+ # 3. Split by sentence boundaries
61
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
62
+
63
+ results = []
64
+ current_idx = start_index
65
+ for s in sentences:
66
+ clean_s = s.strip()
67
+ # Ignore very short/empty fragments
68
+ if len(clean_s) > 5:
69
+ sid = f"S_{current_idx}"
70
+ results.append((sid, clean_s))
71
+ current_idx += 1
72
+ return results, current_idx
73
+
74
+ def build_wtts_string(self, patient_data):
75
+ tuples = []
76
+ admit = patient_data.get('admission_time')
77
+ disch = patient_data.get('discharge_time')
78
+
79
+ sorted_notes = sorted(patient_data.get('notes', []), key=lambda x: x['timestamp'])
80
+
81
+ global_sent_idx = 0
82
+
83
+ for note in sorted_notes:
84
+ raw_ts = note['timestamp']
85
+ p_j = self._get_normalized_time(raw_ts, admit, disch)
86
+
87
+ events, global_sent_idx = self._extract_sentences_with_ids(note['text'], global_sent_idx)
88
+
89
+ for (sid, event) in events:
90
+ w = self._get_weight(event)
91
+ # Format: [ID] ("Date", "Event", P_j, W)
92
+ tuples.append(f'[{sid}] ("{raw_ts}", "{event}", {p_j}, {w})')
93
+
94
+ return " | ".join(tuples)
95
+
96
+ # --- EXECUTION LOGIC ---
97
+ if __name__ == "__main__":
98
+ parser = argparse.ArgumentParser(description="Process raw clinical notes into WTTS format.")
99
+ # Set defaults for easier running
100
+ parser.add_argument("--input_dirs", nargs="+",
101
+ default=[
102
+ "data/raw/dyspnea-clinical-notes",
103
+ "data/raw/dyspnea-crf-development",
104
+ ],
105
+ help="Directories containing .parquet shards (searched recursively)")
106
+ parser.add_argument("--gt_file", type=str,
107
+ default="data/raw/dev_gt.jsonl",
108
+ help="Path to the ground truth JSONL file.")
109
+ parser.add_argument("--output_dir", type=str,
110
+ default="data/processed/materialized_ehr",
111
+ help="Path to store processed JSONL files.")
112
+
113
+ args = parser.parse_args()
114
+
115
+ loader = DataLoader(data_folders=args.input_dirs, gt_path=args.gt_file)
116
+ builder = WTTSBuilder()
117
+
118
+ patients = loader.load_and_merge()
119
+ os.makedirs(args.output_dir, exist_ok=True)
120
+
121
+ if not patients:
122
+ print("No patients found! Check paths.")
123
+ else:
124
+ print(f"Materializing timelines for {len(patients)} patients...")
125
+ for p in patients:
126
+ wtts_output = builder.build_wtts_string(p)
127
+
128
+ # FIX: Prioritize document_id to match DataLoader logic
129
+ pid = str(p.get('document_id') or p.get('patient_id') or p.get('hadm_id'))
130
+
131
+ output_path = os.path.join(args.output_dir, f"{pid}.jsonl")
132
+ with open(output_path, 'w') as f:
133
+ json.dump({"person_id": pid, "text": wtts_output}, f)
134
+
135
+ print(f"Successfully stored outputs in: {args.output_dir}")
rag/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # RAG-guided CRF extraction module
2
+ from src.rag.embedder import WTTSEmbedder
3
+ from src.rag.retriever import WTTSRetriever
4
+ from src.rag.rag_pipeline import RAGCRFExtractor
rag/embedder.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WTTS Tuple Embedder — Embeds clinical event tuples and CRF queries
3
+ into vector space using SentenceTransformers.
4
+
5
+ Swap the model_name to a clinical model (e.g., MedCPT) when GPU is available.
6
+ """
7
+
8
+ import re
9
+ import numpy as np
10
+ from typing import List, Dict, Optional
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+
14
+ class WTTSEmbedder:
15
+ """Embeds WTTS tuples and CRF item queries into dense vectors."""
16
+
17
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2", device: str = "cpu"):
18
+ """
19
+ Args:
20
+ model_name: SentenceTransformer model ID.
21
+ CPU default: 'all-MiniLM-L6-v2' (384-dim, fast)
22
+ GPU clinical: 'pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb'
23
+ or 'medicalai/ClinicalBERT'
24
+ device: 'cpu' or 'cuda'
25
+ """
26
+ print(f"Loading embedding model: {model_name} on {device}...")
27
+ self.model = SentenceTransformer(model_name, device=device)
28
+ self.device = device
29
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
30
+ print(f" Embedding dimension: {self.embedding_dim}")
31
+
32
+ # ------------------------------------------------------------------
33
+ # Parse WTTS string → structured list of tuple dicts
34
+ # ------------------------------------------------------------------
35
+ def parse_wtts_string(self, wtts_string: str) -> List[Dict]:
36
+ """
37
+ Parses the WTTS pipe-delimited string back into structured dicts.
38
+
39
+ Input format: [S_0] ("2026-01-01", "event text", 0.5, 1.0) | [S_1] (...)
40
+ Output: [
41
+ {"sid": "S_0", "timestamp": "2026-01-01", "event": "event text", "p_j": 0.5, "weight": 1.0},
42
+ ...
43
+ ]
44
+ """
45
+ tuples = []
46
+ # Split by pipe separator
47
+ raw_entries = wtts_string.split(" | ")
48
+
49
+ # Pattern to extract: [S_xx] ("timestamp", "event", P_j, W)
50
+ pattern = re.compile(
51
+ r'\[(?P<sid>S_\d+)\]\s*\('
52
+ r'"(?P<timestamp>[^"]*)",\s*'
53
+ r'"(?P<event>[^"]*)",\s*'
54
+ r'(?P<p_j>[\d.]+),\s*'
55
+ r'(?P<weight>[\d.]+)\)'
56
+ )
57
+
58
+ for entry in raw_entries:
59
+ entry = entry.strip()
60
+ if not entry:
61
+ continue
62
+ match = pattern.search(entry)
63
+ if match:
64
+ tuples.append({
65
+ "sid": match.group("sid"),
66
+ "timestamp": match.group("timestamp"),
67
+ "event": match.group("event"),
68
+ "p_j": float(match.group("p_j")),
69
+ "weight": float(match.group("weight")),
70
+ })
71
+
72
+ return tuples
73
+
74
+ # ------------------------------------------------------------------
75
+ # Embed tuple event texts
76
+ # ------------------------------------------------------------------
77
+ def embed_tuples(self, tuples: List[Dict]) -> np.ndarray:
78
+ """
79
+ Embed the event text from each WTTS tuple.
80
+
81
+ Args:
82
+ tuples: List of parsed tuple dicts (from parse_wtts_string)
83
+
84
+ Returns:
85
+ np.ndarray of shape (n_tuples, embedding_dim)
86
+ """
87
+ if not tuples:
88
+ return np.array([])
89
+
90
+ texts = [t["event"] for t in tuples]
91
+ embeddings = self.model.encode(
92
+ texts,
93
+ show_progress_bar=False,
94
+ normalize_embeddings=True, # L2-normalize for cosine similarity via dot product
95
+ batch_size=64,
96
+ )
97
+ return np.array(embeddings, dtype=np.float32)
98
+
99
+ # ------------------------------------------------------------------
100
+ # Embed a CRF item query
101
+ # ------------------------------------------------------------------
102
+ def embed_query(self, crf_item: str, valid_options: Optional[List[str]] = None) -> np.ndarray:
103
+ """
104
+ Create an embedding for a CRF item query.
105
+ Combines the item name with its valid options to create a richer query.
106
+
107
+ Args:
108
+ crf_item: e.g., "mrc_grade" or "administration of bronchodilators"
109
+ valid_options: e.g., ["y", "n", "unknown"]
110
+
111
+ Returns:
112
+ np.ndarray of shape (embedding_dim,)
113
+ """
114
+ # Build a descriptive query string
115
+ query_parts = [crf_item.replace("_", " ")]
116
+
117
+ if valid_options:
118
+ # Add option context to help embedding understand what we're looking for
119
+ opts_str = ", ".join(str(o) for o in valid_options[:10]) # limit to avoid huge queries
120
+ query_parts.append(f"options: {opts_str}")
121
+
122
+ query_text = " | ".join(query_parts)
123
+
124
+ embedding = self.model.encode(
125
+ [query_text],
126
+ show_progress_bar=False,
127
+ normalize_embeddings=True,
128
+ )
129
+ return np.array(embedding[0], dtype=np.float32)
130
+
131
+ # ------------------------------------------------------------------
132
+ # Batch embed multiple CRF queries at once
133
+ # ------------------------------------------------------------------
134
+ def embed_queries_batch(
135
+ self,
136
+ crf_items: List[str],
137
+ valid_options_map: Dict[str, List[str]]
138
+ ) -> Dict[str, np.ndarray]:
139
+ """
140
+ Embed all CRF items in one batch for efficiency.
141
+
142
+ Returns:
143
+ Dict mapping crf_item name → embedding vector
144
+ """
145
+ query_texts = []
146
+ item_names = []
147
+
148
+ for item in crf_items:
149
+ item_names.append(item)
150
+ parts = [item.replace("_", " ")]
151
+ opts = valid_options_map.get(item, [])
152
+ if opts:
153
+ opts_str = ", ".join(str(o) for o in opts[:10])
154
+ parts.append(f"options: {opts_str}")
155
+ query_texts.append(" | ".join(parts))
156
+
157
+ embeddings = self.model.encode(
158
+ query_texts,
159
+ show_progress_bar=False,
160
+ normalize_embeddings=True,
161
+ batch_size=64,
162
+ )
163
+
164
+ return {
165
+ name: np.array(emb, dtype=np.float32)
166
+ for name, emb in zip(item_names, embeddings)
167
+ }
rag/rag_pipeline.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG CRF Extractor — Orchestrates the full RAG-guided pipeline:
3
+ 1. Build WTTS tuples from patient data
4
+ 2. Embed all tuples into FAISS index
5
+ 3. For each CRF item: retrieve relevant tuples → LLM extraction
6
+ 4. Return predictions
7
+
8
+ This replaces the two-pass (Skeleton → Extraction) approach in main.py.
9
+ """
10
+
11
+ import asyncio
12
+ import json
13
+ from typing import Dict, List, Optional, Any
14
+
15
+ from src.rag.embedder import WTTSEmbedder
16
+ from src.rag.retriever import WTTSRetriever
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # RAG-optimized prompt — shorter, focused on retrieved evidence only
21
+ # ---------------------------------------------------------------------------
22
+
23
+ RAG_EXTRACTION_PROMPT = """\
24
+ You are a Clinical Coding Expert.
25
+
26
+ RETRIEVED CLINICAL EVIDENCE (sorted chronologically, most relevant events for these items):
27
+ {retrieved_evidence}
28
+
29
+ TASK:
30
+ For each Clinical Item below, determine the value based ONLY on the evidence above.
31
+ 1. **Value**: Must come strictly from the "Valid Options".
32
+ 2. **Evidence**: Cite the [S_xx] ID that supports your choice.
33
+ 3. If no evidence supports any option, choose "unknown".
34
+
35
+ ITEMS TO EXTRACT & THEIR OPTIONS:
36
+ {chunk_schema_json}
37
+
38
+ OUTPUT FORMAT (JSON Object):
39
+ {{
40
+ "item_name": {{
41
+ "value": "Selected Option",
42
+ "evidence": "S_xx",
43
+ "reasoning": "Brief explanation"
44
+ }},
45
+ ...
46
+ }}
47
+ """
48
+
49
+
50
+ class RAGCRFExtractor:
51
+ """
52
+ Orchestrates RAG-guided CRF extraction for clinical notes.
53
+ Replaces the two-pass (Skeleton → Extraction) pipeline with
54
+ per-item retrieval for focused, temporally-ordered evidence.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ embedder: WTTSEmbedder,
60
+ generate_fn,
61
+ top_k: int = 15,
62
+ weight_boost: float = 0.3,
63
+ items_per_chunk: int = 5,
64
+ ):
65
+ """
66
+ Args:
67
+ embedder: WTTSEmbedder instance (shared across patients)
68
+ generate_fn: Async function to call LLM (the generate_async from main.py)
69
+ top_k: Number of tuples to retrieve per CRF item group
70
+ weight_boost: Re-ranking boost for critical events
71
+ items_per_chunk: How many CRF items to group per LLM call.
72
+ Grouped items share the same retrieved evidence pool.
73
+ """
74
+ self.embedder = embedder
75
+ self.generate_fn = generate_fn
76
+ self.top_k = top_k
77
+ self.weight_boost = weight_boost
78
+ self.items_per_chunk = items_per_chunk
79
+
80
+ # ------------------------------------------------------------------
81
+ # Group CRF items by semantic similarity for batched retrieval
82
+ # ------------------------------------------------------------------
83
+ def _group_crf_items(
84
+ self,
85
+ target_items: List[str],
86
+ valid_options: Dict[str, List[str]],
87
+ ) -> List[List[str]]:
88
+ """
89
+ Group CRF items into chunks. Items in the same chunk will share
90
+ a combined retrieval query, so similar items get grouped together.
91
+
92
+ For now: simple sequential chunking (items_per_chunk at a time).
93
+ Future: cluster by embedding similarity of item names.
94
+ """
95
+ chunks = []
96
+ for i in range(0, len(target_items), self.items_per_chunk):
97
+ chunk = target_items[i : i + self.items_per_chunk]
98
+ chunks.append(chunk)
99
+ return chunks
100
+
101
+ # ------------------------------------------------------------------
102
+ # Build combined query for a group of CRF items
103
+ # ------------------------------------------------------------------
104
+ def _build_group_query(
105
+ self,
106
+ items: List[str],
107
+ valid_options: Dict[str, List[str]],
108
+ ):
109
+ """
110
+ Create a combined query embedding for a group of CRF items.
111
+ Averages the individual item query embeddings.
112
+ """
113
+ query_embeddings = self.embedder.embed_queries_batch(items, valid_options)
114
+
115
+ # Average the embeddings for a combined query
116
+ import numpy as np
117
+ all_embs = list(query_embeddings.values())
118
+ combined = np.mean(all_embs, axis=0).astype(np.float32)
119
+
120
+ # Re-normalize after averaging
121
+ norm = np.linalg.norm(combined)
122
+ if norm > 0:
123
+ combined = combined / norm
124
+
125
+ return combined
126
+
127
+ # ------------------------------------------------------------------
128
+ # Main extraction method — full RAG pipeline for one patient
129
+ # ------------------------------------------------------------------
130
+ async def extract_patient(
131
+ self,
132
+ patient_data: Dict,
133
+ builder, # WTTSBuilder instance
134
+ target_items: List[str],
135
+ valid_options: Dict[str, List[str]],
136
+ semaphore: asyncio.Semaphore,
137
+ model: Any = None, # Gemini model (passed to generate_fn)
138
+ ) -> Optional[Dict]:
139
+ """
140
+ Full RAG pipeline for a single patient:
141
+ 1. Build WTTS string
142
+ 2. Parse & embed tuples → FAISS index
143
+ 3. For each CRF item group: retrieve → prompt → extract
144
+ 4. Return predictions
145
+
146
+ Args:
147
+ patient_data: Merged patient dict from DataLoader
148
+ builder: WTTSBuilder instance
149
+ target_items: List of CRF item names to extract
150
+ valid_options: Dict mapping item name → list of valid values
151
+ semaphore: Concurrency limiter for LLM calls
152
+ model: Gemini model instance
153
+
154
+ Returns:
155
+ Dict with patient_id, predictions, and debug info
156
+ """
157
+ pid = str(
158
+ patient_data.get('document_id')
159
+ or patient_data.get('patient_id')
160
+ or patient_data.get('hadm_id')
161
+ or 'unknown'
162
+ )
163
+
164
+ try:
165
+ # --- Step 1: Build WTTS tuples ---
166
+ wtts_string = builder.build_wtts_string(patient_data)
167
+
168
+ if not wtts_string.strip():
169
+ print(f" [{pid}] No WTTS tuples generated, skipping.")
170
+ return None
171
+
172
+ # --- Step 2: Parse and embed tuples ---
173
+ tuples = self.embedder.parse_wtts_string(wtts_string)
174
+
175
+ if not tuples:
176
+ print(f" [{pid}] Failed to parse WTTS tuples, skipping.")
177
+ return None
178
+
179
+ tuple_embeddings = self.embedder.embed_tuples(tuples)
180
+
181
+ # --- Step 3: Build FAISS index for this patient ---
182
+ retriever = WTTSRetriever(self.embedder)
183
+ retriever.build_index(tuples, tuple_embeddings)
184
+
185
+ print(f" [{pid}] Indexed {len(tuples)} tuples. "
186
+ f"Retrieving for {len(target_items)} CRF items...")
187
+
188
+ # --- Step 4: Group CRF items and extract ---
189
+ item_groups = self._group_crf_items(target_items, valid_options)
190
+ final_predictions = {}
191
+
192
+ for group_items in item_groups:
193
+ # Build combined query for this group
194
+ group_query = self._build_group_query(group_items, valid_options)
195
+
196
+ # Retrieve relevant tuples
197
+ retrieved = retriever.retrieve_and_rerank(
198
+ group_query,
199
+ top_k=self.top_k,
200
+ weight_boost=self.weight_boost,
201
+ )
202
+
203
+ # Format for LLM
204
+ evidence_str = WTTSRetriever.format_retrieved_tuples(retrieved)
205
+
206
+ chunk_schema = {
207
+ item: valid_options.get(item, ["y", "n", "unknown"])
208
+ for item in group_items
209
+ }
210
+
211
+ prompt = RAG_EXTRACTION_PROMPT.format(
212
+ retrieved_evidence=evidence_str,
213
+ chunk_schema_json=json.dumps(chunk_schema, indent=2),
214
+ )
215
+
216
+ # Call LLM with concurrency control
217
+ async with semaphore:
218
+ response = await self.generate_fn(prompt, model)
219
+
220
+ if isinstance(response, dict):
221
+ if "error" in response:
222
+ print(f" [{pid}] LLM error for items {group_items[:2]}...: "
223
+ f"{response['error']}")
224
+ else:
225
+ final_predictions.update(response)
226
+
227
+ # --- Step 5: Return results ---
228
+ return {
229
+ "patient_id": pid,
230
+ "predictions": final_predictions,
231
+ "rag_debug": {
232
+ "total_tuples": len(tuples),
233
+ "top_k": self.top_k,
234
+ "item_groups": len(item_groups),
235
+ },
236
+ }
237
+
238
+ except Exception as e:
239
+ print(f" [{pid}] RAG extraction error: {e}")
240
+ import traceback
241
+ traceback.print_exc()
242
+ return None
rag/retriever.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WTTS Retriever — Builds per-patient FAISS index and retrieves
3
+ relevant tuples per CRF item with weight-based re-ranking
4
+ and P_j temporal sorting.
5
+ """
6
+
7
+ import numpy as np
8
+ import faiss
9
+ from typing import List, Dict, Optional, Tuple
10
+ from src.rag.embedder import WTTSEmbedder
11
+
12
+
13
+ class WTTSRetriever:
14
+ """Per-patient FAISS index for retrieving relevant WTTS tuples."""
15
+
16
+ def __init__(self, embedder: WTTSEmbedder):
17
+ self.embedder = embedder
18
+ self.index: Optional[faiss.IndexFlatIP] = None # Inner product (cosine on normalized vecs)
19
+ self.tuples: List[Dict] = []
20
+ self.embeddings: Optional[np.ndarray] = None
21
+
22
+ # ------------------------------------------------------------------
23
+ # Build index for one patient's tuples
24
+ # ------------------------------------------------------------------
25
+ def build_index(self, tuples: List[Dict], embeddings: np.ndarray):
26
+ """
27
+ Build a FAISS index from pre-computed tuple embeddings.
28
+
29
+ Args:
30
+ tuples: Parsed WTTS tuple dicts
31
+ embeddings: np.ndarray of shape (n_tuples, embedding_dim)
32
+ """
33
+ self.tuples = tuples
34
+ self.embeddings = embeddings
35
+
36
+ if len(tuples) == 0:
37
+ self.index = None
38
+ return
39
+
40
+ dim = embeddings.shape[1]
41
+ # Use Inner Product (IP) since embeddings are L2-normalized
42
+ # This makes IP equivalent to cosine similarity
43
+ self.index = faiss.IndexFlatIP(dim)
44
+ self.index.add(embeddings)
45
+
46
+ # ------------------------------------------------------------------
47
+ # Raw retrieval (top-k by cosine similarity)
48
+ # ------------------------------------------------------------------
49
+ def retrieve(self, query_embedding: np.ndarray, top_k: int = 15) -> List[Dict]:
50
+ """
51
+ Retrieve top-k most similar tuples to the query.
52
+
53
+ Args:
54
+ query_embedding: 1D vector of shape (embedding_dim,)
55
+ top_k: Number of tuples to retrieve
56
+
57
+ Returns:
58
+ List of tuple dicts with added 'similarity_score' field
59
+ """
60
+ if self.index is None or len(self.tuples) == 0:
61
+ return []
62
+
63
+ # Clamp top_k to available tuples
64
+ top_k = min(top_k, len(self.tuples))
65
+
66
+ # FAISS expects 2D input
67
+ query = query_embedding.reshape(1, -1).astype(np.float32)
68
+ scores, indices = self.index.search(query, top_k)
69
+
70
+ results = []
71
+ for score, idx in zip(scores[0], indices[0]):
72
+ if idx == -1: # FAISS returns -1 for missing results
73
+ continue
74
+ result = self.tuples[idx].copy()
75
+ result["similarity_score"] = float(score)
76
+ results.append(result)
77
+
78
+ return results
79
+
80
+ # ------------------------------------------------------------------
81
+ # Retrieve + re-rank by weight + sort by P_j
82
+ # ------------------------------------------------------------------
83
+ def retrieve_and_rerank(
84
+ self,
85
+ query_embedding: np.ndarray,
86
+ top_k: int = 15,
87
+ weight_boost: float = 0.3,
88
+ fetch_multiplier: int = 3,
89
+ ) -> List[Dict]:
90
+ """
91
+ Retrieve, re-rank using weight W, then sort by P_j for temporal order.
92
+
93
+ Strategy:
94
+ 1. Over-fetch (top_k * fetch_multiplier) candidates from FAISS
95
+ 2. Re-score: final_score = similarity + weight_boost * W
96
+ 3. Take top_k by final_score
97
+ 4. Sort the final set by P_j (ascending) to preserve temporal order
98
+
99
+ Args:
100
+ query_embedding: 1D vector
101
+ top_k: Final number of tuples to return
102
+ weight_boost: How much to boost critical events (W=1.0 gets +0.3)
103
+ fetch_multiplier: How many extra candidates to fetch for re-ranking
104
+
105
+ Returns:
106
+ List of tuple dicts sorted by P_j (temporal order),
107
+ each with 'similarity_score', 'rerank_score' fields
108
+ """
109
+ if self.index is None or len(self.tuples) == 0:
110
+ return []
111
+
112
+ # Step 1: Over-fetch candidates
113
+ fetch_k = min(top_k * fetch_multiplier, len(self.tuples))
114
+ candidates = self.retrieve(query_embedding, top_k=fetch_k)
115
+
116
+ # Step 2: Re-rank with weight boost
117
+ for candidate in candidates:
118
+ sim = candidate["similarity_score"]
119
+ w = candidate.get("weight", 0.5)
120
+ candidate["rerank_score"] = sim + (weight_boost * w)
121
+
122
+ # Step 3: Take top_k by re-rank score
123
+ candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
124
+ top_candidates = candidates[:top_k]
125
+
126
+ # Step 4: Sort by P_j (temporal order) — THIS is what preserves continuity
127
+ top_candidates.sort(key=lambda x: x.get("p_j", 0.5))
128
+
129
+ return top_candidates
130
+
131
+ # ------------------------------------------------------------------
132
+ # Format retrieved tuples back into a readable string for the LLM
133
+ # ------------------------------------------------------------------
134
+ @staticmethod
135
+ def format_retrieved_tuples(tuples: List[Dict]) -> str:
136
+ """
137
+ Format retrieved tuples into a clean string for the LLM prompt.
138
+ Sorted by P_j (temporal order) by this point.
139
+
140
+ Output format:
141
+ [S_14] (EARLY | W:HIGH) "Patient reports increasing dyspnea"
142
+ [S_80] (MID | W:HIGH) "SpO2 dropped to 85%, intubated"
143
+ [S_155](LATE | W:MED) "Stable on room air at discharge"
144
+ """
145
+ if not tuples:
146
+ return "(No relevant clinical events found)"
147
+
148
+ lines = []
149
+ for t in tuples:
150
+ p_j = t.get("p_j", 0.5)
151
+ w = t.get("weight", 0.5)
152
+
153
+ # Temporal phase label
154
+ if p_j <= 0.15:
155
+ phase = "ADMISSION"
156
+ elif p_j <= 0.35:
157
+ phase = "EARLY"
158
+ elif p_j <= 0.65:
159
+ phase = "MID"
160
+ elif p_j <= 0.85:
161
+ phase = "LATE"
162
+ else:
163
+ phase = "DISCHARGE"
164
+
165
+ # Weight label
166
+ if w >= 0.8:
167
+ w_label = "CRITICAL"
168
+ elif w >= 0.4:
169
+ w_label = "MODERATE"
170
+ else:
171
+ w_label = "ROUTINE"
172
+
173
+ sid = t.get("sid", "S_?")
174
+ event = t.get("event", "")
175
+ ts = t.get("timestamp", "")
176
+
177
+ lines.append(f'[{sid}] ({phase} | {w_label}) [{ts}] "{event}"')
178
+
179
+ return "\n".join(lines)