RFTSystems's picture
Update drp/diff.py
9ff9da6 verified
import difflib
from typing import Any, Dict, List, Optional
from .bundle import load_bundle
def _normalize_for_compare(x: Any) -> Any:
if isinstance(x, dict):
return {k: _normalize_for_compare(x[k]) for k in sorted(x.keys())}
if isinstance(x, list):
return [_normalize_for_compare(v) for v in x]
return x
def _event_core(ev: Dict[str, Any]) -> Any:
return _normalize_for_compare({k: ev.get(k) for k in ("kind", "step", "payload")})
def build_alignment(A_events: List[Dict[str, Any]], B_events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
n = max(len(A_events), len(B_events))
for i in range(n):
a = A_events[i] if i < len(A_events) else None
b = B_events[i] if i < len(B_events) else None
if a is None:
status = "missing_in_A"
elif b is None:
status = "missing_in_B"
else:
status = "same" if _event_core(a) == _event_core(b) else "diff"
rows.append(
{
"i": i,
"status": status,
"kind_a": a.get("kind") if a else None,
"step_a": a.get("step") if a else None,
"kind_b": b.get("kind") if b else None,
"step_b": b.get("step") if b else None,
}
)
return rows
def _json_diff(a: Any, b: Any, path: str = "") -> List[Dict[str, Any]]:
diffs: List[Dict[str, Any]] = []
if type(a) != type(b):
diffs.append({"path": path or "$", "kind": "type", "a": str(type(a)), "b": str(type(b))})
return diffs
if isinstance(a, dict):
akeys = set(a.keys())
bkeys = set(b.keys())
for k in sorted(akeys - bkeys):
diffs.append({"path": f"{path}.{k}" if path else k, "kind": "removed", "a": a[k], "b": None})
for k in sorted(bkeys - akeys):
diffs.append({"path": f"{path}.{k}" if path else k, "kind": "added", "a": None, "b": b[k]})
for k in sorted(akeys & bkeys):
diffs.extend(_json_diff(a[k], b[k], f"{path}.{k}" if path else k))
return diffs
if isinstance(a, list):
n = max(len(a), len(b))
for i in range(n):
pa = a[i] if i < len(a) else None
pb = b[i] if i < len(b) else None
if i >= len(a):
diffs.append({"path": f"{path}[{i}]", "kind": "added", "a": None, "b": pb})
elif i >= len(b):
diffs.append({"path": f"{path}[{i}]", "kind": "removed", "a": pa, "b": None})
else:
diffs.extend(_json_diff(pa, pb, f"{path}[{i}]"))
return diffs
if a != b:
diffs.append({"path": path or "$", "kind": "value", "a": a, "b": b})
return diffs
def _classify_divergence(kind_a: Optional[str], kind_b: Optional[str]) -> str:
if kind_a != kind_b:
return "control-flow"
if kind_a in ("tool_call", "tool_result"):
return "tool"
if kind_a in ("memory_write", "memory_read"):
return "memory"
if kind_a in ("llm_sample", "llm_call"):
return "sampling"
if kind_a in ("guardrail",):
return "governance"
return "state"
def _text_delta(a: str, b: str) -> str:
a_lines = a.splitlines()
b_lines = b.splitlines()
diff = difflib.unified_diff(a_lines, b_lines, fromfile="A", tofile="B", lineterm="")
return "\n".join(diff)
def _extract_final_reward(events: List[Dict[str, Any]]) -> Optional[float]:
"""
Looks for last state_snapshot payload containing:
- payload.reward_total
- payload.metrics.reward_total
"""
for ev in reversed(events):
if ev.get("kind") != "state_snapshot":
continue
p = ev.get("payload", {}) or {}
if isinstance(p, dict):
rt = p.get("reward_total")
if isinstance(rt, (int, float)):
return float(rt)
m = p.get("metrics")
if isinstance(m, dict):
rt2 = m.get("reward_total")
if isinstance(rt2, (int, float)):
return float(rt2)
return None
def _event_link(manifest: Dict[str, Any], i: int) -> Optional[str]:
"""
Optional deep-link generation.
Supported:
- manifest.replay.base_url + manifest.replay.pattern with {run_id} and {i}
- manifest.run_url + ?i={i}
"""
run_id = manifest.get("run_id")
replay = manifest.get("replay")
if isinstance(replay, dict):
base = replay.get("base_url")
pattern = replay.get("pattern", "")
if isinstance(base, str) and isinstance(pattern, str) and run_id:
try:
return base.rstrip("/") + pattern.format(run_id=run_id, i=i)
except Exception:
return None
run_url = manifest.get("run_url")
if isinstance(run_url, str) and run_url:
# append i in a minimal, non-destructive way
joiner = "&" if "?" in run_url else "?"
return f"{run_url}{joiner}i={i}"
return None
def diff_bundles(zip_a: str, zip_b: str) -> Dict[str, Any]:
A = load_bundle(zip_a)
B = load_bundle(zip_b)
ea = A.events
eb = B.events
alignment = build_alignment(ea, eb)
# first divergence index (including length mismatch)
first_div: Optional[int] = None
for row in alignment:
if row["status"] != "same":
first_div = row["i"]
break
# diff details (per index where both exist and differ)
per_event: List[Dict[str, Any]] = []
n = min(len(ea), len(eb))
for i in range(n):
na = _event_core(ea[i])
nb = _event_core(eb[i])
if na == nb:
continue
diffs = _json_diff(na, nb)
item: Dict[str, Any] = {
"i": i,
"kind_a": ea[i].get("kind"),
"kind_b": eb[i].get("kind"),
"step_a": ea[i].get("step"),
"step_b": eb[i].get("step"),
"class": _classify_divergence(ea[i].get("kind"), eb[i].get("kind")),
"diffs": diffs[:200],
"link_a": _event_link(A.manifest, i),
"link_b": _event_link(B.manifest, i),
}
ta = (ea[i].get("payload", {}) or {}).get("text")
tb = (eb[i].get("payload", {}) or {}).get("text")
if isinstance(ta, str) and isinstance(tb, str) and ta != tb:
item["text_unified_diff"] = _text_delta(ta, tb)[:20000]
per_event.append(item)
diff_count = sum(1 for r in alignment if r["status"] == "diff")
missing_count = sum(1 for r in alignment if r["status"] in ("missing_in_A", "missing_in_B"))
ra = _extract_final_reward(ea)
rb = _extract_final_reward(eb)
reward_delta = (rb - ra) if (ra is not None and rb is not None) else None
# class counts
counts: Dict[str, int] = {}
for item in per_event:
c = item["class"]
counts[c] = counts.get(c, 0) + 1
summary: Dict[str, Any] = {
"run_a": A.manifest.get("run_id"),
"run_b": B.manifest.get("run_id"),
"framework_a": A.manifest.get("framework"),
"framework_b": B.manifest.get("framework"),
"model_a": A.manifest.get("model_id"),
"model_b": B.manifest.get("model_id"),
"events_a": len(ea),
"events_b": len(eb),
"first_divergence_index": first_div,
"identical_until_index": first_div, # same semantic, explicit name
"diff_event_count": diff_count,
"missing_event_count": missing_count,
"final_reward_a": ra,
"final_reward_b": rb,
"final_reward_delta": reward_delta,
"run_link_a": _event_link(A.manifest, 0),
"run_link_b": _event_link(B.manifest, 0),
}
return {
"summary": summary,
"class_counts": counts,
"alignment": alignment,
"differences": per_event[:400],
}