OncoAgent / data_prep /dataset_builder.py
MaximoLopezChenlo's picture
feat: OncoAgent v1.0 — AMD MI300X production deployment (code only)
de48da3 verified
"""
OncoAgent — Dataset Builder / Unifier.
Combines real (HuggingFace-filtered) and synthetic (Qwen generated)
oncology data into a single training-ready JSONL corpus in ChatML
chat template format.
Hardware Target: CPU only.
Rule Compliance: #12 (JSONL + ChatML format), #22 (seeds), #26 (type hints).
"""
import json
import os
import random
import logging
import hashlib
from typing import Dict, List, Tuple
from dotenv import load_dotenv
load_dotenv()
random.seed(42)
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# ── Paths ───────────────────────────────────────────────────────────────────
FILTERED_REAL: str = os.path.join("data", "filtered", "onco_real_filtered.jsonl")
SYNTHETIC_DIR: str = os.path.join("data", "synthetic")
FINAL_OUTPUT: str = os.path.join("data", "final", "train_oncoagent.jsonl")
SYSTEM_PROMPT: str = (
"You are an expert clinical oncologist specializing in cancer triage. "
"Analyze the patient's clinical presentation using temporal-causal "
"reasoning (OncoCoT). Provide: (1) key findings, (2) step-by-step "
"diagnostic reasoning with staging, and (3) evidence-based recommendations "
"citing NCCN/ESMO guidelines where applicable."
)
def format_synthetic_to_chatml(case: Dict[str, str]) -> str:
"""Convert a synthetic case dict to ChatML template.
Args:
case: Dict with 'history', 'reasoning', 'conclusion' keys.
Returns:
Formatted ChatML template string.
"""
history = case.get("history", "")
reasoning = case.get("reasoning", "")
conclusion = case.get("conclusion", "")
user_msg = f"Clinical Presentation:\n{history}"
assistant_msg = (
f"Diagnostic Reasoning:\n{reasoning}\n\n"
f"Assessment & Plan:\n{conclusion}"
)
return (
f"<|im_start|>system\n"
f"{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n"
f"{user_msg}<|im_end|>\n"
f"<|im_start|>assistant\n"
f"{assistant_msg}<|im_end|>"
)
def load_real_data() -> List[Dict[str, str]]:
"""Load real filtered oncology data."""
if not os.path.exists(FILTERED_REAL):
logger.warning(f"⚠️ Real data not found: {FILTERED_REAL}")
return []
entries: List[Dict[str, str]] = []
with open(FILTERED_REAL, "r", encoding="utf-8") as f:
for line in f:
try:
entries.append(json.loads(line.strip()))
except json.JSONDecodeError:
continue
logger.info(f"📚 Loaded {len(entries):,} real oncology samples")
return entries
def load_synthetic_data() -> List[Dict[str, str]]:
"""Load all synthetic generated data and format to ChatML."""
if not os.path.exists(SYNTHETIC_DIR):
logger.warning(f"⚠️ Synthetic dir not found: {SYNTHETIC_DIR}")
return []
# Look for the final consolidated file first
final = os.path.join(SYNTHETIC_DIR, "onco_synthetic_final.jsonl")
files_to_read = []
if os.path.exists(final):
files_to_read = [final]
else:
files_to_read = sorted([
os.path.join(SYNTHETIC_DIR, f)
for f in os.listdir(SYNTHETIC_DIR)
if f.endswith(".jsonl") and f.startswith("generated_")
])
entries: List[Dict[str, str]] = []
for fpath in files_to_read:
with open(fpath, "r", encoding="utf-8") as f:
for line in f:
try:
case = json.loads(line.strip())
formatted = format_synthetic_to_chatml(case)
entries.append({
"text": formatted,
"source": "synthetic_qwen",
})
except (json.JSONDecodeError, KeyError):
continue
logger.info(f"🧬 Loaded {len(entries):,} synthetic oncology samples")
return entries
def _compute_corpus_hash(entries: List[Dict[str, str]]) -> str:
"""Compute a deterministic hash of the corpus for reproducibility tracking.
Args:
entries: List of training entries.
Returns:
SHA-256 hex digest (first 12 chars).
"""
h = hashlib.sha256()
for e in entries:
h.update(e.get("text", "").encode("utf-8"))
return h.hexdigest()[:12]
def build_unified_corpus(
eval_ratio: float = 0.10,
) -> Tuple[str, str]:
"""Build the final unified training corpus with a train/eval split.
Args:
eval_ratio: Fraction of samples reserved for evaluation (default 10%).
Returns:
Tuple of (train_path, eval_path).
"""
logger.info("🚀 Building unified OncoAgent training corpus...")
logger.info("=" * 60)
real = load_real_data()
synthetic = load_synthetic_data()
combined = real + synthetic
random.shuffle(combined)
# Deduplicate by text content
seen_hashes: set = set()
deduped: List[Dict[str, str]] = []
for entry in combined:
text_hash = hashlib.sha256(entry.get("text", "").encode()).hexdigest()
if text_hash not in seen_hashes:
seen_hashes.add(text_hash)
deduped.append(entry)
if len(deduped) < len(combined):
logger.info(f"🧹 Deduplication: removed {len(combined) - len(deduped):,} duplicate samples")
combined = deduped
# Train/eval split
split_idx = max(1, int(len(combined) * (1.0 - eval_ratio)))
train_set = combined[:split_idx]
eval_set = combined[split_idx:]
os.makedirs(os.path.dirname(FINAL_OUTPUT), exist_ok=True)
eval_output = FINAL_OUTPUT.replace(".jsonl", "_eval.jsonl")
for path, entries in [(FINAL_OUTPUT, train_set), (eval_output, eval_set)]:
with open(path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
# Statistics
source_counts: Dict[str, int] = {}
for e in combined:
src = e.get("source", "unknown")
source_counts[src] = source_counts.get(src, 0) + 1
corpus_hash = _compute_corpus_hash(combined)
logger.info(f"📊 UNIFIED CORPUS BUILT — {len(combined):,} total samples")
logger.info(f" ├── Train: {len(train_set):,} ({100*(1-eval_ratio):.0f}%)")
logger.info(f" ├── Eval: {len(eval_set):,} ({100*eval_ratio:.0f}%)")
for src, cnt in sorted(source_counts.items(), key=lambda x: -x[1]):
pct = (cnt / len(combined)) * 100
logger.info(f" ├── {src}: {cnt:,} ({pct:.1f}%)")
logger.info(f" ├── Corpus hash: {corpus_hash}")
logger.info(f" ├── Train output: {FINAL_OUTPUT}")
logger.info(f" └── Eval output: {eval_output}")
logger.info("=" * 60)
return FINAL_OUTPUT, eval_output
if __name__ == "__main__":
build_unified_corpus()