""" Parity check: does the trajectory-eval shallow clone produce the same polyglot-parsed graph + BERT features as the pre-existing big-machine clone (data_multilang), for the same commit? Runs on big where both clones exist. For each common (repo, commit) pair it encounters, it snapshots the working tree from *both* clones, canonical- hashes the graph structure + feature tensors, and reports match/mismatch. A match confirms: - git clone --filter=blob:none + checkout fetches the same file content as the original full clone - parse_repo_polyglot is deterministic w.r.t. the file tree (modulo rglob ordering — we sort before hashing) - BertTokenEmbedder is deterministic Usage (on big): python -m graphjepa.check_polyglot_parity \\ --traj-repos ./outputs/traj_real/repos \\ --multi-repos /raid/train/datasets/code-graph-v7/data_multilang \\ --n-pairs 4 If it picks a commit not present in the shallow clone's blobless ref set (some base_commits may need lazy blob fetch), the script does the fetch automatically via checkout. """ from __future__ import annotations import argparse import hashlib import json import subprocess import sys from pathlib import Path from typing import List, Optional, Tuple def run(cmd: List[str], cwd: Optional[Path] = None, check: bool = True): r = subprocess.run(cmd, cwd=str(cwd) if cwd else None, capture_output=True, text=True) if check and r.returncode != 0: raise RuntimeError(f'{" ".join(cmd)} failed: {r.stderr[-400:]}') return r def list_commits(repo_dir: Path, n: int = 20) -> List[str]: r = run(['git', 'log', '--format=%H', '-n', str(n)], cwd=repo_dir) return r.stdout.split() def checkout(repo_dir: Path, sha: str) -> bool: run(['git', 'reset', '--hard', '-q'], cwd=repo_dir, check=False) run(['git', 'clean', '-fdx', '-q'], cwd=repo_dir, check=False) r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir, check=False) if r.returncode != 0: # Try fetching the ref run(['git', 'fetch', '-q', 'origin', sha], cwd=repo_dir, check=False) r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir, check=False) return r.returncode == 0 def canonical_hash(graph, features) -> Tuple[str, str, dict]: """Deterministic hash of (graph structure, feature tensors). Sorts node IDs so walk order doesn't matter. Returns (graph_hash, feature_hash, stats_dict). """ import torch # Nodes: sort by id, hash (id, kind, content, type_description). h_nodes = hashlib.sha256() node_items = sorted(graph.nodes.items()) for nid, n in node_items: h_nodes.update(nid.encode()) h_nodes.update(b'\x00') h_nodes.update(getattr(n.kind, 'value', str(n.kind)).encode()) h_nodes.update(b'\x00') h_nodes.update((n.content or '').encode()) h_nodes.update(b'\x00') h_nodes.update((n.type_description or '').encode()) h_nodes.update(b'\x01') # Edges: sort by (src, dst, kind). edge_keys = sorted( (e.src, e.dst, getattr(e.kind, 'value', str(e.kind))) for e in graph.edges.values() ) for src, dst, k in edge_keys: h_nodes.update(f'E|{src}|{dst}|{k}|'.encode()) graph_hash = h_nodes.hexdigest() # Feature tensors: for each kind in deterministic order, hash # (sorted_ids, content_sum, type_sum, content_first_vec, type_first_vec). h_feats = hashlib.sha256() for kind, d in sorted((k, v) for k, v in features.items() if v is not None): kind_str = getattr(kind, 'value', str(kind)) h_feats.update(kind_str.encode()) h_feats.update(b'\x00') ids = list(d['ids']) sort_idx = sorted(range(len(ids)), key=lambda i: ids[i]) content = d['content'][sort_idx] if sort_idx else d['content'] typev = d['type'][sort_idx] if sort_idx else d['type'] sorted_ids = [ids[i] for i in sort_idx] for sid in sorted_ids: h_feats.update(sid.encode()); h_feats.update(b'\x00') # Digest feature tensors numerically with fixed precision so # hashes match across float ops that might differ in trailing ULP. content_q = (content * 1e5).round().to(torch.int64) typev_q = (typev * 1e5).round().to(torch.int64) h_feats.update(content_q.cpu().numpy().tobytes()) h_feats.update(typev_q.cpu().numpy().tobytes()) feat_hash = h_feats.hexdigest() stats = { 'n_nodes': len(graph.nodes), 'n_edges': len(graph.edges), 'n_feat_kinds': sum(1 for v in features.values() if v is not None), 'feat_dim': next((v['content'].shape[1] for v in features.values() if v is not None), None), } return graph_hash, feat_hash, stats def snapshot(repo_dir: Path, embedder) -> Tuple[str, str, dict]: from graphjepa.trajectory_pipeline import snapshot_working_tree g, feats = snapshot_working_tree(repo_dir, embedder, verbose=False) return canonical_hash(g, feats) # Mapping from trajectory-eval repo dirname → data_multilang subpath. # traj repos: django__django; data_multilang: python/django _REPO_DIR_MAP = { 'django__django': ('python', 'django'), 'sympy__sympy': ('python', 'sympy'), 'sphinx-doc__sphinx': ('python', 'sphinx'), 'matplotlib__matplotlib': ('python', 'matplotlib'), 'scikit-learn__scikit-learn': ('python', 'scikit-learn'), 'astropy__astropy': ('python', 'astropy'), 'pydata__xarray': ('python', 'xarray'), 'pytest-dev__pytest': ('python', 'pytest'), 'pylint-dev__pylint': ('python', 'pylint'), 'psf__requests': ('python', 'requests'), 'mwaskom__seaborn': ('python', 'seaborn'), 'pallets__flask': ('python', 'flask'), } def find_pairs(traj_root: Path, multi_root: Path) -> List[Tuple[str, Path, Path]]: pairs = [] if not traj_root.is_dir(): return pairs for name, (lang, mname) in _REPO_DIR_MAP.items(): tpath = traj_root / name mpath = multi_root / lang / mname if tpath.is_dir() and mpath.is_dir(): pairs.append((name, tpath, mpath)) return pairs def main(): p = argparse.ArgumentParser() p.add_argument('--traj-repos', required=True, help='outputs/traj_real/repos dir from the transfer bundle') p.add_argument('--multi-repos', required=True, help='data_multilang dir used to build cache_v7') p.add_argument('--n-pairs', type=int, default=3, help='Number of (repo, commit) pairs to test') p.add_argument('--output', default=None, help='Write a JSON report here') args = p.parse_args() traj_root = Path(args.traj_repos) multi_root = Path(args.multi_repos) pairs = find_pairs(traj_root, multi_root) if not pairs: print(f'[parity] no common repos found under {traj_root} and ' f'{multi_root}'); sys.exit(1) print(f'[parity] {len(pairs)} repo pairs available:') for n, t, m in pairs: print(f' {n:30s} traj={t} multi={m}') # For each pair, pick a commit that exists in both. HEAD of the # multi clone is a safe default since that clone has full history. tests = [] for name, tpath, mpath in pairs[:args.n_pairs]: mcommits = list_commits(mpath, n=5) if not mcommits: print(f'[parity] {name}: no commits in multi clone, skip') continue tests.append((name, tpath, mpath, mcommits[0])) # Import embedder once — BERT load is slow. from graphjepa.features import BertTokenEmbedder print('\n[parity] loading BERT embedder ...') embedder = BertTokenEmbedder(device='cpu') results = [] for name, tpath, mpath, sha in tests: print(f'\n[parity] === {name} @ {sha[:10]} ===') print(f' checkout traj clone ...') if not checkout(tpath, sha): print(f' [parity] traj clone cannot reach {sha[:10]}; skip') results.append({'repo': name, 'sha': sha, 'error': 'traj_checkout_failed'}) continue print(f' checkout multi clone ...') if not checkout(mpath, sha): print(f' [parity] multi clone cannot reach {sha[:10]}; skip') results.append({'repo': name, 'sha': sha, 'error': 'multi_checkout_failed'}) continue print(f' snapshotting traj clone ...') tg, tf, tstats = snapshot(tpath, embedder) print(f' snapshotting multi clone ...') mg, mf, mstats = snapshot(mpath, embedder) match_g = tg == mg match_f = tf == mf print(f' graph hash traj={tg[:12]} multi={mg[:12]} ' f'{"MATCH" if match_g else "MISMATCH"}') print(f' feature hash traj={tf[:12]} multi={mf[:12]} ' f'{"MATCH" if match_f else "MISMATCH"}') print(f' stats traj={tstats} multi={mstats}') results.append({ 'repo': name, 'sha': sha, 'graph_match': match_g, 'feature_match': match_f, 'traj_stats': tstats, 'multi_stats': mstats, }) print('\n' + '=' * 60) n_g = sum(1 for r in results if r.get('graph_match')) n_f = sum(1 for r in results if r.get('feature_match')) print(f'graph parity: {n_g}/{len(results)} matched') print(f'feature parity: {n_f}/{len(results)} matched') print('=' * 60) if args.output: with open(args.output, 'w') as f: json.dump(results, f, indent=2) print(f'[parity] report saved: {args.output}') sys.exit(0 if (n_g == n_f == len(results) and results) else 1) if __name__ == '__main__': main()