| """ |
| Parity check: does the trajectory-eval shallow clone produce the same |
| polyglot-parsed graph + BERT features as the pre-existing big-machine |
| clone (data_multilang), for the same commit? |
| |
| Runs on big where both clones exist. For each common (repo, commit) pair |
| it encounters, it snapshots the working tree from *both* clones, canonical- |
| hashes the graph structure + feature tensors, and reports match/mismatch. |
| |
| A match confirms: |
| - git clone --filter=blob:none + checkout fetches the same file content |
| as the original full clone |
| - parse_repo_polyglot is deterministic w.r.t. the file tree (modulo |
| rglob ordering — we sort before hashing) |
| - BertTokenEmbedder is deterministic |
| |
| Usage (on big): |
| python -m graphjepa.check_polyglot_parity \\ |
| --traj-repos ./outputs/traj_real/repos \\ |
| --multi-repos /raid/train/datasets/code-graph-v7/data_multilang \\ |
| --n-pairs 4 |
| |
| If it picks a commit not present in the shallow clone's blobless ref set |
| (some base_commits may need lazy blob fetch), the script does the fetch |
| automatically via checkout. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| import subprocess |
| import sys |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
|
|
| def run(cmd: List[str], cwd: Optional[Path] = None, check: bool = True): |
| r = subprocess.run(cmd, cwd=str(cwd) if cwd else None, |
| capture_output=True, text=True) |
| if check and r.returncode != 0: |
| raise RuntimeError(f'{" ".join(cmd)} failed: {r.stderr[-400:]}') |
| return r |
|
|
|
|
| def list_commits(repo_dir: Path, n: int = 20) -> List[str]: |
| r = run(['git', 'log', '--format=%H', '-n', str(n)], cwd=repo_dir) |
| return r.stdout.split() |
|
|
|
|
| def checkout(repo_dir: Path, sha: str) -> bool: |
| run(['git', 'reset', '--hard', '-q'], cwd=repo_dir, check=False) |
| run(['git', 'clean', '-fdx', '-q'], cwd=repo_dir, check=False) |
| r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir, check=False) |
| if r.returncode != 0: |
| |
| run(['git', 'fetch', '-q', 'origin', sha], cwd=repo_dir, check=False) |
| r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir, |
| check=False) |
| return r.returncode == 0 |
|
|
|
|
| def canonical_hash(graph, features) -> Tuple[str, str, dict]: |
| """Deterministic hash of (graph structure, feature tensors). |
| Sorts node IDs so walk order doesn't matter. |
| Returns (graph_hash, feature_hash, stats_dict). |
| """ |
| import torch |
|
|
| |
| h_nodes = hashlib.sha256() |
| node_items = sorted(graph.nodes.items()) |
| for nid, n in node_items: |
| h_nodes.update(nid.encode()) |
| h_nodes.update(b'\x00') |
| h_nodes.update(getattr(n.kind, 'value', str(n.kind)).encode()) |
| h_nodes.update(b'\x00') |
| h_nodes.update((n.content or '').encode()) |
| h_nodes.update(b'\x00') |
| h_nodes.update((n.type_description or '').encode()) |
| h_nodes.update(b'\x01') |
|
|
| |
| edge_keys = sorted( |
| (e.src, e.dst, getattr(e.kind, 'value', str(e.kind))) |
| for e in graph.edges.values() |
| ) |
| for src, dst, k in edge_keys: |
| h_nodes.update(f'E|{src}|{dst}|{k}|'.encode()) |
|
|
| graph_hash = h_nodes.hexdigest() |
|
|
| |
| |
| h_feats = hashlib.sha256() |
| for kind, d in sorted((k, v) for k, v in features.items() if v is not None): |
| kind_str = getattr(kind, 'value', str(kind)) |
| h_feats.update(kind_str.encode()) |
| h_feats.update(b'\x00') |
| ids = list(d['ids']) |
| sort_idx = sorted(range(len(ids)), key=lambda i: ids[i]) |
| content = d['content'][sort_idx] if sort_idx else d['content'] |
| typev = d['type'][sort_idx] if sort_idx else d['type'] |
| sorted_ids = [ids[i] for i in sort_idx] |
| for sid in sorted_ids: |
| h_feats.update(sid.encode()); h_feats.update(b'\x00') |
| |
| |
| content_q = (content * 1e5).round().to(torch.int64) |
| typev_q = (typev * 1e5).round().to(torch.int64) |
| h_feats.update(content_q.cpu().numpy().tobytes()) |
| h_feats.update(typev_q.cpu().numpy().tobytes()) |
|
|
| feat_hash = h_feats.hexdigest() |
|
|
| stats = { |
| 'n_nodes': len(graph.nodes), |
| 'n_edges': len(graph.edges), |
| 'n_feat_kinds': sum(1 for v in features.values() if v is not None), |
| 'feat_dim': next((v['content'].shape[1] for v in features.values() |
| if v is not None), None), |
| } |
| return graph_hash, feat_hash, stats |
|
|
|
|
| def snapshot(repo_dir: Path, embedder) -> Tuple[str, str, dict]: |
| from graphjepa.trajectory_pipeline import snapshot_working_tree |
| g, feats = snapshot_working_tree(repo_dir, embedder, verbose=False) |
| return canonical_hash(g, feats) |
|
|
|
|
| |
| |
| _REPO_DIR_MAP = { |
| 'django__django': ('python', 'django'), |
| 'sympy__sympy': ('python', 'sympy'), |
| 'sphinx-doc__sphinx': ('python', 'sphinx'), |
| 'matplotlib__matplotlib': ('python', 'matplotlib'), |
| 'scikit-learn__scikit-learn': ('python', 'scikit-learn'), |
| 'astropy__astropy': ('python', 'astropy'), |
| 'pydata__xarray': ('python', 'xarray'), |
| 'pytest-dev__pytest': ('python', 'pytest'), |
| 'pylint-dev__pylint': ('python', 'pylint'), |
| 'psf__requests': ('python', 'requests'), |
| 'mwaskom__seaborn': ('python', 'seaborn'), |
| 'pallets__flask': ('python', 'flask'), |
| } |
|
|
|
|
| def find_pairs(traj_root: Path, multi_root: Path) -> List[Tuple[str, Path, Path]]: |
| pairs = [] |
| if not traj_root.is_dir(): |
| return pairs |
| for name, (lang, mname) in _REPO_DIR_MAP.items(): |
| tpath = traj_root / name |
| mpath = multi_root / lang / mname |
| if tpath.is_dir() and mpath.is_dir(): |
| pairs.append((name, tpath, mpath)) |
| return pairs |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--traj-repos', required=True, |
| help='outputs/traj_real/repos dir from the transfer bundle') |
| p.add_argument('--multi-repos', required=True, |
| help='data_multilang dir used to build cache_v7') |
| p.add_argument('--n-pairs', type=int, default=3, |
| help='Number of (repo, commit) pairs to test') |
| p.add_argument('--output', default=None, |
| help='Write a JSON report here') |
| args = p.parse_args() |
|
|
| traj_root = Path(args.traj_repos) |
| multi_root = Path(args.multi_repos) |
|
|
| pairs = find_pairs(traj_root, multi_root) |
| if not pairs: |
| print(f'[parity] no common repos found under {traj_root} and ' |
| f'{multi_root}'); sys.exit(1) |
| print(f'[parity] {len(pairs)} repo pairs available:') |
| for n, t, m in pairs: |
| print(f' {n:30s} traj={t} multi={m}') |
|
|
| |
| |
| tests = [] |
| for name, tpath, mpath in pairs[:args.n_pairs]: |
| mcommits = list_commits(mpath, n=5) |
| if not mcommits: |
| print(f'[parity] {name}: no commits in multi clone, skip') |
| continue |
| tests.append((name, tpath, mpath, mcommits[0])) |
|
|
| |
| from graphjepa.features import BertTokenEmbedder |
| print('\n[parity] loading BERT embedder ...') |
| embedder = BertTokenEmbedder(device='cpu') |
|
|
| results = [] |
| for name, tpath, mpath, sha in tests: |
| print(f'\n[parity] === {name} @ {sha[:10]} ===') |
|
|
| print(f' checkout traj clone ...') |
| if not checkout(tpath, sha): |
| print(f' [parity] traj clone cannot reach {sha[:10]}; skip') |
| results.append({'repo': name, 'sha': sha, 'error': 'traj_checkout_failed'}) |
| continue |
| print(f' checkout multi clone ...') |
| if not checkout(mpath, sha): |
| print(f' [parity] multi clone cannot reach {sha[:10]}; skip') |
| results.append({'repo': name, 'sha': sha, 'error': 'multi_checkout_failed'}) |
| continue |
|
|
| print(f' snapshotting traj clone ...') |
| tg, tf, tstats = snapshot(tpath, embedder) |
| print(f' snapshotting multi clone ...') |
| mg, mf, mstats = snapshot(mpath, embedder) |
|
|
| match_g = tg == mg |
| match_f = tf == mf |
| print(f' graph hash traj={tg[:12]} multi={mg[:12]} ' |
| f'{"MATCH" if match_g else "MISMATCH"}') |
| print(f' feature hash traj={tf[:12]} multi={mf[:12]} ' |
| f'{"MATCH" if match_f else "MISMATCH"}') |
| print(f' stats traj={tstats} multi={mstats}') |
| results.append({ |
| 'repo': name, 'sha': sha, |
| 'graph_match': match_g, 'feature_match': match_f, |
| 'traj_stats': tstats, 'multi_stats': mstats, |
| }) |
|
|
| print('\n' + '=' * 60) |
| n_g = sum(1 for r in results if r.get('graph_match')) |
| n_f = sum(1 for r in results if r.get('feature_match')) |
| print(f'graph parity: {n_g}/{len(results)} matched') |
| print(f'feature parity: {n_f}/{len(results)} matched') |
| print('=' * 60) |
|
|
| if args.output: |
| with open(args.output, 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f'[parity] report saved: {args.output}') |
|
|
| sys.exit(0 if (n_g == n_f == len(results) and results) else 1) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|