RFTSystems commited on
Commit
9ff9da6
·
verified ·
1 Parent(s): 2976335

Update drp/diff.py

Browse files
Files changed (1) hide show
  1. drp/diff.py +149 -56
drp/diff.py CHANGED
@@ -1,11 +1,10 @@
1
  import difflib
2
- from typing import Any, Dict, List, Optional, Tuple
3
 
4
- from .bundle import Bundle, load_bundle
5
 
6
 
7
  def _normalize_for_compare(x: Any) -> Any:
8
- # Avoid false diffs from ordering
9
  if isinstance(x, dict):
10
  return {k: _normalize_for_compare(x[k]) for k in sorted(x.keys())}
11
  if isinstance(x, list):
@@ -13,11 +12,38 @@ def _normalize_for_compare(x: Any) -> Any:
13
  return x
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def _json_diff(a: Any, b: Any, path: str = "") -> List[Dict[str, Any]]:
17
- """
18
- Small recursive diff (no heavy deps).
19
- Emits list of {path, a, b, kind}.
20
- """
21
  diffs: List[Dict[str, Any]] = []
22
 
23
  if type(a) != type(b):
@@ -36,7 +62,6 @@ def _json_diff(a: Any, b: Any, path: str = "") -> List[Dict[str, Any]]:
36
  return diffs
37
 
38
  if isinstance(a, list):
39
- # list diff by index (simple)
40
  n = max(len(a), len(b))
41
  for i in range(n):
42
  pa = a[i] if i < len(a) else None
@@ -54,18 +79,16 @@ def _json_diff(a: Any, b: Any, path: str = "") -> List[Dict[str, Any]]:
54
  return diffs
55
 
56
 
57
- def _classify_divergence(ev_a: Dict[str, Any], ev_b: Dict[str, Any]) -> str:
58
- ka = ev_a.get("kind")
59
- kb = ev_b.get("kind")
60
- if ka != kb:
61
  return "control-flow"
62
- if ka in ("tool_call", "tool_result"):
63
  return "tool"
64
- if ka in ("memory_write", "memory_read"):
65
  return "memory"
66
- if ka in ("llm_sample", "llm_call"):
67
  return "sampling"
68
- if ka in ("guardrail",):
69
  return "governance"
70
  return "state"
71
 
@@ -77,48 +100,115 @@ def _text_delta(a: str, b: str) -> str:
77
  return "\n".join(diff)
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def diff_bundles(zip_a: str, zip_b: str) -> Dict[str, Any]:
81
  A = load_bundle(zip_a)
82
  B = load_bundle(zip_b)
83
 
84
  ea = A.events
85
  eb = B.events
86
- n = min(len(ea), len(eb))
87
 
 
 
 
88
  first_div: Optional[int] = None
89
- per_event: List[Dict[str, Any]] = []
 
 
 
90
 
 
 
 
91
  for i in range(n):
92
- na = _normalize_for_compare({k: ea[i].get(k) for k in ("kind", "step", "payload")})
93
- nb = _normalize_for_compare({k: eb[i].get(k) for k in ("kind", "step", "payload")})
94
- if na != nb and first_div is None:
95
- first_div = i
96
-
97
- if na != nb:
98
- diffs = _json_diff(na, nb)
99
- item = {
100
- "i": i,
101
- "step_a": ea[i].get("step"),
102
- "step_b": eb[i].get("step"),
103
- "kind_a": ea[i].get("kind"),
104
- "kind_b": eb[i].get("kind"),
105
- "class": _classify_divergence(ea[i], eb[i]),
106
- "diffs": diffs[:200], # cap
107
- }
108
-
109
- # Optional friendly text diff if payload has 'text'
110
- ta = ea[i].get("payload", {}).get("text")
111
- tb = eb[i].get("payload", {}).get("text")
112
- if isinstance(ta, str) and isinstance(tb, str) and ta != tb:
113
- item["text_unified_diff"] = _text_delta(ta, tb)[:20000]
114
-
115
- per_event.append(item)
116
-
117
- # length mismatch
118
- if len(ea) != len(eb):
119
- first_div = first_div if first_div is not None else n
 
 
 
 
 
 
 
 
 
120
 
121
- summary = {
122
  "run_a": A.manifest.get("run_id"),
123
  "run_b": B.manifest.get("run_id"),
124
  "framework_a": A.manifest.get("framework"),
@@ -128,16 +218,19 @@ def diff_bundles(zip_a: str, zip_b: str) -> Dict[str, Any]:
128
  "events_a": len(ea),
129
  "events_b": len(eb),
130
  "first_divergence_index": first_div,
 
 
 
 
 
 
 
 
131
  }
132
 
133
- # simple counts by class
134
- counts: Dict[str, int] = {}
135
- for item in per_event:
136
- counts[item["class"]] = counts.get(item["class"], 0) + 1
137
-
138
- out = {
139
  "summary": summary,
140
  "class_counts": counts,
141
- "differences": per_event[:400], # cap for UI
142
- }
143
- return out
 
1
  import difflib
2
+ from typing import Any, Dict, List, Optional
3
 
4
+ from .bundle import load_bundle
5
 
6
 
7
  def _normalize_for_compare(x: Any) -> Any:
 
8
  if isinstance(x, dict):
9
  return {k: _normalize_for_compare(x[k]) for k in sorted(x.keys())}
10
  if isinstance(x, list):
 
12
  return x
13
 
14
 
15
+ def _event_core(ev: Dict[str, Any]) -> Any:
16
+ return _normalize_for_compare({k: ev.get(k) for k in ("kind", "step", "payload")})
17
+
18
+
19
+ def build_alignment(A_events: List[Dict[str, Any]], B_events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
20
+ rows: List[Dict[str, Any]] = []
21
+ n = max(len(A_events), len(B_events))
22
+ for i in range(n):
23
+ a = A_events[i] if i < len(A_events) else None
24
+ b = B_events[i] if i < len(B_events) else None
25
+
26
+ if a is None:
27
+ status = "missing_in_A"
28
+ elif b is None:
29
+ status = "missing_in_B"
30
+ else:
31
+ status = "same" if _event_core(a) == _event_core(b) else "diff"
32
+
33
+ rows.append(
34
+ {
35
+ "i": i,
36
+ "status": status,
37
+ "kind_a": a.get("kind") if a else None,
38
+ "step_a": a.get("step") if a else None,
39
+ "kind_b": b.get("kind") if b else None,
40
+ "step_b": b.get("step") if b else None,
41
+ }
42
+ )
43
+ return rows
44
+
45
+
46
  def _json_diff(a: Any, b: Any, path: str = "") -> List[Dict[str, Any]]:
 
 
 
 
47
  diffs: List[Dict[str, Any]] = []
48
 
49
  if type(a) != type(b):
 
62
  return diffs
63
 
64
  if isinstance(a, list):
 
65
  n = max(len(a), len(b))
66
  for i in range(n):
67
  pa = a[i] if i < len(a) else None
 
79
  return diffs
80
 
81
 
82
+ def _classify_divergence(kind_a: Optional[str], kind_b: Optional[str]) -> str:
83
+ if kind_a != kind_b:
 
 
84
  return "control-flow"
85
+ if kind_a in ("tool_call", "tool_result"):
86
  return "tool"
87
+ if kind_a in ("memory_write", "memory_read"):
88
  return "memory"
89
+ if kind_a in ("llm_sample", "llm_call"):
90
  return "sampling"
91
+ if kind_a in ("guardrail",):
92
  return "governance"
93
  return "state"
94
 
 
100
  return "\n".join(diff)
101
 
102
 
103
+ def _extract_final_reward(events: List[Dict[str, Any]]) -> Optional[float]:
104
+ """
105
+ Looks for last state_snapshot payload containing:
106
+ - payload.reward_total
107
+ - payload.metrics.reward_total
108
+ """
109
+ for ev in reversed(events):
110
+ if ev.get("kind") != "state_snapshot":
111
+ continue
112
+ p = ev.get("payload", {}) or {}
113
+ if isinstance(p, dict):
114
+ rt = p.get("reward_total")
115
+ if isinstance(rt, (int, float)):
116
+ return float(rt)
117
+ m = p.get("metrics")
118
+ if isinstance(m, dict):
119
+ rt2 = m.get("reward_total")
120
+ if isinstance(rt2, (int, float)):
121
+ return float(rt2)
122
+ return None
123
+
124
+
125
+ def _event_link(manifest: Dict[str, Any], i: int) -> Optional[str]:
126
+ """
127
+ Optional deep-link generation.
128
+ Supported:
129
+ - manifest.replay.base_url + manifest.replay.pattern with {run_id} and {i}
130
+ - manifest.run_url + ?i={i}
131
+ """
132
+ run_id = manifest.get("run_id")
133
+ replay = manifest.get("replay")
134
+
135
+ if isinstance(replay, dict):
136
+ base = replay.get("base_url")
137
+ pattern = replay.get("pattern", "")
138
+ if isinstance(base, str) and isinstance(pattern, str) and run_id:
139
+ try:
140
+ return base.rstrip("/") + pattern.format(run_id=run_id, i=i)
141
+ except Exception:
142
+ return None
143
+
144
+ run_url = manifest.get("run_url")
145
+ if isinstance(run_url, str) and run_url:
146
+ # append i in a minimal, non-destructive way
147
+ joiner = "&" if "?" in run_url else "?"
148
+ return f"{run_url}{joiner}i={i}"
149
+
150
+ return None
151
+
152
+
153
  def diff_bundles(zip_a: str, zip_b: str) -> Dict[str, Any]:
154
  A = load_bundle(zip_a)
155
  B = load_bundle(zip_b)
156
 
157
  ea = A.events
158
  eb = B.events
 
159
 
160
+ alignment = build_alignment(ea, eb)
161
+
162
+ # first divergence index (including length mismatch)
163
  first_div: Optional[int] = None
164
+ for row in alignment:
165
+ if row["status"] != "same":
166
+ first_div = row["i"]
167
+ break
168
 
169
+ # diff details (per index where both exist and differ)
170
+ per_event: List[Dict[str, Any]] = []
171
+ n = min(len(ea), len(eb))
172
  for i in range(n):
173
+ na = _event_core(ea[i])
174
+ nb = _event_core(eb[i])
175
+ if na == nb:
176
+ continue
177
+
178
+ diffs = _json_diff(na, nb)
179
+ item: Dict[str, Any] = {
180
+ "i": i,
181
+ "kind_a": ea[i].get("kind"),
182
+ "kind_b": eb[i].get("kind"),
183
+ "step_a": ea[i].get("step"),
184
+ "step_b": eb[i].get("step"),
185
+ "class": _classify_divergence(ea[i].get("kind"), eb[i].get("kind")),
186
+ "diffs": diffs[:200],
187
+ "link_a": _event_link(A.manifest, i),
188
+ "link_b": _event_link(B.manifest, i),
189
+ }
190
+
191
+ ta = (ea[i].get("payload", {}) or {}).get("text")
192
+ tb = (eb[i].get("payload", {}) or {}).get("text")
193
+ if isinstance(ta, str) and isinstance(tb, str) and ta != tb:
194
+ item["text_unified_diff"] = _text_delta(ta, tb)[:20000]
195
+
196
+ per_event.append(item)
197
+
198
+ diff_count = sum(1 for r in alignment if r["status"] == "diff")
199
+ missing_count = sum(1 for r in alignment if r["status"] in ("missing_in_A", "missing_in_B"))
200
+
201
+ ra = _extract_final_reward(ea)
202
+ rb = _extract_final_reward(eb)
203
+ reward_delta = (rb - ra) if (ra is not None and rb is not None) else None
204
+
205
+ # class counts
206
+ counts: Dict[str, int] = {}
207
+ for item in per_event:
208
+ c = item["class"]
209
+ counts[c] = counts.get(c, 0) + 1
210
 
211
+ summary: Dict[str, Any] = {
212
  "run_a": A.manifest.get("run_id"),
213
  "run_b": B.manifest.get("run_id"),
214
  "framework_a": A.manifest.get("framework"),
 
218
  "events_a": len(ea),
219
  "events_b": len(eb),
220
  "first_divergence_index": first_div,
221
+ "identical_until_index": first_div, # same semantic, explicit name
222
+ "diff_event_count": diff_count,
223
+ "missing_event_count": missing_count,
224
+ "final_reward_a": ra,
225
+ "final_reward_b": rb,
226
+ "final_reward_delta": reward_delta,
227
+ "run_link_a": _event_link(A.manifest, 0),
228
+ "run_link_b": _event_link(B.manifest, 0),
229
  }
230
 
231
+ return {
 
 
 
 
 
232
  "summary": summary,
233
  "class_counts": counts,
234
+ "alignment": alignment,
235
+ "differences": per_event[:400],
236
+ }