Spaces:
Build error
Build error
Upload 8 files
Browse files- main.py +416 -0
- predictor.py +275 -0
- preprocess/__init__.py +0 -0
- preprocess/wtts_builder.py +135 -0
- rag/__init__.py +4 -0
- rag/embedder.py +167 -0
- rag/rag_pipeline.py +242 -0
- rag/retriever.py +179 -0
main.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# Imports from your specific file structure
|
| 8 |
+
from src.preprocess.wtts_builder import WTTSBuilder
|
| 9 |
+
from src.utils.data_loader import DataLoader
|
| 10 |
+
|
| 11 |
+
# RAG pipeline imports (lazy — only used when --use_rag is set)
|
| 12 |
+
RAG_AVAILABLE = False
|
| 13 |
+
try:
|
| 14 |
+
from src.rag.embedder import WTTSEmbedder
|
| 15 |
+
from src.rag.rag_pipeline import RAGCRFExtractor
|
| 16 |
+
RAG_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
import google.generativeai as genai
|
| 21 |
+
|
| 22 |
+
# --- CONFIG ---
|
| 23 |
+
API_KEY = "AIzaSyAkgKha4IxsCjRXbeirhyoygT9Qmr4qYzU"
|
| 24 |
+
|
| 25 |
+
# --- PROMPTS ---
|
| 26 |
+
|
| 27 |
+
SKELETON_PROMPT = """
|
| 28 |
+
You are a Clinical Data Specialist.
|
| 29 |
+
Convert the Weighted Time Series (WTTS) below into a "Clinical Chronology Skeleton".
|
| 30 |
+
|
| 31 |
+
INPUT (WTTS):
|
| 32 |
+
{wtts_string}
|
| 33 |
+
|
| 34 |
+
INSTRUCTIONS:
|
| 35 |
+
1. Create a strict chronological timeline (Admission to Discharge).
|
| 36 |
+
2. IMPORTANT: You MUST retain the [S_xx] ID for every event you list.
|
| 37 |
+
3. Filter out "Routine" (Weight 0.1) events unless they indicate a status change.
|
| 38 |
+
4. Keep exact values (e.g., "BP 90/60", "Temp 102.5").
|
| 39 |
+
|
| 40 |
+
OUTPUT FORMAT:
|
| 41 |
+
[Date] [S_xx]: Event details
|
| 42 |
+
[Date] [S_xx]: Event details
|
| 43 |
+
...
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
EXTRACTION_PROMPT = """
|
| 47 |
+
You are a Clinical Coding Expert.
|
| 48 |
+
Review the Patient Skeleton and the Valid Options for the requested items.
|
| 49 |
+
|
| 50 |
+
PATIENT SKELETON:
|
| 51 |
+
{skeleton}
|
| 52 |
+
|
| 53 |
+
TASK:
|
| 54 |
+
For each Clinical Item listed below, determine the value AND the supporting Sentence ID.
|
| 55 |
+
1. **Value**: Must come strictly from the "Valid Options" provided.
|
| 56 |
+
2. **Evidence**: Must be the specific [S_xx] ID from the skeleton that proves the value.
|
| 57 |
+
|
| 58 |
+
ITEMS TO EXTRACT & THEIR OPTIONS:
|
| 59 |
+
{chunk_schema_json}
|
| 60 |
+
|
| 61 |
+
OUTPUT FORMAT (JSON Object):
|
| 62 |
+
{{
|
| 63 |
+
"item_name": {{
|
| 64 |
+
"value": "Selected Option",
|
| 65 |
+
"evidence": "S_xx",
|
| 66 |
+
"reasoning": "Brief explanation"
|
| 67 |
+
}},
|
| 68 |
+
...
|
| 69 |
+
}}
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def chunk_data(data, size):
|
| 73 |
+
"""Yield successive n-sized chunks from list."""
|
| 74 |
+
for i in range(0, len(data), size):
|
| 75 |
+
yield data[i:i + size]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
async def generate_async(prompt, model, max_retries=3, initial_delay=1):
|
| 79 |
+
"""Call Gemini via google-generativeai SDK (async-safe)."""
|
| 80 |
+
loop = asyncio.get_event_loop()
|
| 81 |
+
for attempt in range(max_retries):
|
| 82 |
+
try:
|
| 83 |
+
response = await loop.run_in_executor(
|
| 84 |
+
None,
|
| 85 |
+
lambda: model.generate_content(
|
| 86 |
+
contents=prompt,
|
| 87 |
+
generation_config=genai.GenerationConfig(
|
| 88 |
+
response_mime_type="application/json"
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
try:
|
| 93 |
+
json_response = json.loads(response.text)
|
| 94 |
+
return json_response
|
| 95 |
+
except json.JSONDecodeError:
|
| 96 |
+
print(f"Generated content is not valid JSON. Retrying...")
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
error_message = str(e)
|
| 101 |
+
if "429" in error_message or "500" in error_message:
|
| 102 |
+
if attempt < max_retries - 1:
|
| 103 |
+
delay = initial_delay * (2 ** attempt)
|
| 104 |
+
print(f"Rate limit / server error. Retrying in {delay}s...")
|
| 105 |
+
await asyncio.sleep(delay)
|
| 106 |
+
else:
|
| 107 |
+
print(f"Max retries reached.")
|
| 108 |
+
return {"error": f"Max retries reached - {error_message}"}
|
| 109 |
+
else:
|
| 110 |
+
print(f"Error in generate_async: {error_message}")
|
| 111 |
+
return {"error": error_message}
|
| 112 |
+
|
| 113 |
+
return {"error": "Failed to generate valid JSON after multiple attempts"}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
async def process_patient(model, builder, patient_data, target_items, valid_options, semaphore):
|
| 117 |
+
"""Executes the Two-Pass Pipeline for a single patient."""
|
| 118 |
+
async with semaphore:
|
| 119 |
+
pid = str(patient_data.get('document_id') or patient_data.get('patient_id')
|
| 120 |
+
or patient_data.get('hadm_id') or 'unknown')
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# --- PHASE 1: WTTS Construction ---
|
| 124 |
+
wtts_string = builder.build_wtts_string(patient_data)
|
| 125 |
+
|
| 126 |
+
# --- PHASE 2: Skeleton Generation (Pass 1) ---
|
| 127 |
+
skeleton_input = SKELETON_PROMPT.format(wtts_string=wtts_string)
|
| 128 |
+
skeleton_resp = await generate_async(skeleton_input, model)
|
| 129 |
+
|
| 130 |
+
skeleton_text = str(skeleton_resp)
|
| 131 |
+
if isinstance(skeleton_resp, dict):
|
| 132 |
+
skeleton_text = json.dumps(skeleton_resp)
|
| 133 |
+
|
| 134 |
+
# --- PHASE 3: Extraction (Pass 2) ---
|
| 135 |
+
final_predictions = {}
|
| 136 |
+
item_chunks = list(chunk_data(target_items, 10))
|
| 137 |
+
|
| 138 |
+
for chunk_items in item_chunks:
|
| 139 |
+
chunk_schema = {
|
| 140 |
+
item: valid_options.get(item, ["Yes", "No", "Unknown"])
|
| 141 |
+
for item in chunk_items
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
extract_input = EXTRACTION_PROMPT.format(
|
| 145 |
+
skeleton=skeleton_text,
|
| 146 |
+
chunk_schema_json=json.dumps(chunk_schema)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
chunk_resp = await generate_async(extract_input, model)
|
| 150 |
+
|
| 151 |
+
if isinstance(chunk_resp, dict):
|
| 152 |
+
if 'error' in chunk_resp:
|
| 153 |
+
print(f" [WARN] LLM error for {pid}, chunk {chunk_items[:3]}...: {chunk_resp['error']}")
|
| 154 |
+
else:
|
| 155 |
+
final_predictions.update(chunk_resp)
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
"patient_id": pid,
|
| 159 |
+
"skeleton_debug": skeleton_text[:500] + "...",
|
| 160 |
+
"predictions": final_predictions
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Error processing {pid}: {e}")
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ---------------------------------------------------------------------------
|
| 169 |
+
# EVALUATION -- Accuracy & F1 Scoring
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
def _normalise(value):
|
| 173 |
+
"""Lowercase + strip for fair comparison."""
|
| 174 |
+
if value is None:
|
| 175 |
+
return ""
|
| 176 |
+
return str(value).strip().lower()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def evaluate_predictions(results, gt_path):
|
| 180 |
+
"""
|
| 181 |
+
Compare pipeline results against dev_gt.jsonl.
|
| 182 |
+
Prints accuracy, macro-F1, per-item breakdown, and sample errors.
|
| 183 |
+
Returns (overall_dict, per_item_dict).
|
| 184 |
+
"""
|
| 185 |
+
# --- Load GT ---
|
| 186 |
+
gt = {}
|
| 187 |
+
with open(gt_path, 'r', encoding='utf-8') as f:
|
| 188 |
+
for line in f:
|
| 189 |
+
line = line.strip()
|
| 190 |
+
if not line:
|
| 191 |
+
continue
|
| 192 |
+
rec = json.loads(line)
|
| 193 |
+
doc_id = str(rec['document_id'])
|
| 194 |
+
gt[doc_id] = {a['item']: a['ground_truth'] for a in rec.get('annotations', [])}
|
| 195 |
+
|
| 196 |
+
# --- Build prediction lookup ---
|
| 197 |
+
preds = {}
|
| 198 |
+
for r in results:
|
| 199 |
+
doc_id = str(r.get('patient_id', 'unknown'))
|
| 200 |
+
items = {}
|
| 201 |
+
for item_name, item_val in r.get('predictions', {}).items():
|
| 202 |
+
if isinstance(item_val, dict):
|
| 203 |
+
items[item_name] = item_val.get('value', str(item_val))
|
| 204 |
+
else:
|
| 205 |
+
items[item_name] = str(item_val)
|
| 206 |
+
preds[doc_id] = items
|
| 207 |
+
|
| 208 |
+
# --- Collect all unique items ---
|
| 209 |
+
all_items = set()
|
| 210 |
+
for doc_items in gt.values():
|
| 211 |
+
all_items.update(doc_items.keys())
|
| 212 |
+
|
| 213 |
+
# --- Score ---
|
| 214 |
+
item_stats = {item: {'tp': 0, 'fp': 0, 'fn': 0, 'total': 0, 'correct': 0}
|
| 215 |
+
for item in all_items}
|
| 216 |
+
total_comparisons = 0
|
| 217 |
+
total_correct = 0
|
| 218 |
+
matched_patients = 0
|
| 219 |
+
errors = []
|
| 220 |
+
|
| 221 |
+
for doc_id, gt_items in gt.items():
|
| 222 |
+
pred_items = preds.get(doc_id, {})
|
| 223 |
+
if pred_items:
|
| 224 |
+
matched_patients += 1
|
| 225 |
+
|
| 226 |
+
for item_name, gt_val in gt_items.items():
|
| 227 |
+
gt_norm = _normalise(gt_val)
|
| 228 |
+
pred_val = pred_items.get(item_name)
|
| 229 |
+
pred_norm = _normalise(pred_val) if pred_val is not None else ""
|
| 230 |
+
|
| 231 |
+
total_comparisons += 1
|
| 232 |
+
item_stats[item_name]['total'] += 1
|
| 233 |
+
|
| 234 |
+
if gt_norm == pred_norm:
|
| 235 |
+
total_correct += 1
|
| 236 |
+
item_stats[item_name]['correct'] += 1
|
| 237 |
+
item_stats[item_name]['tp'] += 1
|
| 238 |
+
else:
|
| 239 |
+
item_stats[item_name]['fn'] += 1
|
| 240 |
+
if pred_norm:
|
| 241 |
+
item_stats[item_name]['fp'] += 1
|
| 242 |
+
errors.append((doc_id, item_name, gt_val,
|
| 243 |
+
pred_val if pred_val is not None else '<MISSING>'))
|
| 244 |
+
|
| 245 |
+
accuracy = total_correct / total_comparisons if total_comparisons > 0 else 0.0
|
| 246 |
+
|
| 247 |
+
# --- Per-item P/R/F1 ---
|
| 248 |
+
f1s = []
|
| 249 |
+
per_item = {}
|
| 250 |
+
for item_name in sorted(all_items):
|
| 251 |
+
s = item_stats[item_name]
|
| 252 |
+
tp, fp, fn = s['tp'], s['fp'], s['fn']
|
| 253 |
+
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 254 |
+
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 255 |
+
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
|
| 256 |
+
item_acc = s['correct'] / s['total'] if s['total'] > 0 else 0.0
|
| 257 |
+
per_item[item_name] = {'accuracy': item_acc, 'precision': prec,
|
| 258 |
+
'recall': rec, 'f1': f1, 'total': s['total']}
|
| 259 |
+
f1s.append(f1)
|
| 260 |
+
|
| 261 |
+
macro_f1 = sum(f1s) / len(f1s) if f1s else 0.0
|
| 262 |
+
|
| 263 |
+
# --- Print report ---
|
| 264 |
+
print("\n" + "=" * 70)
|
| 265 |
+
print(" CL4Health CRF Filling -- Evaluation Report")
|
| 266 |
+
print("=" * 70)
|
| 267 |
+
print(f"\n GT Patients: {len(gt)}")
|
| 268 |
+
print(f" Pred Patients: {len(preds)}")
|
| 269 |
+
print(f" Matched Patients: {matched_patients}")
|
| 270 |
+
print(f"\n Total Comparisons: {total_comparisons}")
|
| 271 |
+
print(f" Total Correct: {total_correct}")
|
| 272 |
+
print(f"\n {'Accuracy':>20s}: {accuracy:.4f}")
|
| 273 |
+
print(f" {'Macro F1':>20s}: {macro_f1:.4f}")
|
| 274 |
+
|
| 275 |
+
# Top / bottom items
|
| 276 |
+
sorted_items = sorted(per_item.items(), key=lambda x: x[1]['f1'], reverse=True)
|
| 277 |
+
n_show = min(15, len(sorted_items))
|
| 278 |
+
|
| 279 |
+
print(f"\n Top {n_show} Items by F1:")
|
| 280 |
+
print(f" {'Item':<45s} {'Acc':>6s} {'P':>6s} {'R':>6s} {'F1':>6s}")
|
| 281 |
+
print(f" {'-'*45} {'---':>6s} {'---':>6s} {'---':>6s} {'---':>6s}")
|
| 282 |
+
for name, s in sorted_items[:n_show]:
|
| 283 |
+
print(f" {name:<45s} {s['accuracy']:>6.2f} {s['precision']:>6.2f} {s['recall']:>6.2f} {s['f1']:>6.2f}")
|
| 284 |
+
|
| 285 |
+
print(f"\n Bottom {n_show} Items by F1:")
|
| 286 |
+
print(f" {'Item':<45s} {'Acc':>6s} {'P':>6s} {'R':>6s} {'F1':>6s}")
|
| 287 |
+
print(f" {'-'*45} {'---':>6s} {'---':>6s} {'---':>6s} {'---':>6s}")
|
| 288 |
+
for name, s in sorted_items[-n_show:]:
|
| 289 |
+
print(f" {name:<45s} {s['accuracy']:>6.2f} {s['precision']:>6.2f} {s['recall']:>6.2f} {s['f1']:>6.2f}")
|
| 290 |
+
|
| 291 |
+
# Sample errors
|
| 292 |
+
if errors:
|
| 293 |
+
n_err = min(15, len(errors))
|
| 294 |
+
print(f"\n Sample Mismatches ({n_err} of {len(errors)}):")
|
| 295 |
+
print(f" {'DocID':<12s} {'Item':<40s} {'GT':<20s} {'Pred':<20s}")
|
| 296 |
+
print(f" {'-'*12} {'-'*40} {'-'*20} {'-'*20}")
|
| 297 |
+
for doc_id, item, gt_v, pred_v in errors[:n_err]:
|
| 298 |
+
print(f" {doc_id:<12s} {item:<40s} {str(gt_v):<20s} {str(pred_v):<20s}")
|
| 299 |
+
|
| 300 |
+
print("=" * 70)
|
| 301 |
+
|
| 302 |
+
return {'accuracy': round(accuracy, 4), 'macro_f1': round(macro_f1, 4)}, per_item
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ---------------------------------------------------------------------------
|
| 306 |
+
# MAIN
|
| 307 |
+
# ---------------------------------------------------------------------------
|
| 308 |
+
|
| 309 |
+
async def main():
|
| 310 |
+
parser = argparse.ArgumentParser()
|
| 311 |
+
parser.add_argument("--api_key", default=API_KEY,
|
| 312 |
+
help="Google AI Studio API key")
|
| 313 |
+
parser.add_argument("--model_name", default="gemini-1.5-pro",
|
| 314 |
+
help="Gemini model name")
|
| 315 |
+
parser.add_argument("--data_folders", nargs="+",
|
| 316 |
+
default=[
|
| 317 |
+
r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-clinical-notes",
|
| 318 |
+
r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-crf-development",
|
| 319 |
+
],
|
| 320 |
+
help="Directories containing .parquet shards (searched recursively)")
|
| 321 |
+
parser.add_argument("--gt_file",
|
| 322 |
+
default=r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dev_gt.jsonl")
|
| 323 |
+
parser.add_argument("--options_folder",
|
| 324 |
+
default=r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-valid-options\dyspnea-valid-options\data")
|
| 325 |
+
parser.add_argument("--output_file",
|
| 326 |
+
default="data/processed/materialized_ehr/submission.json")
|
| 327 |
+
parser.add_argument("--skip_eval", action="store_true",
|
| 328 |
+
help="Skip evaluation after generating predictions")
|
| 329 |
+
parser.add_argument("--concurrency", type=int, default=5,
|
| 330 |
+
help="Max concurrent LLM calls (free tier: keep at 5)")
|
| 331 |
+
# --- RAG options ---
|
| 332 |
+
parser.add_argument("--use_rag", action="store_true",
|
| 333 |
+
help="Use RAG-guided extraction (retrieves relevant tuples per CRF item)")
|
| 334 |
+
parser.add_argument("--rag_top_k", type=int, default=15,
|
| 335 |
+
help="Number of WTTS tuples to retrieve per CRF item group (RAG mode)")
|
| 336 |
+
parser.add_argument("--rag_model", type=str, default="all-MiniLM-L6-v2",
|
| 337 |
+
help="SentenceTransformer model for embeddings (swap to clinical model on GPU)")
|
| 338 |
+
parser.add_argument("--rag_device", type=str, default="cpu",
|
| 339 |
+
help="Device for embedding model: 'cpu' or 'cuda'")
|
| 340 |
+
args = parser.parse_args()
|
| 341 |
+
|
| 342 |
+
# 1. Setup — Configure Gemini API
|
| 343 |
+
genai.configure(api_key=args.api_key)
|
| 344 |
+
model = genai.GenerativeModel(args.model_name)
|
| 345 |
+
print(f"Using model: {args.model_name} (Google AI Studio)")
|
| 346 |
+
|
| 347 |
+
# Limit concurrency (free tier = 15 RPM, so keep low)
|
| 348 |
+
semaphore = asyncio.Semaphore(args.concurrency)
|
| 349 |
+
|
| 350 |
+
# 2. Load Data
|
| 351 |
+
loader = DataLoader(data_folders=args.data_folders, gt_path=args.gt_file)
|
| 352 |
+
|
| 353 |
+
target_items = loader.get_target_schema()
|
| 354 |
+
valid_options = loader.load_valid_options(args.options_folder)
|
| 355 |
+
|
| 356 |
+
merged_data = loader.load_and_merge()
|
| 357 |
+
|
| 358 |
+
if not merged_data:
|
| 359 |
+
print("No data found. Exiting.")
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
# 3. Process
|
| 363 |
+
builder = WTTSBuilder()
|
| 364 |
+
print(f"Starting pipeline for {len(merged_data)} patients...")
|
| 365 |
+
print(f"Schema: {len(target_items)} items per patient.")
|
| 366 |
+
|
| 367 |
+
if args.use_rag:
|
| 368 |
+
# --- RAG Pipeline ---
|
| 369 |
+
if not RAG_AVAILABLE:
|
| 370 |
+
print("ERROR: RAG dependencies not installed. Run:")
|
| 371 |
+
print(" pip install sentence-transformers faiss-cpu")
|
| 372 |
+
return
|
| 373 |
+
|
| 374 |
+
print(f"\n [RAG MODE] Embedding model: {args.rag_model}")
|
| 375 |
+
print(f" [RAG MODE] Device: {args.rag_device}")
|
| 376 |
+
print(f" [RAG MODE] Top-k: {args.rag_top_k}\n")
|
| 377 |
+
|
| 378 |
+
embedder = WTTSEmbedder(model_name=args.rag_model, device=args.rag_device)
|
| 379 |
+
extractor = RAGCRFExtractor(
|
| 380 |
+
embedder=embedder,
|
| 381 |
+
generate_fn=generate_async,
|
| 382 |
+
top_k=args.rag_top_k,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
tasks = [
|
| 386 |
+
extractor.extract_patient(
|
| 387 |
+
p, builder, target_items, valid_options, semaphore, model
|
| 388 |
+
)
|
| 389 |
+
for p in merged_data
|
| 390 |
+
]
|
| 391 |
+
else:
|
| 392 |
+
# --- Original Two-Pass Pipeline ---
|
| 393 |
+
tasks = [
|
| 394 |
+
process_patient(model, builder, p, target_items, valid_options, semaphore)
|
| 395 |
+
for p in merged_data
|
| 396 |
+
]
|
| 397 |
+
|
| 398 |
+
results = await asyncio.gather(*tasks)
|
| 399 |
+
results = [r for r in results if r is not None]
|
| 400 |
+
|
| 401 |
+
# 4. Save
|
| 402 |
+
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
|
| 403 |
+
with open(args.output_file, 'w') as f:
|
| 404 |
+
json.dump(results, f, indent=2)
|
| 405 |
+
|
| 406 |
+
print(f"\nDone! {len(results)} results saved to {args.output_file}")
|
| 407 |
+
|
| 408 |
+
# 5. Evaluate against GT
|
| 409 |
+
if not args.skip_eval:
|
| 410 |
+
print("\nRunning evaluation against ground truth...")
|
| 411 |
+
overall, _ = evaluate_predictions(results, args.gt_file)
|
| 412 |
+
print(f"\n >>> Final Accuracy: {overall['accuracy']:.4f} | Macro F1: {overall['macro_f1']:.4f}")
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
if __name__ == "__main__":
|
| 416 |
+
asyncio.run(main())
|
predictor.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import asyncio
|
| 4 |
+
import glob
|
| 5 |
+
import argparse
|
| 6 |
+
from argparse import ArgumentTypeError
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import google.generativeai as genai
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
+
|
| 12 |
+
import src.utils.utils as utils
|
| 13 |
+
|
| 14 |
+
generation_config = {
|
| 15 |
+
"max_output_tokens": 8192,
|
| 16 |
+
"temperature": 1,
|
| 17 |
+
"top_p": 0.95,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
def merge_jsonl_files(input_folder, output_file):
|
| 21 |
+
seen_records = set()
|
| 22 |
+
with open(output_file, 'w') as outfile:
|
| 23 |
+
for filename in glob.glob(f"{input_folder}/*.jsonl"):
|
| 24 |
+
with open(filename, 'r') as infile:
|
| 25 |
+
for line in infile:
|
| 26 |
+
record = json.loads(line)
|
| 27 |
+
unique_id = str(record.get('person_id', ''))
|
| 28 |
+
if unique_id not in seen_records:
|
| 29 |
+
outfile.write(line)
|
| 30 |
+
seen_records.add(unique_id)
|
| 31 |
+
print(f"Merged {len(seen_records)} unique records into {output_file}")
|
| 32 |
+
|
| 33 |
+
def get_context_path_safe(ehr_context):
|
| 34 |
+
"""Convert ehr_context value to a path-safe string."""
|
| 35 |
+
if isinstance(ehr_context, int) or (isinstance(ehr_context, str) and ehr_context.isdigit()):
|
| 36 |
+
return f"context_{ehr_context}"
|
| 37 |
+
return str(ehr_context)
|
| 38 |
+
|
| 39 |
+
def validate_ehr_context(value):
|
| 40 |
+
"""
|
| 41 |
+
Custom validation for ehr_context argument.
|
| 42 |
+
Accepts either predefined strings or positive integers.
|
| 43 |
+
Returns:
|
| 44 |
+
- String values unchanged ("full", "last_five")
|
| 45 |
+
- Integer values as integers
|
| 46 |
+
- For path construction, numeric values are converted to "context_{value}" format
|
| 47 |
+
"""
|
| 48 |
+
if value in ["full", "last_five"]:
|
| 49 |
+
return value
|
| 50 |
+
try:
|
| 51 |
+
context_length = int(value)
|
| 52 |
+
if context_length <= 0:
|
| 53 |
+
raise ArgumentTypeError(f"Context length must be positive, got {value}")
|
| 54 |
+
return context_length
|
| 55 |
+
except ValueError:
|
| 56 |
+
raise ArgumentTypeError(
|
| 57 |
+
f"Invalid ehr_context value: {value}. Must be either 'full', 'last_five'"
|
| 58 |
+
"or a positive integer for context length."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
async def generate_async(prompt, model, max_retries=3, initial_delay=1):
|
| 63 |
+
loop = asyncio.get_event_loop()
|
| 64 |
+
for attempt in range(max_retries):
|
| 65 |
+
try:
|
| 66 |
+
# Wrap the synchronous SDK call so it doesn't block the event loop
|
| 67 |
+
response = await loop.run_in_executor(
|
| 68 |
+
None,
|
| 69 |
+
lambda: model.generate_content(
|
| 70 |
+
contents=prompt,
|
| 71 |
+
generation_config={"response_mime_type": "application/json"},
|
| 72 |
+
stream=False
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
try:
|
| 76 |
+
json_response = json.loads(response.text)
|
| 77 |
+
return json_response
|
| 78 |
+
except json.JSONDecodeError:
|
| 79 |
+
print(f"Generated content is not valid JSON. Retrying...")
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
error_message = str(e)
|
| 84 |
+
if "500" in error_message:
|
| 85 |
+
if attempt < max_retries - 1:
|
| 86 |
+
delay = initial_delay * (2 ** attempt) # Exponential backoff
|
| 87 |
+
print(f"Encountered 500 error. Retrying in {delay} seconds...")
|
| 88 |
+
await asyncio.sleep(delay)
|
| 89 |
+
else:
|
| 90 |
+
print(f"Max retries reached. Unable to generate content.")
|
| 91 |
+
return {"error": f"Max retries reached - {error_message}"}
|
| 92 |
+
else:
|
| 93 |
+
print(f"Error in generate_async: {error_message}")
|
| 94 |
+
return {"error": error_message}
|
| 95 |
+
|
| 96 |
+
return {"error": "Failed to generate valid JSON content after multiple attempts"}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def process_single_timeline(timeline, args, model, semaphore, windowed_jsonl_file):
|
| 100 |
+
person_id = timeline.get('person_id', 'unknown')
|
| 101 |
+
print(f"Processing timeline for person_id: {person_id}")
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
context_windows = utils.process_ehr_context(
|
| 105 |
+
timeline,
|
| 106 |
+
args.ehr_context,
|
| 107 |
+
windowed_jsonl_file
|
| 108 |
+
)
|
| 109 |
+
responses = []
|
| 110 |
+
|
| 111 |
+
for window in context_windows:
|
| 112 |
+
if args.prompt_method == "general":
|
| 113 |
+
prompt = utils.create_prompt_from_timeline(
|
| 114 |
+
window,
|
| 115 |
+
args.prompt_template
|
| 116 |
+
)
|
| 117 |
+
async with semaphore:
|
| 118 |
+
response = await generate_async(prompt, model)
|
| 119 |
+
responses.append({
|
| 120 |
+
"window_index": window.get("window_index", 0),
|
| 121 |
+
"response": response,
|
| 122 |
+
"window_token_count": window.get("window_token_count"),
|
| 123 |
+
"window_percent_full": window.get("window_percent_full"),
|
| 124 |
+
"start_date": window.get("start_date"),
|
| 125 |
+
"end_date": window.get("end_date")
|
| 126 |
+
})
|
| 127 |
+
specialty = None
|
| 128 |
+
visit_occurrence_id = None
|
| 129 |
+
elif args.prompt_method == "persona":
|
| 130 |
+
prompt, specialty, visit_occurrence_id = utils.create_persona_prompt_from_timeline(
|
| 131 |
+
window,
|
| 132 |
+
args.prompt_template,
|
| 133 |
+
person_id
|
| 134 |
+
)
|
| 135 |
+
async with semaphore:
|
| 136 |
+
response = await generate_async(prompt, model)
|
| 137 |
+
responses.append({
|
| 138 |
+
"window_index": window.get("window_index", 0),
|
| 139 |
+
"response": response,
|
| 140 |
+
"window_token_count": window.get("window_token_count"),
|
| 141 |
+
"window_percent_full": window.get("window_percent_full"),
|
| 142 |
+
"start_date": window.get("start_date"),
|
| 143 |
+
"end_date": window.get("end_date")
|
| 144 |
+
})
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError(f"Unknown prompt method: {args.prompt_method}")
|
| 147 |
+
|
| 148 |
+
print(f"Completed processing for person_id: {person_id} with {len(responses)} context windows")
|
| 149 |
+
return {
|
| 150 |
+
"person_id": person_id,
|
| 151 |
+
"visit_occurrence_id": visit_occurrence_id,
|
| 152 |
+
"responses": responses,
|
| 153 |
+
"specialty": specialty
|
| 154 |
+
}
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"Error processing timeline for person_id {person_id}: {str(e)}")
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
async def process_timelines(merged_jsonl_file, args, windowed_jsonl_file):
|
| 160 |
+
genai.configure(api_key=args.api_key)
|
| 161 |
+
model = genai.GenerativeModel(args.model_name)
|
| 162 |
+
semaphore = asyncio.Semaphore(args.max_concurrent_calls)
|
| 163 |
+
|
| 164 |
+
with open(merged_jsonl_file, 'r') as f:
|
| 165 |
+
timelines = [json.loads(line) for line in f]
|
| 166 |
+
total_timelines = min(len(timelines), args.max_samples) if args.max_samples else len(timelines)
|
| 167 |
+
print(f"Found {total_timelines} total timelines")
|
| 168 |
+
|
| 169 |
+
results = []
|
| 170 |
+
processed_count = 0
|
| 171 |
+
for timeline in tqdm(timelines[:total_timelines], total=total_timelines, desc="Processing timelines"):
|
| 172 |
+
try:
|
| 173 |
+
result = await process_single_timeline(timeline, args, model, semaphore, windowed_jsonl_file)
|
| 174 |
+
if result is not None:
|
| 175 |
+
results.append(result)
|
| 176 |
+
processed_count += 1
|
| 177 |
+
if processed_count % 10 == 0 or processed_count == total_timelines:
|
| 178 |
+
print(f"Processed {processed_count}/{total_timelines} samples")
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"Error processing timeline: {e}")
|
| 181 |
+
|
| 182 |
+
return results
|
| 183 |
+
|
| 184 |
+
def save_single_response(resp_data: Dict, output_file: Path, person_id: str, template_name: Optional[str] = None):
|
| 185 |
+
try:
|
| 186 |
+
with open(output_file, 'w') as f:
|
| 187 |
+
json.dump(resp_data, f, indent=2)
|
| 188 |
+
print(f"Saved response for person_id: {person_id}" +
|
| 189 |
+
(f", template: {template_name}" if template_name else ""))
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error saving response for person_id: {person_id}" +
|
| 192 |
+
(f", template: {template_name}" if template_name else ""))
|
| 193 |
+
print(f"Error details: {str(e)}")
|
| 194 |
+
|
| 195 |
+
def save_responses(responses, output_folder, prompt_method):
|
| 196 |
+
"""
|
| 197 |
+
Save responses with window metadata including date ranges.
|
| 198 |
+
"""
|
| 199 |
+
for response in responses:
|
| 200 |
+
person_id = response["person_id"]
|
| 201 |
+
|
| 202 |
+
for window_response in response["responses"]:
|
| 203 |
+
if prompt_method == "general":
|
| 204 |
+
output_file = Path(output_folder) / f"{person_id}_{window_response['window_index']}.json"
|
| 205 |
+
elif prompt_method == "persona":
|
| 206 |
+
specialty = response.get("specialty", "unknown")
|
| 207 |
+
output_file = Path(output_folder) / f"{person_id}_{specialty}_{window_response['window_index']}.json"
|
| 208 |
+
|
| 209 |
+
metadata = {
|
| 210 |
+
"person_id": person_id,
|
| 211 |
+
"window_index": window_response["window_index"],
|
| 212 |
+
"window_token_count": window_response["window_token_count"],
|
| 213 |
+
"window_percent_full": window_response["window_percent_full"],
|
| 214 |
+
"start_date": window_response["start_date"],
|
| 215 |
+
"end_date": window_response["end_date"],
|
| 216 |
+
"response_data": window_response["response"]
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 220 |
+
save_single_response(metadata, output_file, person_id)
|
| 221 |
+
|
| 222 |
+
print(f"Saved response for person_id: {person_id}, "
|
| 223 |
+
f"window: {window_response['window_index']}, "
|
| 224 |
+
f"dates: {window_response['start_date']} to {window_response['end_date']}")
|
| 225 |
+
|
| 226 |
+
print(f"Saved responses from {len(responses)} timelines to {output_folder}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
async def main(args):
|
| 230 |
+
start_time = time.time()
|
| 231 |
+
|
| 232 |
+
merged_jsonl_file = Path(args.materialized_ehr_folder) / f"merged_timelines.jsonl"
|
| 233 |
+
windowed_dir = Path(args.materialized_ehr_folder) / "windowed"
|
| 234 |
+
windowed_dir.mkdir(parents=True, exist_ok=True)
|
| 235 |
+
windowed_jsonl_file = windowed_dir / f"windowed_timelines_{args.prompt_method}_{args.ehr_context}.jsonl"
|
| 236 |
+
|
| 237 |
+
merge_jsonl_files(args.materialized_ehr_folder, merged_jsonl_file)
|
| 238 |
+
|
| 239 |
+
print("Processing timelines...")
|
| 240 |
+
responses = await process_timelines(merged_jsonl_file, args, windowed_jsonl_file)
|
| 241 |
+
print(f"Processed {len(responses)} samples")
|
| 242 |
+
|
| 243 |
+
if args.max_samples:
|
| 244 |
+
print(f"Limited to {args.max_samples} samples for testing")
|
| 245 |
+
|
| 246 |
+
if responses:
|
| 247 |
+
context_path = get_context_path_safe(args.ehr_context)
|
| 248 |
+
output_folder = Path(args.output_folder) / args.model_name / args.prompt_method / context_path / Path(args.prompt_template).stem
|
| 249 |
+
save_responses(responses, output_folder, args.prompt_method)
|
| 250 |
+
|
| 251 |
+
total_time = time.time() - start_time
|
| 252 |
+
print(f"Total time taken: {total_time:.2f} seconds")
|
| 253 |
+
print(f"Average time per sample: {total_time / len(responses):.2f} seconds")
|
| 254 |
+
else:
|
| 255 |
+
print("No responses were generated. Check the logs for more information.")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
parser = argparse.ArgumentParser(description="Generate synthetic instruction-response pairs from materialized patient timelines using async processing.")
|
| 260 |
+
parser.add_argument("--materialized_ehr_folder", type=str, help="Folder containing materialized EHR JSONL files")
|
| 261 |
+
parser.add_argument("--prompt_template", type=str)
|
| 262 |
+
parser.add_argument("--prompt_method", type=str, default="persona", help="Prompt method: general, persona")
|
| 263 |
+
parser.add_argument("--ehr_context",type=validate_ehr_context,default=16384,help="EHR context to use: 'full', 'last_five', or a positive integer for context length (e.g., 8192)")
|
| 264 |
+
parser.add_argument("--output_folder", type=str)
|
| 265 |
+
parser.add_argument("--project_id", type=str)
|
| 266 |
+
parser.add_argument("--location", type=str, default="us-central1", help="Google Cloud location")
|
| 267 |
+
parser.add_argument("--model_name", type=str)
|
| 268 |
+
parser.add_argument("--max_samples", type=int)
|
| 269 |
+
parser.add_argument("--max_concurrent_calls")
|
| 270 |
+
parser.add_argument("--csv_file")
|
| 271 |
+
parser.add_argument("--dataset_id")
|
| 272 |
+
parser.add_argument("--table_id", type=str)
|
| 273 |
+
args = parser.parse_args()
|
| 274 |
+
|
| 275 |
+
asyncio.run(main(args))
|
preprocess/__init__.py
ADDED
|
File without changes
|
preprocess/wtts_builder.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import argparse
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from src.utils.data_loader import DataLoader
|
| 8 |
+
|
| 9 |
+
class WTTSBuilder:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
# --- WEIGHTING RULES (W) ---
|
| 12 |
+
self.critical_patterns = [
|
| 13 |
+
r'respiratory failure', r'seizure', r'cardiac arrest', r'intubat',
|
| 14 |
+
r'abnormal', r'critical', r'hemorrhage', r'positive',
|
| 15 |
+
r'emergency', r'acute', r'hypoxia', r'flagged', r'icu',
|
| 16 |
+
r'dyspnea', r'shortness of breath', r'sob', r'mrc grade', r'nyha'
|
| 17 |
+
]
|
| 18 |
+
self.chronic_patterns = [
|
| 19 |
+
r'history of', r'chronic', r'stable', r'continued',
|
| 20 |
+
r'maintained', r'diagnosed with', r'previous', r'known'
|
| 21 |
+
]
|
| 22 |
+
self.routine_patterns = [
|
| 23 |
+
r'routine', r'normal', r'negative', r'unremarkable',
|
| 24 |
+
r'no acute', r'clear', r'regular diet', r'resting'
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
def _get_normalized_time(self, event_time_str, admit_str, disch_str):
|
| 28 |
+
"""Calculates P_j (0.0 to 1.0)"""
|
| 29 |
+
try:
|
| 30 |
+
e_dt = pd.to_datetime(event_time_str)
|
| 31 |
+
a_dt = pd.to_datetime(admit_str)
|
| 32 |
+
d_dt = pd.to_datetime(disch_str)
|
| 33 |
+
|
| 34 |
+
total_duration = (d_dt - a_dt).total_seconds()
|
| 35 |
+
elapsed = (e_dt - a_dt).total_seconds()
|
| 36 |
+
|
| 37 |
+
if total_duration <= 0: return 1.0
|
| 38 |
+
return round(max(0.0, min(1.0, elapsed / total_duration)), 2)
|
| 39 |
+
except:
|
| 40 |
+
return 0.5
|
| 41 |
+
|
| 42 |
+
def _get_weight(self, text):
|
| 43 |
+
t = text.lower()
|
| 44 |
+
if any(re.search(p, t) for p in self.critical_patterns): return 1.0
|
| 45 |
+
if any(re.search(p, t) for p in self.chronic_patterns): return 0.5
|
| 46 |
+
if any(re.search(p, t) for p in self.routine_patterns): return 0.1
|
| 47 |
+
return 0.5
|
| 48 |
+
|
| 49 |
+
def _extract_sentences_with_ids(self, text, start_index):
|
| 50 |
+
"""
|
| 51 |
+
Splits notes and assigns UNIQUE IDs.
|
| 52 |
+
CRITICAL FIX: Sanitizes newlines to preserve WTTS structure.
|
| 53 |
+
"""
|
| 54 |
+
# 1. Replace newlines/tabs with spaces to keep tuple on one line
|
| 55 |
+
text = text.replace('\n', ' ').replace('\r', '').replace('\t', ' ')
|
| 56 |
+
|
| 57 |
+
# 2. Remove de-id brackets
|
| 58 |
+
text = re.sub(r'\[\*\*.*?\*\*\]', '', text)
|
| 59 |
+
|
| 60 |
+
# 3. Split by sentence boundaries
|
| 61 |
+
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
|
| 62 |
+
|
| 63 |
+
results = []
|
| 64 |
+
current_idx = start_index
|
| 65 |
+
for s in sentences:
|
| 66 |
+
clean_s = s.strip()
|
| 67 |
+
# Ignore very short/empty fragments
|
| 68 |
+
if len(clean_s) > 5:
|
| 69 |
+
sid = f"S_{current_idx}"
|
| 70 |
+
results.append((sid, clean_s))
|
| 71 |
+
current_idx += 1
|
| 72 |
+
return results, current_idx
|
| 73 |
+
|
| 74 |
+
def build_wtts_string(self, patient_data):
|
| 75 |
+
tuples = []
|
| 76 |
+
admit = patient_data.get('admission_time')
|
| 77 |
+
disch = patient_data.get('discharge_time')
|
| 78 |
+
|
| 79 |
+
sorted_notes = sorted(patient_data.get('notes', []), key=lambda x: x['timestamp'])
|
| 80 |
+
|
| 81 |
+
global_sent_idx = 0
|
| 82 |
+
|
| 83 |
+
for note in sorted_notes:
|
| 84 |
+
raw_ts = note['timestamp']
|
| 85 |
+
p_j = self._get_normalized_time(raw_ts, admit, disch)
|
| 86 |
+
|
| 87 |
+
events, global_sent_idx = self._extract_sentences_with_ids(note['text'], global_sent_idx)
|
| 88 |
+
|
| 89 |
+
for (sid, event) in events:
|
| 90 |
+
w = self._get_weight(event)
|
| 91 |
+
# Format: [ID] ("Date", "Event", P_j, W)
|
| 92 |
+
tuples.append(f'[{sid}] ("{raw_ts}", "{event}", {p_j}, {w})')
|
| 93 |
+
|
| 94 |
+
return " | ".join(tuples)
|
| 95 |
+
|
| 96 |
+
# --- EXECUTION LOGIC ---
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
parser = argparse.ArgumentParser(description="Process raw clinical notes into WTTS format.")
|
| 99 |
+
# Set defaults for easier running
|
| 100 |
+
parser.add_argument("--input_dirs", nargs="+",
|
| 101 |
+
default=[
|
| 102 |
+
"data/raw/dyspnea-clinical-notes",
|
| 103 |
+
"data/raw/dyspnea-crf-development",
|
| 104 |
+
],
|
| 105 |
+
help="Directories containing .parquet shards (searched recursively)")
|
| 106 |
+
parser.add_argument("--gt_file", type=str,
|
| 107 |
+
default="data/raw/dev_gt.jsonl",
|
| 108 |
+
help="Path to the ground truth JSONL file.")
|
| 109 |
+
parser.add_argument("--output_dir", type=str,
|
| 110 |
+
default="data/processed/materialized_ehr",
|
| 111 |
+
help="Path to store processed JSONL files.")
|
| 112 |
+
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
loader = DataLoader(data_folders=args.input_dirs, gt_path=args.gt_file)
|
| 116 |
+
builder = WTTSBuilder()
|
| 117 |
+
|
| 118 |
+
patients = loader.load_and_merge()
|
| 119 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
if not patients:
|
| 122 |
+
print("No patients found! Check paths.")
|
| 123 |
+
else:
|
| 124 |
+
print(f"Materializing timelines for {len(patients)} patients...")
|
| 125 |
+
for p in patients:
|
| 126 |
+
wtts_output = builder.build_wtts_string(p)
|
| 127 |
+
|
| 128 |
+
# FIX: Prioritize document_id to match DataLoader logic
|
| 129 |
+
pid = str(p.get('document_id') or p.get('patient_id') or p.get('hadm_id'))
|
| 130 |
+
|
| 131 |
+
output_path = os.path.join(args.output_dir, f"{pid}.jsonl")
|
| 132 |
+
with open(output_path, 'w') as f:
|
| 133 |
+
json.dump({"person_id": pid, "text": wtts_output}, f)
|
| 134 |
+
|
| 135 |
+
print(f"Successfully stored outputs in: {args.output_dir}")
|
rag/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG-guided CRF extraction module
|
| 2 |
+
from src.rag.embedder import WTTSEmbedder
|
| 3 |
+
from src.rag.retriever import WTTSRetriever
|
| 4 |
+
from src.rag.rag_pipeline import RAGCRFExtractor
|
rag/embedder.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WTTS Tuple Embedder — Embeds clinical event tuples and CRF queries
|
| 3 |
+
into vector space using SentenceTransformers.
|
| 4 |
+
|
| 5 |
+
Swap the model_name to a clinical model (e.g., MedCPT) when GPU is available.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import List, Dict, Optional
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class WTTSEmbedder:
|
| 15 |
+
"""Embeds WTTS tuples and CRF item queries into dense vectors."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2", device: str = "cpu"):
|
| 18 |
+
"""
|
| 19 |
+
Args:
|
| 20 |
+
model_name: SentenceTransformer model ID.
|
| 21 |
+
CPU default: 'all-MiniLM-L6-v2' (384-dim, fast)
|
| 22 |
+
GPU clinical: 'pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb'
|
| 23 |
+
or 'medicalai/ClinicalBERT'
|
| 24 |
+
device: 'cpu' or 'cuda'
|
| 25 |
+
"""
|
| 26 |
+
print(f"Loading embedding model: {model_name} on {device}...")
|
| 27 |
+
self.model = SentenceTransformer(model_name, device=device)
|
| 28 |
+
self.device = device
|
| 29 |
+
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
| 30 |
+
print(f" Embedding dimension: {self.embedding_dim}")
|
| 31 |
+
|
| 32 |
+
# ------------------------------------------------------------------
|
| 33 |
+
# Parse WTTS string → structured list of tuple dicts
|
| 34 |
+
# ------------------------------------------------------------------
|
| 35 |
+
def parse_wtts_string(self, wtts_string: str) -> List[Dict]:
|
| 36 |
+
"""
|
| 37 |
+
Parses the WTTS pipe-delimited string back into structured dicts.
|
| 38 |
+
|
| 39 |
+
Input format: [S_0] ("2026-01-01", "event text", 0.5, 1.0) | [S_1] (...)
|
| 40 |
+
Output: [
|
| 41 |
+
{"sid": "S_0", "timestamp": "2026-01-01", "event": "event text", "p_j": 0.5, "weight": 1.0},
|
| 42 |
+
...
|
| 43 |
+
]
|
| 44 |
+
"""
|
| 45 |
+
tuples = []
|
| 46 |
+
# Split by pipe separator
|
| 47 |
+
raw_entries = wtts_string.split(" | ")
|
| 48 |
+
|
| 49 |
+
# Pattern to extract: [S_xx] ("timestamp", "event", P_j, W)
|
| 50 |
+
pattern = re.compile(
|
| 51 |
+
r'\[(?P<sid>S_\d+)\]\s*\('
|
| 52 |
+
r'"(?P<timestamp>[^"]*)",\s*'
|
| 53 |
+
r'"(?P<event>[^"]*)",\s*'
|
| 54 |
+
r'(?P<p_j>[\d.]+),\s*'
|
| 55 |
+
r'(?P<weight>[\d.]+)\)'
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
for entry in raw_entries:
|
| 59 |
+
entry = entry.strip()
|
| 60 |
+
if not entry:
|
| 61 |
+
continue
|
| 62 |
+
match = pattern.search(entry)
|
| 63 |
+
if match:
|
| 64 |
+
tuples.append({
|
| 65 |
+
"sid": match.group("sid"),
|
| 66 |
+
"timestamp": match.group("timestamp"),
|
| 67 |
+
"event": match.group("event"),
|
| 68 |
+
"p_j": float(match.group("p_j")),
|
| 69 |
+
"weight": float(match.group("weight")),
|
| 70 |
+
})
|
| 71 |
+
|
| 72 |
+
return tuples
|
| 73 |
+
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
# Embed tuple event texts
|
| 76 |
+
# ------------------------------------------------------------------
|
| 77 |
+
def embed_tuples(self, tuples: List[Dict]) -> np.ndarray:
|
| 78 |
+
"""
|
| 79 |
+
Embed the event text from each WTTS tuple.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
tuples: List of parsed tuple dicts (from parse_wtts_string)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
np.ndarray of shape (n_tuples, embedding_dim)
|
| 86 |
+
"""
|
| 87 |
+
if not tuples:
|
| 88 |
+
return np.array([])
|
| 89 |
+
|
| 90 |
+
texts = [t["event"] for t in tuples]
|
| 91 |
+
embeddings = self.model.encode(
|
| 92 |
+
texts,
|
| 93 |
+
show_progress_bar=False,
|
| 94 |
+
normalize_embeddings=True, # L2-normalize for cosine similarity via dot product
|
| 95 |
+
batch_size=64,
|
| 96 |
+
)
|
| 97 |
+
return np.array(embeddings, dtype=np.float32)
|
| 98 |
+
|
| 99 |
+
# ------------------------------------------------------------------
|
| 100 |
+
# Embed a CRF item query
|
| 101 |
+
# ------------------------------------------------------------------
|
| 102 |
+
def embed_query(self, crf_item: str, valid_options: Optional[List[str]] = None) -> np.ndarray:
|
| 103 |
+
"""
|
| 104 |
+
Create an embedding for a CRF item query.
|
| 105 |
+
Combines the item name with its valid options to create a richer query.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
crf_item: e.g., "mrc_grade" or "administration of bronchodilators"
|
| 109 |
+
valid_options: e.g., ["y", "n", "unknown"]
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
np.ndarray of shape (embedding_dim,)
|
| 113 |
+
"""
|
| 114 |
+
# Build a descriptive query string
|
| 115 |
+
query_parts = [crf_item.replace("_", " ")]
|
| 116 |
+
|
| 117 |
+
if valid_options:
|
| 118 |
+
# Add option context to help embedding understand what we're looking for
|
| 119 |
+
opts_str = ", ".join(str(o) for o in valid_options[:10]) # limit to avoid huge queries
|
| 120 |
+
query_parts.append(f"options: {opts_str}")
|
| 121 |
+
|
| 122 |
+
query_text = " | ".join(query_parts)
|
| 123 |
+
|
| 124 |
+
embedding = self.model.encode(
|
| 125 |
+
[query_text],
|
| 126 |
+
show_progress_bar=False,
|
| 127 |
+
normalize_embeddings=True,
|
| 128 |
+
)
|
| 129 |
+
return np.array(embedding[0], dtype=np.float32)
|
| 130 |
+
|
| 131 |
+
# ------------------------------------------------------------------
|
| 132 |
+
# Batch embed multiple CRF queries at once
|
| 133 |
+
# ------------------------------------------------------------------
|
| 134 |
+
def embed_queries_batch(
|
| 135 |
+
self,
|
| 136 |
+
crf_items: List[str],
|
| 137 |
+
valid_options_map: Dict[str, List[str]]
|
| 138 |
+
) -> Dict[str, np.ndarray]:
|
| 139 |
+
"""
|
| 140 |
+
Embed all CRF items in one batch for efficiency.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Dict mapping crf_item name → embedding vector
|
| 144 |
+
"""
|
| 145 |
+
query_texts = []
|
| 146 |
+
item_names = []
|
| 147 |
+
|
| 148 |
+
for item in crf_items:
|
| 149 |
+
item_names.append(item)
|
| 150 |
+
parts = [item.replace("_", " ")]
|
| 151 |
+
opts = valid_options_map.get(item, [])
|
| 152 |
+
if opts:
|
| 153 |
+
opts_str = ", ".join(str(o) for o in opts[:10])
|
| 154 |
+
parts.append(f"options: {opts_str}")
|
| 155 |
+
query_texts.append(" | ".join(parts))
|
| 156 |
+
|
| 157 |
+
embeddings = self.model.encode(
|
| 158 |
+
query_texts,
|
| 159 |
+
show_progress_bar=False,
|
| 160 |
+
normalize_embeddings=True,
|
| 161 |
+
batch_size=64,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
name: np.array(emb, dtype=np.float32)
|
| 166 |
+
for name, emb in zip(item_names, embeddings)
|
| 167 |
+
}
|
rag/rag_pipeline.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG CRF Extractor — Orchestrates the full RAG-guided pipeline:
|
| 3 |
+
1. Build WTTS tuples from patient data
|
| 4 |
+
2. Embed all tuples into FAISS index
|
| 5 |
+
3. For each CRF item: retrieve relevant tuples → LLM extraction
|
| 6 |
+
4. Return predictions
|
| 7 |
+
|
| 8 |
+
This replaces the two-pass (Skeleton → Extraction) approach in main.py.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import asyncio
|
| 12 |
+
import json
|
| 13 |
+
from typing import Dict, List, Optional, Any
|
| 14 |
+
|
| 15 |
+
from src.rag.embedder import WTTSEmbedder
|
| 16 |
+
from src.rag.retriever import WTTSRetriever
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# RAG-optimized prompt — shorter, focused on retrieved evidence only
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
RAG_EXTRACTION_PROMPT = """\
|
| 24 |
+
You are a Clinical Coding Expert.
|
| 25 |
+
|
| 26 |
+
RETRIEVED CLINICAL EVIDENCE (sorted chronologically, most relevant events for these items):
|
| 27 |
+
{retrieved_evidence}
|
| 28 |
+
|
| 29 |
+
TASK:
|
| 30 |
+
For each Clinical Item below, determine the value based ONLY on the evidence above.
|
| 31 |
+
1. **Value**: Must come strictly from the "Valid Options".
|
| 32 |
+
2. **Evidence**: Cite the [S_xx] ID that supports your choice.
|
| 33 |
+
3. If no evidence supports any option, choose "unknown".
|
| 34 |
+
|
| 35 |
+
ITEMS TO EXTRACT & THEIR OPTIONS:
|
| 36 |
+
{chunk_schema_json}
|
| 37 |
+
|
| 38 |
+
OUTPUT FORMAT (JSON Object):
|
| 39 |
+
{{
|
| 40 |
+
"item_name": {{
|
| 41 |
+
"value": "Selected Option",
|
| 42 |
+
"evidence": "S_xx",
|
| 43 |
+
"reasoning": "Brief explanation"
|
| 44 |
+
}},
|
| 45 |
+
...
|
| 46 |
+
}}
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RAGCRFExtractor:
|
| 51 |
+
"""
|
| 52 |
+
Orchestrates RAG-guided CRF extraction for clinical notes.
|
| 53 |
+
Replaces the two-pass (Skeleton → Extraction) pipeline with
|
| 54 |
+
per-item retrieval for focused, temporally-ordered evidence.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
embedder: WTTSEmbedder,
|
| 60 |
+
generate_fn,
|
| 61 |
+
top_k: int = 15,
|
| 62 |
+
weight_boost: float = 0.3,
|
| 63 |
+
items_per_chunk: int = 5,
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Args:
|
| 67 |
+
embedder: WTTSEmbedder instance (shared across patients)
|
| 68 |
+
generate_fn: Async function to call LLM (the generate_async from main.py)
|
| 69 |
+
top_k: Number of tuples to retrieve per CRF item group
|
| 70 |
+
weight_boost: Re-ranking boost for critical events
|
| 71 |
+
items_per_chunk: How many CRF items to group per LLM call.
|
| 72 |
+
Grouped items share the same retrieved evidence pool.
|
| 73 |
+
"""
|
| 74 |
+
self.embedder = embedder
|
| 75 |
+
self.generate_fn = generate_fn
|
| 76 |
+
self.top_k = top_k
|
| 77 |
+
self.weight_boost = weight_boost
|
| 78 |
+
self.items_per_chunk = items_per_chunk
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
# Group CRF items by semantic similarity for batched retrieval
|
| 82 |
+
# ------------------------------------------------------------------
|
| 83 |
+
def _group_crf_items(
|
| 84 |
+
self,
|
| 85 |
+
target_items: List[str],
|
| 86 |
+
valid_options: Dict[str, List[str]],
|
| 87 |
+
) -> List[List[str]]:
|
| 88 |
+
"""
|
| 89 |
+
Group CRF items into chunks. Items in the same chunk will share
|
| 90 |
+
a combined retrieval query, so similar items get grouped together.
|
| 91 |
+
|
| 92 |
+
For now: simple sequential chunking (items_per_chunk at a time).
|
| 93 |
+
Future: cluster by embedding similarity of item names.
|
| 94 |
+
"""
|
| 95 |
+
chunks = []
|
| 96 |
+
for i in range(0, len(target_items), self.items_per_chunk):
|
| 97 |
+
chunk = target_items[i : i + self.items_per_chunk]
|
| 98 |
+
chunks.append(chunk)
|
| 99 |
+
return chunks
|
| 100 |
+
|
| 101 |
+
# ------------------------------------------------------------------
|
| 102 |
+
# Build combined query for a group of CRF items
|
| 103 |
+
# ------------------------------------------------------------------
|
| 104 |
+
def _build_group_query(
|
| 105 |
+
self,
|
| 106 |
+
items: List[str],
|
| 107 |
+
valid_options: Dict[str, List[str]],
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Create a combined query embedding for a group of CRF items.
|
| 111 |
+
Averages the individual item query embeddings.
|
| 112 |
+
"""
|
| 113 |
+
query_embeddings = self.embedder.embed_queries_batch(items, valid_options)
|
| 114 |
+
|
| 115 |
+
# Average the embeddings for a combined query
|
| 116 |
+
import numpy as np
|
| 117 |
+
all_embs = list(query_embeddings.values())
|
| 118 |
+
combined = np.mean(all_embs, axis=0).astype(np.float32)
|
| 119 |
+
|
| 120 |
+
# Re-normalize after averaging
|
| 121 |
+
norm = np.linalg.norm(combined)
|
| 122 |
+
if norm > 0:
|
| 123 |
+
combined = combined / norm
|
| 124 |
+
|
| 125 |
+
return combined
|
| 126 |
+
|
| 127 |
+
# ------------------------------------------------------------------
|
| 128 |
+
# Main extraction method — full RAG pipeline for one patient
|
| 129 |
+
# ------------------------------------------------------------------
|
| 130 |
+
async def extract_patient(
|
| 131 |
+
self,
|
| 132 |
+
patient_data: Dict,
|
| 133 |
+
builder, # WTTSBuilder instance
|
| 134 |
+
target_items: List[str],
|
| 135 |
+
valid_options: Dict[str, List[str]],
|
| 136 |
+
semaphore: asyncio.Semaphore,
|
| 137 |
+
model: Any = None, # Gemini model (passed to generate_fn)
|
| 138 |
+
) -> Optional[Dict]:
|
| 139 |
+
"""
|
| 140 |
+
Full RAG pipeline for a single patient:
|
| 141 |
+
1. Build WTTS string
|
| 142 |
+
2. Parse & embed tuples → FAISS index
|
| 143 |
+
3. For each CRF item group: retrieve → prompt → extract
|
| 144 |
+
4. Return predictions
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
patient_data: Merged patient dict from DataLoader
|
| 148 |
+
builder: WTTSBuilder instance
|
| 149 |
+
target_items: List of CRF item names to extract
|
| 150 |
+
valid_options: Dict mapping item name → list of valid values
|
| 151 |
+
semaphore: Concurrency limiter for LLM calls
|
| 152 |
+
model: Gemini model instance
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Dict with patient_id, predictions, and debug info
|
| 156 |
+
"""
|
| 157 |
+
pid = str(
|
| 158 |
+
patient_data.get('document_id')
|
| 159 |
+
or patient_data.get('patient_id')
|
| 160 |
+
or patient_data.get('hadm_id')
|
| 161 |
+
or 'unknown'
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
# --- Step 1: Build WTTS tuples ---
|
| 166 |
+
wtts_string = builder.build_wtts_string(patient_data)
|
| 167 |
+
|
| 168 |
+
if not wtts_string.strip():
|
| 169 |
+
print(f" [{pid}] No WTTS tuples generated, skipping.")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
# --- Step 2: Parse and embed tuples ---
|
| 173 |
+
tuples = self.embedder.parse_wtts_string(wtts_string)
|
| 174 |
+
|
| 175 |
+
if not tuples:
|
| 176 |
+
print(f" [{pid}] Failed to parse WTTS tuples, skipping.")
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
tuple_embeddings = self.embedder.embed_tuples(tuples)
|
| 180 |
+
|
| 181 |
+
# --- Step 3: Build FAISS index for this patient ---
|
| 182 |
+
retriever = WTTSRetriever(self.embedder)
|
| 183 |
+
retriever.build_index(tuples, tuple_embeddings)
|
| 184 |
+
|
| 185 |
+
print(f" [{pid}] Indexed {len(tuples)} tuples. "
|
| 186 |
+
f"Retrieving for {len(target_items)} CRF items...")
|
| 187 |
+
|
| 188 |
+
# --- Step 4: Group CRF items and extract ---
|
| 189 |
+
item_groups = self._group_crf_items(target_items, valid_options)
|
| 190 |
+
final_predictions = {}
|
| 191 |
+
|
| 192 |
+
for group_items in item_groups:
|
| 193 |
+
# Build combined query for this group
|
| 194 |
+
group_query = self._build_group_query(group_items, valid_options)
|
| 195 |
+
|
| 196 |
+
# Retrieve relevant tuples
|
| 197 |
+
retrieved = retriever.retrieve_and_rerank(
|
| 198 |
+
group_query,
|
| 199 |
+
top_k=self.top_k,
|
| 200 |
+
weight_boost=self.weight_boost,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Format for LLM
|
| 204 |
+
evidence_str = WTTSRetriever.format_retrieved_tuples(retrieved)
|
| 205 |
+
|
| 206 |
+
chunk_schema = {
|
| 207 |
+
item: valid_options.get(item, ["y", "n", "unknown"])
|
| 208 |
+
for item in group_items
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
prompt = RAG_EXTRACTION_PROMPT.format(
|
| 212 |
+
retrieved_evidence=evidence_str,
|
| 213 |
+
chunk_schema_json=json.dumps(chunk_schema, indent=2),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Call LLM with concurrency control
|
| 217 |
+
async with semaphore:
|
| 218 |
+
response = await self.generate_fn(prompt, model)
|
| 219 |
+
|
| 220 |
+
if isinstance(response, dict):
|
| 221 |
+
if "error" in response:
|
| 222 |
+
print(f" [{pid}] LLM error for items {group_items[:2]}...: "
|
| 223 |
+
f"{response['error']}")
|
| 224 |
+
else:
|
| 225 |
+
final_predictions.update(response)
|
| 226 |
+
|
| 227 |
+
# --- Step 5: Return results ---
|
| 228 |
+
return {
|
| 229 |
+
"patient_id": pid,
|
| 230 |
+
"predictions": final_predictions,
|
| 231 |
+
"rag_debug": {
|
| 232 |
+
"total_tuples": len(tuples),
|
| 233 |
+
"top_k": self.top_k,
|
| 234 |
+
"item_groups": len(item_groups),
|
| 235 |
+
},
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f" [{pid}] RAG extraction error: {e}")
|
| 240 |
+
import traceback
|
| 241 |
+
traceback.print_exc()
|
| 242 |
+
return None
|
rag/retriever.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WTTS Retriever — Builds per-patient FAISS index and retrieves
|
| 3 |
+
relevant tuples per CRF item with weight-based re-ranking
|
| 4 |
+
and P_j temporal sorting.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import faiss
|
| 9 |
+
from typing import List, Dict, Optional, Tuple
|
| 10 |
+
from src.rag.embedder import WTTSEmbedder
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class WTTSRetriever:
|
| 14 |
+
"""Per-patient FAISS index for retrieving relevant WTTS tuples."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, embedder: WTTSEmbedder):
|
| 17 |
+
self.embedder = embedder
|
| 18 |
+
self.index: Optional[faiss.IndexFlatIP] = None # Inner product (cosine on normalized vecs)
|
| 19 |
+
self.tuples: List[Dict] = []
|
| 20 |
+
self.embeddings: Optional[np.ndarray] = None
|
| 21 |
+
|
| 22 |
+
# ------------------------------------------------------------------
|
| 23 |
+
# Build index for one patient's tuples
|
| 24 |
+
# ------------------------------------------------------------------
|
| 25 |
+
def build_index(self, tuples: List[Dict], embeddings: np.ndarray):
|
| 26 |
+
"""
|
| 27 |
+
Build a FAISS index from pre-computed tuple embeddings.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
tuples: Parsed WTTS tuple dicts
|
| 31 |
+
embeddings: np.ndarray of shape (n_tuples, embedding_dim)
|
| 32 |
+
"""
|
| 33 |
+
self.tuples = tuples
|
| 34 |
+
self.embeddings = embeddings
|
| 35 |
+
|
| 36 |
+
if len(tuples) == 0:
|
| 37 |
+
self.index = None
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
dim = embeddings.shape[1]
|
| 41 |
+
# Use Inner Product (IP) since embeddings are L2-normalized
|
| 42 |
+
# This makes IP equivalent to cosine similarity
|
| 43 |
+
self.index = faiss.IndexFlatIP(dim)
|
| 44 |
+
self.index.add(embeddings)
|
| 45 |
+
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
# Raw retrieval (top-k by cosine similarity)
|
| 48 |
+
# ------------------------------------------------------------------
|
| 49 |
+
def retrieve(self, query_embedding: np.ndarray, top_k: int = 15) -> List[Dict]:
|
| 50 |
+
"""
|
| 51 |
+
Retrieve top-k most similar tuples to the query.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
query_embedding: 1D vector of shape (embedding_dim,)
|
| 55 |
+
top_k: Number of tuples to retrieve
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
List of tuple dicts with added 'similarity_score' field
|
| 59 |
+
"""
|
| 60 |
+
if self.index is None or len(self.tuples) == 0:
|
| 61 |
+
return []
|
| 62 |
+
|
| 63 |
+
# Clamp top_k to available tuples
|
| 64 |
+
top_k = min(top_k, len(self.tuples))
|
| 65 |
+
|
| 66 |
+
# FAISS expects 2D input
|
| 67 |
+
query = query_embedding.reshape(1, -1).astype(np.float32)
|
| 68 |
+
scores, indices = self.index.search(query, top_k)
|
| 69 |
+
|
| 70 |
+
results = []
|
| 71 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 72 |
+
if idx == -1: # FAISS returns -1 for missing results
|
| 73 |
+
continue
|
| 74 |
+
result = self.tuples[idx].copy()
|
| 75 |
+
result["similarity_score"] = float(score)
|
| 76 |
+
results.append(result)
|
| 77 |
+
|
| 78 |
+
return results
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
# Retrieve + re-rank by weight + sort by P_j
|
| 82 |
+
# ------------------------------------------------------------------
|
| 83 |
+
def retrieve_and_rerank(
|
| 84 |
+
self,
|
| 85 |
+
query_embedding: np.ndarray,
|
| 86 |
+
top_k: int = 15,
|
| 87 |
+
weight_boost: float = 0.3,
|
| 88 |
+
fetch_multiplier: int = 3,
|
| 89 |
+
) -> List[Dict]:
|
| 90 |
+
"""
|
| 91 |
+
Retrieve, re-rank using weight W, then sort by P_j for temporal order.
|
| 92 |
+
|
| 93 |
+
Strategy:
|
| 94 |
+
1. Over-fetch (top_k * fetch_multiplier) candidates from FAISS
|
| 95 |
+
2. Re-score: final_score = similarity + weight_boost * W
|
| 96 |
+
3. Take top_k by final_score
|
| 97 |
+
4. Sort the final set by P_j (ascending) to preserve temporal order
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
query_embedding: 1D vector
|
| 101 |
+
top_k: Final number of tuples to return
|
| 102 |
+
weight_boost: How much to boost critical events (W=1.0 gets +0.3)
|
| 103 |
+
fetch_multiplier: How many extra candidates to fetch for re-ranking
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
List of tuple dicts sorted by P_j (temporal order),
|
| 107 |
+
each with 'similarity_score', 'rerank_score' fields
|
| 108 |
+
"""
|
| 109 |
+
if self.index is None or len(self.tuples) == 0:
|
| 110 |
+
return []
|
| 111 |
+
|
| 112 |
+
# Step 1: Over-fetch candidates
|
| 113 |
+
fetch_k = min(top_k * fetch_multiplier, len(self.tuples))
|
| 114 |
+
candidates = self.retrieve(query_embedding, top_k=fetch_k)
|
| 115 |
+
|
| 116 |
+
# Step 2: Re-rank with weight boost
|
| 117 |
+
for candidate in candidates:
|
| 118 |
+
sim = candidate["similarity_score"]
|
| 119 |
+
w = candidate.get("weight", 0.5)
|
| 120 |
+
candidate["rerank_score"] = sim + (weight_boost * w)
|
| 121 |
+
|
| 122 |
+
# Step 3: Take top_k by re-rank score
|
| 123 |
+
candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
|
| 124 |
+
top_candidates = candidates[:top_k]
|
| 125 |
+
|
| 126 |
+
# Step 4: Sort by P_j (temporal order) — THIS is what preserves continuity
|
| 127 |
+
top_candidates.sort(key=lambda x: x.get("p_j", 0.5))
|
| 128 |
+
|
| 129 |
+
return top_candidates
|
| 130 |
+
|
| 131 |
+
# ------------------------------------------------------------------
|
| 132 |
+
# Format retrieved tuples back into a readable string for the LLM
|
| 133 |
+
# ------------------------------------------------------------------
|
| 134 |
+
@staticmethod
|
| 135 |
+
def format_retrieved_tuples(tuples: List[Dict]) -> str:
|
| 136 |
+
"""
|
| 137 |
+
Format retrieved tuples into a clean string for the LLM prompt.
|
| 138 |
+
Sorted by P_j (temporal order) by this point.
|
| 139 |
+
|
| 140 |
+
Output format:
|
| 141 |
+
[S_14] (EARLY | W:HIGH) "Patient reports increasing dyspnea"
|
| 142 |
+
[S_80] (MID | W:HIGH) "SpO2 dropped to 85%, intubated"
|
| 143 |
+
[S_155](LATE | W:MED) "Stable on room air at discharge"
|
| 144 |
+
"""
|
| 145 |
+
if not tuples:
|
| 146 |
+
return "(No relevant clinical events found)"
|
| 147 |
+
|
| 148 |
+
lines = []
|
| 149 |
+
for t in tuples:
|
| 150 |
+
p_j = t.get("p_j", 0.5)
|
| 151 |
+
w = t.get("weight", 0.5)
|
| 152 |
+
|
| 153 |
+
# Temporal phase label
|
| 154 |
+
if p_j <= 0.15:
|
| 155 |
+
phase = "ADMISSION"
|
| 156 |
+
elif p_j <= 0.35:
|
| 157 |
+
phase = "EARLY"
|
| 158 |
+
elif p_j <= 0.65:
|
| 159 |
+
phase = "MID"
|
| 160 |
+
elif p_j <= 0.85:
|
| 161 |
+
phase = "LATE"
|
| 162 |
+
else:
|
| 163 |
+
phase = "DISCHARGE"
|
| 164 |
+
|
| 165 |
+
# Weight label
|
| 166 |
+
if w >= 0.8:
|
| 167 |
+
w_label = "CRITICAL"
|
| 168 |
+
elif w >= 0.4:
|
| 169 |
+
w_label = "MODERATE"
|
| 170 |
+
else:
|
| 171 |
+
w_label = "ROUTINE"
|
| 172 |
+
|
| 173 |
+
sid = t.get("sid", "S_?")
|
| 174 |
+
event = t.get("event", "")
|
| 175 |
+
ts = t.get("timestamp", "")
|
| 176 |
+
|
| 177 |
+
lines.append(f'[{sid}] ({phase} | {w_label}) [{ts}] "{event}"')
|
| 178 |
+
|
| 179 |
+
return "\n".join(lines)
|