Spaces:
Running
Running
| """ | |
| OncoAgent — Dataset Builder / Unifier. | |
| Combines real (HuggingFace-filtered) and synthetic (Qwen generated) | |
| oncology data into a single training-ready JSONL corpus in ChatML | |
| chat template format. | |
| Hardware Target: CPU only. | |
| Rule Compliance: #12 (JSONL + ChatML format), #22 (seeds), #26 (type hints). | |
| """ | |
| import json | |
| import os | |
| import random | |
| import logging | |
| import hashlib | |
| from typing import Dict, List, Tuple | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| random.seed(42) | |
| logging.basicConfig( | |
| level=os.getenv("LOG_LEVEL", "INFO"), | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ── Paths ─────────────────────────────────────────────────────────────────── | |
| FILTERED_REAL: str = os.path.join("data", "filtered", "onco_real_filtered.jsonl") | |
| SYNTHETIC_DIR: str = os.path.join("data", "synthetic") | |
| FINAL_OUTPUT: str = os.path.join("data", "final", "train_oncoagent.jsonl") | |
| SYSTEM_PROMPT: str = ( | |
| "You are an expert clinical oncologist specializing in cancer triage. " | |
| "Analyze the patient's clinical presentation using temporal-causal " | |
| "reasoning (OncoCoT). Provide: (1) key findings, (2) step-by-step " | |
| "diagnostic reasoning with staging, and (3) evidence-based recommendations " | |
| "citing NCCN/ESMO guidelines where applicable." | |
| ) | |
| def format_synthetic_to_chatml(case: Dict[str, str]) -> str: | |
| """Convert a synthetic case dict to ChatML template. | |
| Args: | |
| case: Dict with 'history', 'reasoning', 'conclusion' keys. | |
| Returns: | |
| Formatted ChatML template string. | |
| """ | |
| history = case.get("history", "") | |
| reasoning = case.get("reasoning", "") | |
| conclusion = case.get("conclusion", "") | |
| user_msg = f"Clinical Presentation:\n{history}" | |
| assistant_msg = ( | |
| f"Diagnostic Reasoning:\n{reasoning}\n\n" | |
| f"Assessment & Plan:\n{conclusion}" | |
| ) | |
| return ( | |
| f"<|im_start|>system\n" | |
| f"{SYSTEM_PROMPT}<|im_end|>\n" | |
| f"<|im_start|>user\n" | |
| f"{user_msg}<|im_end|>\n" | |
| f"<|im_start|>assistant\n" | |
| f"{assistant_msg}<|im_end|>" | |
| ) | |
| def load_real_data() -> List[Dict[str, str]]: | |
| """Load real filtered oncology data.""" | |
| if not os.path.exists(FILTERED_REAL): | |
| logger.warning(f"⚠️ Real data not found: {FILTERED_REAL}") | |
| return [] | |
| entries: List[Dict[str, str]] = [] | |
| with open(FILTERED_REAL, "r", encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| entries.append(json.loads(line.strip())) | |
| except json.JSONDecodeError: | |
| continue | |
| logger.info(f"📚 Loaded {len(entries):,} real oncology samples") | |
| return entries | |
| def load_synthetic_data() -> List[Dict[str, str]]: | |
| """Load all synthetic generated data and format to ChatML.""" | |
| if not os.path.exists(SYNTHETIC_DIR): | |
| logger.warning(f"⚠️ Synthetic dir not found: {SYNTHETIC_DIR}") | |
| return [] | |
| # Look for the final consolidated file first | |
| final = os.path.join(SYNTHETIC_DIR, "onco_synthetic_final.jsonl") | |
| files_to_read = [] | |
| if os.path.exists(final): | |
| files_to_read = [final] | |
| else: | |
| files_to_read = sorted([ | |
| os.path.join(SYNTHETIC_DIR, f) | |
| for f in os.listdir(SYNTHETIC_DIR) | |
| if f.endswith(".jsonl") and f.startswith("generated_") | |
| ]) | |
| entries: List[Dict[str, str]] = [] | |
| for fpath in files_to_read: | |
| with open(fpath, "r", encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| case = json.loads(line.strip()) | |
| formatted = format_synthetic_to_chatml(case) | |
| entries.append({ | |
| "text": formatted, | |
| "source": "synthetic_qwen", | |
| }) | |
| except (json.JSONDecodeError, KeyError): | |
| continue | |
| logger.info(f"🧬 Loaded {len(entries):,} synthetic oncology samples") | |
| return entries | |
| def _compute_corpus_hash(entries: List[Dict[str, str]]) -> str: | |
| """Compute a deterministic hash of the corpus for reproducibility tracking. | |
| Args: | |
| entries: List of training entries. | |
| Returns: | |
| SHA-256 hex digest (first 12 chars). | |
| """ | |
| h = hashlib.sha256() | |
| for e in entries: | |
| h.update(e.get("text", "").encode("utf-8")) | |
| return h.hexdigest()[:12] | |
| def build_unified_corpus( | |
| eval_ratio: float = 0.10, | |
| ) -> Tuple[str, str]: | |
| """Build the final unified training corpus with a train/eval split. | |
| Args: | |
| eval_ratio: Fraction of samples reserved for evaluation (default 10%). | |
| Returns: | |
| Tuple of (train_path, eval_path). | |
| """ | |
| logger.info("🚀 Building unified OncoAgent training corpus...") | |
| logger.info("=" * 60) | |
| real = load_real_data() | |
| synthetic = load_synthetic_data() | |
| combined = real + synthetic | |
| random.shuffle(combined) | |
| # Deduplicate by text content | |
| seen_hashes: set = set() | |
| deduped: List[Dict[str, str]] = [] | |
| for entry in combined: | |
| text_hash = hashlib.sha256(entry.get("text", "").encode()).hexdigest() | |
| if text_hash not in seen_hashes: | |
| seen_hashes.add(text_hash) | |
| deduped.append(entry) | |
| if len(deduped) < len(combined): | |
| logger.info(f"🧹 Deduplication: removed {len(combined) - len(deduped):,} duplicate samples") | |
| combined = deduped | |
| # Train/eval split | |
| split_idx = max(1, int(len(combined) * (1.0 - eval_ratio))) | |
| train_set = combined[:split_idx] | |
| eval_set = combined[split_idx:] | |
| os.makedirs(os.path.dirname(FINAL_OUTPUT), exist_ok=True) | |
| eval_output = FINAL_OUTPUT.replace(".jsonl", "_eval.jsonl") | |
| for path, entries in [(FINAL_OUTPUT, train_set), (eval_output, eval_set)]: | |
| with open(path, "w", encoding="utf-8") as f: | |
| for entry in entries: | |
| f.write(json.dumps(entry, ensure_ascii=False) + "\n") | |
| # Statistics | |
| source_counts: Dict[str, int] = {} | |
| for e in combined: | |
| src = e.get("source", "unknown") | |
| source_counts[src] = source_counts.get(src, 0) + 1 | |
| corpus_hash = _compute_corpus_hash(combined) | |
| logger.info(f"📊 UNIFIED CORPUS BUILT — {len(combined):,} total samples") | |
| logger.info(f" ├── Train: {len(train_set):,} ({100*(1-eval_ratio):.0f}%)") | |
| logger.info(f" ├── Eval: {len(eval_set):,} ({100*eval_ratio:.0f}%)") | |
| for src, cnt in sorted(source_counts.items(), key=lambda x: -x[1]): | |
| pct = (cnt / len(combined)) * 100 | |
| logger.info(f" ├── {src}: {cnt:,} ({pct:.1f}%)") | |
| logger.info(f" ├── Corpus hash: {corpus_hash}") | |
| logger.info(f" ├── Train output: {FINAL_OUTPUT}") | |
| logger.info(f" └── Eval output: {eval_output}") | |
| logger.info("=" * 60) | |
| return FINAL_OUTPUT, eval_output | |
| if __name__ == "__main__": | |
| build_unified_corpus() | |