gnn_wm / sampled_data /frame_to_graph.py
EndeavourDD's picture
Add files using upload-large-folder tool
4ee0c8c verified
"""Extract the constraint graph for a given frame as a JSON dict.
Usage:
python frame_to_graph.py <frame_idx> [--episode <path>] [--output <path>]
Examples:
python frame_to_graph.py 42
python frame_to_graph.py 42 --episode session_0408_162129/episode_00
python frame_to_graph.py 42 --output my_graph.json
"""
import argparse
import json
from pathlib import Path
import numpy as np
from gnn_disassembly_loader import load_frame_data
def frame_to_graph_dict(episode_dir: Path, frame_idx: int) -> dict:
"""Build a self-contained graph dict for one frame.
Returns a dict with:
frame_idx: int
nodes: list of node dicts (id, type, color, visible, centroid_3d,
embedding_norm, mask_area)
edges: list of edge dicts (src, dst, is_locked)
"""
fd = load_frame_data(episode_dir, frame_idx)
graph_json = fd.graph
graph_dict = {
"frame_idx": frame_idx,
"nodes": [],
"edges": [],
}
# Product nodes
for comp in graph_json["components"]:
cid = comp["id"]
centroid_key = f"{cid}_centroid"
depth_valid_key = f"{cid}_depth_valid"
has_depth = (depth_valid_key in fd.depth_info
and int(fd.depth_info[depth_valid_key][0]) == 1)
graph_dict["nodes"].append({
"id": cid,
"type": comp["type"],
"color": comp["color"],
"visible": fd.visibility.get(cid, True),
"centroid_3d": fd.depth_info[centroid_key].tolist() if has_depth else [0, 0, 0],
"embedding_norm": float(np.linalg.norm(fd.embeddings[cid])) if cid in fd.embeddings else 0.0,
"mask_area": int(fd.depth_info[f"{cid}_area"][0]) if f"{cid}_area" in fd.depth_info else 0,
})
# Robot node
if fd.robot is not None:
graph_dict["nodes"].append({
"id": "robot",
"type": "robot",
"color": "#F5F5F5",
"visible": True,
"centroid_3d": fd.robot["centroid"].tolist(),
"embedding_norm": float(np.linalg.norm(fd.robot["embedding"])),
"mask_area": int(fd.robot["area"][0]),
})
# Constraint edges with resolved lock state
for edge in graph_json["edges"]:
constraint_key = f"{edge['src']}->{edge['dst']}"
is_locked = fd.constraints.get(constraint_key, True)
graph_dict["edges"].append({
"src": edge["src"],
"dst": edge["dst"],
"is_locked": is_locked,
})
return graph_dict
def main():
parser = argparse.ArgumentParser(description="Extract constraint graph for a frame as JSON.")
parser.add_argument("frame_idx", type=int, nargs="?", default=None,
help="Frame index (e.g. 42). Omit when using --all.")
parser.add_argument("--episode", type=str, default="session_0408_162129/episode_00",
help="Path to episode directory")
parser.add_argument("--output", type=str, default=None,
help="Single-frame: output JSON path. Batch: output directory.")
parser.add_argument("--all", action="store_true",
help="Process all frames in the episode.")
parser.add_argument("--view", type=str, default="side",
help="Camera view used to count frames in --all mode (default: side).")
args = parser.parse_args()
episode = Path(args.episode)
if args.all:
rgb_dir = episode / args.view / "rgb"
import re
frames = sorted({int(m.group()) for p in rgb_dir.glob("*.*")
for m in [re.search(r"\d+", p.stem)] if m})
if not frames:
raise SystemExit(f"No frames found in {rgb_dir}")
out_dir = Path(args.output) if args.output else Path("graphs_per_frame")
out_dir.mkdir(parents=True, exist_ok=True)
for i in frames:
g = frame_to_graph_dict(episode, i)
with open(out_dir / f"frame_{i:06d}_graph.json", "w") as f:
json.dump(g, f, indent=2)
print(f"Saved {len(frames)} graphs to {out_dir}")
return
if args.frame_idx is None:
parser.error("frame_idx is required unless --all is given")
graph_dict = frame_to_graph_dict(episode, args.frame_idx)
out_path = Path(args.output) if args.output else Path(f"frame_{args.frame_idx:06d}_graph.json")
with open(out_path, "w") as f:
json.dump(graph_dict, f, indent=2)
n_locked = sum(1 for e in graph_dict["edges"] if e["is_locked"])
n_unlocked = len(graph_dict["edges"]) - n_locked
print(f"Saved {out_path}")
print(f" {len(graph_dict['nodes'])} nodes, {len(graph_dict['edges'])} constraint edges "
f"({n_locked} locked, {n_unlocked} unlocked)")
if __name__ == "__main__":
main()