BactKing / engine /parser_fusion.py
EphAsad's picture
Update engine/parser_fusion.py
957a730 verified
# engine/parser_fusion.py
# ------------------------------------------------------------
# Tri-Parser Fusion — Stage 12B (Weighted, SOTA-style)
#
# This module combines:
# - Rule parser (parser_rules.parse_text_rules)
# - Extended parser (parser_ext.parse_text_extended)
# - LLM parser (parser_llm.parse_llm) [optional]
#
# using per-field reliability weights learned in Stage 12A
# and stored in:
# data/field_weights.json
#
# Behaviour:
# - For each field, gather predictions from available parsers.
# - For that field, load weights:
# field_weights[field] (if present)
# else global weights
# else equal weights across available parsers
# - Discard parsers that:
# * did not predict the field
# * or only predicted "Unknown"
# - Group by predicted value and sum the weights of parsers
# that voted for each value.
# - Choose the value with highest total weight.
# Tie-break: prefer rules > extended > llm if needed.
#
# Output format:
# {
# "fused_fields": { field: value, ... }, # used by DB identifier AND genus ML
# "by_parser": {
# "rules": { ... },
# "extended": { ... },
# "llm": { ... } # may be empty
# },
# "votes": {
# field_name: {
# "per_parser": {
# "rules": {"value": "Positive", "weight": 0.95},
# "extended": {"value": "Unknown", "weight": 0.03},
# ...
# },
# "summed": {
# "Positive": 0.97,
# "Negative": 0.02
# },
# "chosen": "Positive"
# },
# ...
# },
# "weights_meta": {
# "has_weights_file": True/False,
# "weights_path": "data/field_weights.json",
# "meta": { ... } # from file if present
# }
# }
# ------------------------------------------------------------
from __future__ import annotations
import json
import os
from typing import Any, Dict, Optional
from engine.parser_rules import parse_text_rules
from engine.parser_ext import parse_text_extended
# Optional LLM parser
try:
from engine.parser_llm import parse_llm as parse_text_llm # type: ignore
HAS_LLM = True
except Exception:
parse_text_llm = None # type: ignore
HAS_LLM = False
# Path to learned weights
FIELD_WEIGHTS_PATH = os.path.join("data", "field_weights.json")
UNKNOWN = "Unknown"
PARSER_ORDER = ["rules", "extended", "llm"] # tie-breaking priority
# ------------------------------------------------------------
# Weights loading and helpers
# ------------------------------------------------------------
def _load_field_weights(path: str = FIELD_WEIGHTS_PATH) -> Dict[str, Any]:
"""
Load the JSON weights file produced by Stage 12A.
Expected structure:
{
"global": { "rules": 0.7, "extended": 0.2, "llm": 0.1 },
"fields": {
"DNase": {
"rules": 0.95,
"extended": 0.03,
"llm": 0.02,
"support": 123
},
...
},
"meta": { ... }
}
If the file is missing or broken, fall back to empty dict,
triggering equal-weight behaviour later.
"""
if not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
obj = json.load(f)
return obj if isinstance(obj, dict) else {}
except Exception:
return {}
FIELD_WEIGHTS_RAW: Dict[str, Any] = _load_field_weights()
HAS_WEIGHTS_FILE: bool = bool(FIELD_WEIGHTS_RAW)
def _normalise_scores(scores: Dict[str, float]) -> Dict[str, float]:
"""
Normalise parser -> score into weights summing to 1.
If all scores are zero or dict is empty, return equal weights.
"""
cleaned = {k: max(0.0, float(v)) for k, v in scores.items()}
total = sum(cleaned.values())
if total <= 0:
n = len(cleaned) or 1
return {k: 1.0 / n for k in cleaned}
return {k: v / total for k, v in cleaned.items()}
def _get_base_weights_for_parsers(include_llm: bool) -> Dict[str, float]:
"""
Equal-weight distribution across available parsers.
Used when no learned weights are available.
"""
parsers = ["rules", "extended"]
if include_llm:
parsers.append("llm")
n = len(parsers) or 1
return {p: 1.0 / n for p in parsers}
def _get_weights_for_field(field_name: str, include_llm: bool) -> Dict[str, float]:
"""
Get weights for a specific field.
Priority:
1) FIELD_WEIGHTS_RAW["fields"][field_name]
2) FIELD_WEIGHTS_RAW["global"]
3) Equal weights
Always:
- Drop 'llm' if include_llm == False
- Normalise
"""
if not FIELD_WEIGHTS_RAW:
return _normalise_scores(_get_base_weights_for_parsers(include_llm))
fields_block = FIELD_WEIGHTS_RAW.get("fields", {}) or {}
global_block = FIELD_WEIGHTS_RAW.get("global", {}) or {}
raw: Dict[str, float] = {}
field_entry = fields_block.get(field_name)
if isinstance(field_entry, dict):
for k, v in field_entry.items():
if k in ("rules", "extended", "llm"):
raw[k] = float(v)
if not raw and isinstance(global_block, dict):
for k, v in global_block.items():
if k in ("rules", "extended", "llm"):
raw[k] = float(v)
if not raw:
raw = _get_base_weights_for_parsers(include_llm)
if not include_llm:
raw.pop("llm", None)
if not raw:
raw = _get_base_weights_for_parsers(include_llm=False)
return _normalise_scores(raw)
# ------------------------------------------------------------
# Fusion logic
# ------------------------------------------------------------
def _clean_pred_value(val: Optional[str]) -> Optional[str]:
"""
Treat None, empty string, or explicit "Unknown" as missing.
"""
if val is None:
return None
s = str(val).strip()
if not s:
return None
if s.lower() == UNKNOWN.lower():
return None
return s
def parse_text_fused(text: str, use_llm: Optional[bool] = None) -> Dict[str, Any]:
"""
Main tri-parser fusion entrypoint.
Parameters
----------
text : str
use_llm : bool or None
True -> include LLM
False -> exclude LLM
None -> include if available
Returns
-------
Dict[str, Any]
Full fusion output including votes and per-parser breakdowns.
"""
original = text or ""
include_llm = HAS_LLM if use_llm is None else bool(use_llm)
rules_out = parse_text_rules(original) or {}
ext_out = parse_text_extended(original) or {}
rules_fields = dict(rules_out.get("parsed_fields", {}))
ext_fields = dict(ext_out.get("parsed_fields", {}))
llm_fields: Dict[str, Any] = {}
if include_llm and parse_text_llm is not None:
try:
merged_existing = {}
merged_existing.update(rules_fields)
merged_existing.update(ext_fields)
llm_out = parse_text_llm(original, existing_fields=merged_existing)
if isinstance(llm_out, dict):
if "parsed_fields" in llm_out:
llm_fields = dict(llm_out.get("parsed_fields", {}))
else:
llm_fields = {str(k): v for k, v in llm_out.items()}
except Exception:
llm_fields = {}
else:
include_llm = False
by_parser: Dict[str, Dict[str, Any]] = {
"rules": rules_fields,
"extended": ext_fields,
"llm": llm_fields if include_llm else {},
}
candidate_fields = (
set(rules_fields.keys())
| set(ext_fields.keys())
| set(llm_fields.keys())
)
fused_fields: Dict[str, Any] = {}
votes_debug: Dict[str, Any] = {}
for field in sorted(candidate_fields):
weights = _get_weights_for_field(field, include_llm)
parser_preds: Dict[str, Optional[str]] = {
"rules": _clean_pred_value(rules_fields.get(field)),
"extended": _clean_pred_value(ext_fields.get(field)),
"llm": _clean_pred_value(llm_fields.get(field)) if include_llm else None,
}
per_parser_info: Dict[str, Any] = {}
value_scores: Dict[str, float] = {}
for parser_name in PARSER_ORDER:
if parser_name == "llm" and not include_llm:
continue
pred = parser_preds.get(parser_name)
w = float(weights.get(parser_name, 0.0))
per_parser_info[parser_name] = {
"value": pred if pred is not None else UNKNOWN,
"weight": w,
}
if pred is not None:
value_scores[pred] = value_scores.get(pred, 0.0) + w
if not value_scores:
fused_value = UNKNOWN
else:
max_score = max(value_scores.values())
best_values = [v for v, s in value_scores.items() if s == max_score]
if len(best_values) == 1:
fused_value = best_values[0]
else:
fused_value = best_values[0]
for parser_name in PARSER_ORDER:
if parser_name == "llm" and not include_llm:
continue
if parser_preds.get(parser_name) in best_values:
fused_value = parser_preds[parser_name] # type: ignore
break
fused_fields[field] = fused_value
votes_debug[field] = {
"per_parser": per_parser_info,
"summed": value_scores,
"chosen": fused_value,
}
weights_meta = {
"has_weights_file": HAS_WEIGHTS_FILE,
"weights_path": FIELD_WEIGHTS_PATH,
"meta": FIELD_WEIGHTS_RAW.get("meta", {}) if HAS_WEIGHTS_FILE else {},
}
return {
"fused_fields": fused_fields,
"by_parser": by_parser,
"votes": votes_debug,
"weights_meta": weights_meta,
}