# engine/parser_fusion.py # ------------------------------------------------------------ # Tri-Parser Fusion — Stage 12B (Weighted, SOTA-style) # # This module combines: # - Rule parser (parser_rules.parse_text_rules) # - Extended parser (parser_ext.parse_text_extended) # - LLM parser (parser_llm.parse_llm) [optional] # # using per-field reliability weights learned in Stage 12A # and stored in: # data/field_weights.json # # Behaviour: # - For each field, gather predictions from available parsers. # - For that field, load weights: # field_weights[field] (if present) # else global weights # else equal weights across available parsers # - Discard parsers that: # * did not predict the field # * or only predicted "Unknown" # - Group by predicted value and sum the weights of parsers # that voted for each value. # - Choose the value with highest total weight. # Tie-break: prefer rules > extended > llm if needed. # # Output format: # { # "fused_fields": { field: value, ... }, # used by DB identifier AND genus ML # "by_parser": { # "rules": { ... }, # "extended": { ... }, # "llm": { ... } # may be empty # }, # "votes": { # field_name: { # "per_parser": { # "rules": {"value": "Positive", "weight": 0.95}, # "extended": {"value": "Unknown", "weight": 0.03}, # ... # }, # "summed": { # "Positive": 0.97, # "Negative": 0.02 # }, # "chosen": "Positive" # }, # ... # }, # "weights_meta": { # "has_weights_file": True/False, # "weights_path": "data/field_weights.json", # "meta": { ... } # from file if present # } # } # ------------------------------------------------------------ from __future__ import annotations import json import os from typing import Any, Dict, Optional from engine.parser_rules import parse_text_rules from engine.parser_ext import parse_text_extended # Optional LLM parser try: from engine.parser_llm import parse_llm as parse_text_llm # type: ignore HAS_LLM = True except Exception: parse_text_llm = None # type: ignore HAS_LLM = False # Path to learned weights FIELD_WEIGHTS_PATH = os.path.join("data", "field_weights.json") UNKNOWN = "Unknown" PARSER_ORDER = ["rules", "extended", "llm"] # tie-breaking priority # ------------------------------------------------------------ # Weights loading and helpers # ------------------------------------------------------------ def _load_field_weights(path: str = FIELD_WEIGHTS_PATH) -> Dict[str, Any]: """ Load the JSON weights file produced by Stage 12A. Expected structure: { "global": { "rules": 0.7, "extended": 0.2, "llm": 0.1 }, "fields": { "DNase": { "rules": 0.95, "extended": 0.03, "llm": 0.02, "support": 123 }, ... }, "meta": { ... } } If the file is missing or broken, fall back to empty dict, triggering equal-weight behaviour later. """ if not os.path.exists(path): return {} try: with open(path, "r", encoding="utf-8") as f: obj = json.load(f) return obj if isinstance(obj, dict) else {} except Exception: return {} FIELD_WEIGHTS_RAW: Dict[str, Any] = _load_field_weights() HAS_WEIGHTS_FILE: bool = bool(FIELD_WEIGHTS_RAW) def _normalise_scores(scores: Dict[str, float]) -> Dict[str, float]: """ Normalise parser -> score into weights summing to 1. If all scores are zero or dict is empty, return equal weights. """ cleaned = {k: max(0.0, float(v)) for k, v in scores.items()} total = sum(cleaned.values()) if total <= 0: n = len(cleaned) or 1 return {k: 1.0 / n for k in cleaned} return {k: v / total for k, v in cleaned.items()} def _get_base_weights_for_parsers(include_llm: bool) -> Dict[str, float]: """ Equal-weight distribution across available parsers. Used when no learned weights are available. """ parsers = ["rules", "extended"] if include_llm: parsers.append("llm") n = len(parsers) or 1 return {p: 1.0 / n for p in parsers} def _get_weights_for_field(field_name: str, include_llm: bool) -> Dict[str, float]: """ Get weights for a specific field. Priority: 1) FIELD_WEIGHTS_RAW["fields"][field_name] 2) FIELD_WEIGHTS_RAW["global"] 3) Equal weights Always: - Drop 'llm' if include_llm == False - Normalise """ if not FIELD_WEIGHTS_RAW: return _normalise_scores(_get_base_weights_for_parsers(include_llm)) fields_block = FIELD_WEIGHTS_RAW.get("fields", {}) or {} global_block = FIELD_WEIGHTS_RAW.get("global", {}) or {} raw: Dict[str, float] = {} field_entry = fields_block.get(field_name) if isinstance(field_entry, dict): for k, v in field_entry.items(): if k in ("rules", "extended", "llm"): raw[k] = float(v) if not raw and isinstance(global_block, dict): for k, v in global_block.items(): if k in ("rules", "extended", "llm"): raw[k] = float(v) if not raw: raw = _get_base_weights_for_parsers(include_llm) if not include_llm: raw.pop("llm", None) if not raw: raw = _get_base_weights_for_parsers(include_llm=False) return _normalise_scores(raw) # ------------------------------------------------------------ # Fusion logic # ------------------------------------------------------------ def _clean_pred_value(val: Optional[str]) -> Optional[str]: """ Treat None, empty string, or explicit "Unknown" as missing. """ if val is None: return None s = str(val).strip() if not s: return None if s.lower() == UNKNOWN.lower(): return None return s def parse_text_fused(text: str, use_llm: Optional[bool] = None) -> Dict[str, Any]: """ Main tri-parser fusion entrypoint. Parameters ---------- text : str use_llm : bool or None True -> include LLM False -> exclude LLM None -> include if available Returns ------- Dict[str, Any] Full fusion output including votes and per-parser breakdowns. """ original = text or "" include_llm = HAS_LLM if use_llm is None else bool(use_llm) rules_out = parse_text_rules(original) or {} ext_out = parse_text_extended(original) or {} rules_fields = dict(rules_out.get("parsed_fields", {})) ext_fields = dict(ext_out.get("parsed_fields", {})) llm_fields: Dict[str, Any] = {} if include_llm and parse_text_llm is not None: try: merged_existing = {} merged_existing.update(rules_fields) merged_existing.update(ext_fields) llm_out = parse_text_llm(original, existing_fields=merged_existing) if isinstance(llm_out, dict): if "parsed_fields" in llm_out: llm_fields = dict(llm_out.get("parsed_fields", {})) else: llm_fields = {str(k): v for k, v in llm_out.items()} except Exception: llm_fields = {} else: include_llm = False by_parser: Dict[str, Dict[str, Any]] = { "rules": rules_fields, "extended": ext_fields, "llm": llm_fields if include_llm else {}, } candidate_fields = ( set(rules_fields.keys()) | set(ext_fields.keys()) | set(llm_fields.keys()) ) fused_fields: Dict[str, Any] = {} votes_debug: Dict[str, Any] = {} for field in sorted(candidate_fields): weights = _get_weights_for_field(field, include_llm) parser_preds: Dict[str, Optional[str]] = { "rules": _clean_pred_value(rules_fields.get(field)), "extended": _clean_pred_value(ext_fields.get(field)), "llm": _clean_pred_value(llm_fields.get(field)) if include_llm else None, } per_parser_info: Dict[str, Any] = {} value_scores: Dict[str, float] = {} for parser_name in PARSER_ORDER: if parser_name == "llm" and not include_llm: continue pred = parser_preds.get(parser_name) w = float(weights.get(parser_name, 0.0)) per_parser_info[parser_name] = { "value": pred if pred is not None else UNKNOWN, "weight": w, } if pred is not None: value_scores[pred] = value_scores.get(pred, 0.0) + w if not value_scores: fused_value = UNKNOWN else: max_score = max(value_scores.values()) best_values = [v for v, s in value_scores.items() if s == max_score] if len(best_values) == 1: fused_value = best_values[0] else: fused_value = best_values[0] for parser_name in PARSER_ORDER: if parser_name == "llm" and not include_llm: continue if parser_preds.get(parser_name) in best_values: fused_value = parser_preds[parser_name] # type: ignore break fused_fields[field] = fused_value votes_debug[field] = { "per_parser": per_parser_info, "summed": value_scores, "chosen": fused_value, } weights_meta = { "has_weights_file": HAS_WEIGHTS_FILE, "weights_path": FIELD_WEIGHTS_PATH, "meta": FIELD_WEIGHTS_RAW.get("meta", {}) if HAS_WEIGHTS_FILE else {}, } return { "fused_fields": fused_fields, "by_parser": by_parser, "votes": votes_debug, "weights_meta": weights_meta, }