tush / main.py
ahuggingface01's picture
Upload 8 files
9ba48be verified
import asyncio
import json
import argparse
import os
import math
# Imports from your specific file structure
from src.preprocess.wtts_builder import WTTSBuilder
from src.utils.data_loader import DataLoader
# RAG pipeline imports (lazy — only used when --use_rag is set)
RAG_AVAILABLE = False
try:
from src.rag.embedder import WTTSEmbedder
from src.rag.rag_pipeline import RAGCRFExtractor
RAG_AVAILABLE = True
except ImportError:
pass
import google.generativeai as genai
# --- CONFIG ---
API_KEY = "AIzaSyAkgKha4IxsCjRXbeirhyoygT9Qmr4qYzU"
# --- PROMPTS ---
SKELETON_PROMPT = """
You are a Clinical Data Specialist.
Convert the Weighted Time Series (WTTS) below into a "Clinical Chronology Skeleton".
INPUT (WTTS):
{wtts_string}
INSTRUCTIONS:
1. Create a strict chronological timeline (Admission to Discharge).
2. IMPORTANT: You MUST retain the [S_xx] ID for every event you list.
3. Filter out "Routine" (Weight 0.1) events unless they indicate a status change.
4. Keep exact values (e.g., "BP 90/60", "Temp 102.5").
OUTPUT FORMAT:
[Date] [S_xx]: Event details
[Date] [S_xx]: Event details
...
"""
EXTRACTION_PROMPT = """
You are a Clinical Coding Expert.
Review the Patient Skeleton and the Valid Options for the requested items.
PATIENT SKELETON:
{skeleton}
TASK:
For each Clinical Item listed below, determine the value AND the supporting Sentence ID.
1. **Value**: Must come strictly from the "Valid Options" provided.
2. **Evidence**: Must be the specific [S_xx] ID from the skeleton that proves the value.
ITEMS TO EXTRACT & THEIR OPTIONS:
{chunk_schema_json}
OUTPUT FORMAT (JSON Object):
{{
"item_name": {{
"value": "Selected Option",
"evidence": "S_xx",
"reasoning": "Brief explanation"
}},
...
}}
"""
def chunk_data(data, size):
"""Yield successive n-sized chunks from list."""
for i in range(0, len(data), size):
yield data[i:i + size]
async def generate_async(prompt, model, max_retries=3, initial_delay=1):
"""Call Gemini via google-generativeai SDK (async-safe)."""
loop = asyncio.get_event_loop()
for attempt in range(max_retries):
try:
response = await loop.run_in_executor(
None,
lambda: model.generate_content(
contents=prompt,
generation_config=genai.GenerationConfig(
response_mime_type="application/json"
),
)
)
try:
json_response = json.loads(response.text)
return json_response
except json.JSONDecodeError:
print(f"Generated content is not valid JSON. Retrying...")
continue
except Exception as e:
error_message = str(e)
if "429" in error_message or "500" in error_message:
if attempt < max_retries - 1:
delay = initial_delay * (2 ** attempt)
print(f"Rate limit / server error. Retrying in {delay}s...")
await asyncio.sleep(delay)
else:
print(f"Max retries reached.")
return {"error": f"Max retries reached - {error_message}"}
else:
print(f"Error in generate_async: {error_message}")
return {"error": error_message}
return {"error": "Failed to generate valid JSON after multiple attempts"}
async def process_patient(model, builder, patient_data, target_items, valid_options, semaphore):
"""Executes the Two-Pass Pipeline for a single patient."""
async with semaphore:
pid = str(patient_data.get('document_id') or patient_data.get('patient_id')
or patient_data.get('hadm_id') or 'unknown')
try:
# --- PHASE 1: WTTS Construction ---
wtts_string = builder.build_wtts_string(patient_data)
# --- PHASE 2: Skeleton Generation (Pass 1) ---
skeleton_input = SKELETON_PROMPT.format(wtts_string=wtts_string)
skeleton_resp = await generate_async(skeleton_input, model)
skeleton_text = str(skeleton_resp)
if isinstance(skeleton_resp, dict):
skeleton_text = json.dumps(skeleton_resp)
# --- PHASE 3: Extraction (Pass 2) ---
final_predictions = {}
item_chunks = list(chunk_data(target_items, 10))
for chunk_items in item_chunks:
chunk_schema = {
item: valid_options.get(item, ["Yes", "No", "Unknown"])
for item in chunk_items
}
extract_input = EXTRACTION_PROMPT.format(
skeleton=skeleton_text,
chunk_schema_json=json.dumps(chunk_schema)
)
chunk_resp = await generate_async(extract_input, model)
if isinstance(chunk_resp, dict):
if 'error' in chunk_resp:
print(f" [WARN] LLM error for {pid}, chunk {chunk_items[:3]}...: {chunk_resp['error']}")
else:
final_predictions.update(chunk_resp)
return {
"patient_id": pid,
"skeleton_debug": skeleton_text[:500] + "...",
"predictions": final_predictions
}
except Exception as e:
print(f"Error processing {pid}: {e}")
return None
# ---------------------------------------------------------------------------
# EVALUATION -- Accuracy & F1 Scoring
# ---------------------------------------------------------------------------
def _normalise(value):
"""Lowercase + strip for fair comparison."""
if value is None:
return ""
return str(value).strip().lower()
def evaluate_predictions(results, gt_path):
"""
Compare pipeline results against dev_gt.jsonl.
Prints accuracy, macro-F1, per-item breakdown, and sample errors.
Returns (overall_dict, per_item_dict).
"""
# --- Load GT ---
gt = {}
with open(gt_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
rec = json.loads(line)
doc_id = str(rec['document_id'])
gt[doc_id] = {a['item']: a['ground_truth'] for a in rec.get('annotations', [])}
# --- Build prediction lookup ---
preds = {}
for r in results:
doc_id = str(r.get('patient_id', 'unknown'))
items = {}
for item_name, item_val in r.get('predictions', {}).items():
if isinstance(item_val, dict):
items[item_name] = item_val.get('value', str(item_val))
else:
items[item_name] = str(item_val)
preds[doc_id] = items
# --- Collect all unique items ---
all_items = set()
for doc_items in gt.values():
all_items.update(doc_items.keys())
# --- Score ---
item_stats = {item: {'tp': 0, 'fp': 0, 'fn': 0, 'total': 0, 'correct': 0}
for item in all_items}
total_comparisons = 0
total_correct = 0
matched_patients = 0
errors = []
for doc_id, gt_items in gt.items():
pred_items = preds.get(doc_id, {})
if pred_items:
matched_patients += 1
for item_name, gt_val in gt_items.items():
gt_norm = _normalise(gt_val)
pred_val = pred_items.get(item_name)
pred_norm = _normalise(pred_val) if pred_val is not None else ""
total_comparisons += 1
item_stats[item_name]['total'] += 1
if gt_norm == pred_norm:
total_correct += 1
item_stats[item_name]['correct'] += 1
item_stats[item_name]['tp'] += 1
else:
item_stats[item_name]['fn'] += 1
if pred_norm:
item_stats[item_name]['fp'] += 1
errors.append((doc_id, item_name, gt_val,
pred_val if pred_val is not None else '<MISSING>'))
accuracy = total_correct / total_comparisons if total_comparisons > 0 else 0.0
# --- Per-item P/R/F1 ---
f1s = []
per_item = {}
for item_name in sorted(all_items):
s = item_stats[item_name]
tp, fp, fn = s['tp'], s['fp'], s['fn']
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
item_acc = s['correct'] / s['total'] if s['total'] > 0 else 0.0
per_item[item_name] = {'accuracy': item_acc, 'precision': prec,
'recall': rec, 'f1': f1, 'total': s['total']}
f1s.append(f1)
macro_f1 = sum(f1s) / len(f1s) if f1s else 0.0
# --- Print report ---
print("\n" + "=" * 70)
print(" CL4Health CRF Filling -- Evaluation Report")
print("=" * 70)
print(f"\n GT Patients: {len(gt)}")
print(f" Pred Patients: {len(preds)}")
print(f" Matched Patients: {matched_patients}")
print(f"\n Total Comparisons: {total_comparisons}")
print(f" Total Correct: {total_correct}")
print(f"\n {'Accuracy':>20s}: {accuracy:.4f}")
print(f" {'Macro F1':>20s}: {macro_f1:.4f}")
# Top / bottom items
sorted_items = sorted(per_item.items(), key=lambda x: x[1]['f1'], reverse=True)
n_show = min(15, len(sorted_items))
print(f"\n Top {n_show} Items by F1:")
print(f" {'Item':<45s} {'Acc':>6s} {'P':>6s} {'R':>6s} {'F1':>6s}")
print(f" {'-'*45} {'---':>6s} {'---':>6s} {'---':>6s} {'---':>6s}")
for name, s in sorted_items[:n_show]:
print(f" {name:<45s} {s['accuracy']:>6.2f} {s['precision']:>6.2f} {s['recall']:>6.2f} {s['f1']:>6.2f}")
print(f"\n Bottom {n_show} Items by F1:")
print(f" {'Item':<45s} {'Acc':>6s} {'P':>6s} {'R':>6s} {'F1':>6s}")
print(f" {'-'*45} {'---':>6s} {'---':>6s} {'---':>6s} {'---':>6s}")
for name, s in sorted_items[-n_show:]:
print(f" {name:<45s} {s['accuracy']:>6.2f} {s['precision']:>6.2f} {s['recall']:>6.2f} {s['f1']:>6.2f}")
# Sample errors
if errors:
n_err = min(15, len(errors))
print(f"\n Sample Mismatches ({n_err} of {len(errors)}):")
print(f" {'DocID':<12s} {'Item':<40s} {'GT':<20s} {'Pred':<20s}")
print(f" {'-'*12} {'-'*40} {'-'*20} {'-'*20}")
for doc_id, item, gt_v, pred_v in errors[:n_err]:
print(f" {doc_id:<12s} {item:<40s} {str(gt_v):<20s} {str(pred_v):<20s}")
print("=" * 70)
return {'accuracy': round(accuracy, 4), 'macro_f1': round(macro_f1, 4)}, per_item
# ---------------------------------------------------------------------------
# MAIN
# ---------------------------------------------------------------------------
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--api_key", default=API_KEY,
help="Google AI Studio API key")
parser.add_argument("--model_name", default="gemini-1.5-pro",
help="Gemini model name")
parser.add_argument("--data_folders", nargs="+",
default=[
r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-clinical-notes",
r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-crf-development",
],
help="Directories containing .parquet shards (searched recursively)")
parser.add_argument("--gt_file",
default=r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dev_gt.jsonl")
parser.add_argument("--options_folder",
default=r"C:\Users\sai78\Desktop\Clinical_CRF_filling\data\raw\dyspnea-valid-options\dyspnea-valid-options\data")
parser.add_argument("--output_file",
default="data/processed/materialized_ehr/submission.json")
parser.add_argument("--skip_eval", action="store_true",
help="Skip evaluation after generating predictions")
parser.add_argument("--concurrency", type=int, default=5,
help="Max concurrent LLM calls (free tier: keep at 5)")
# --- RAG options ---
parser.add_argument("--use_rag", action="store_true",
help="Use RAG-guided extraction (retrieves relevant tuples per CRF item)")
parser.add_argument("--rag_top_k", type=int, default=15,
help="Number of WTTS tuples to retrieve per CRF item group (RAG mode)")
parser.add_argument("--rag_model", type=str, default="all-MiniLM-L6-v2",
help="SentenceTransformer model for embeddings (swap to clinical model on GPU)")
parser.add_argument("--rag_device", type=str, default="cpu",
help="Device for embedding model: 'cpu' or 'cuda'")
args = parser.parse_args()
# 1. Setup — Configure Gemini API
genai.configure(api_key=args.api_key)
model = genai.GenerativeModel(args.model_name)
print(f"Using model: {args.model_name} (Google AI Studio)")
# Limit concurrency (free tier = 15 RPM, so keep low)
semaphore = asyncio.Semaphore(args.concurrency)
# 2. Load Data
loader = DataLoader(data_folders=args.data_folders, gt_path=args.gt_file)
target_items = loader.get_target_schema()
valid_options = loader.load_valid_options(args.options_folder)
merged_data = loader.load_and_merge()
if not merged_data:
print("No data found. Exiting.")
return
# 3. Process
builder = WTTSBuilder()
print(f"Starting pipeline for {len(merged_data)} patients...")
print(f"Schema: {len(target_items)} items per patient.")
if args.use_rag:
# --- RAG Pipeline ---
if not RAG_AVAILABLE:
print("ERROR: RAG dependencies not installed. Run:")
print(" pip install sentence-transformers faiss-cpu")
return
print(f"\n [RAG MODE] Embedding model: {args.rag_model}")
print(f" [RAG MODE] Device: {args.rag_device}")
print(f" [RAG MODE] Top-k: {args.rag_top_k}\n")
embedder = WTTSEmbedder(model_name=args.rag_model, device=args.rag_device)
extractor = RAGCRFExtractor(
embedder=embedder,
generate_fn=generate_async,
top_k=args.rag_top_k,
)
tasks = [
extractor.extract_patient(
p, builder, target_items, valid_options, semaphore, model
)
for p in merged_data
]
else:
# --- Original Two-Pass Pipeline ---
tasks = [
process_patient(model, builder, p, target_items, valid_options, semaphore)
for p in merged_data
]
results = await asyncio.gather(*tasks)
results = [r for r in results if r is not None]
# 4. Save
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nDone! {len(results)} results saved to {args.output_file}")
# 5. Evaluate against GT
if not args.skip_eval:
print("\nRunning evaluation against ground truth...")
overall, _ = evaluate_predictions(results, args.gt_file)
print(f"\n >>> Final Accuracy: {overall['accuracy']:.4f} | Macro F1: {overall['macro_f1']:.4f}")
if __name__ == "__main__":
asyncio.run(main())