File size: 2,836 Bytes
1e250c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Basic example: score a LIBERO-format HDF5 dataset and print a report.

Usage:
    python examples/basic_scoring.py --hdf5 path/to/demo.hdf5

To run with synthetic data (no HDF5 file needed):
    python examples/basic_scoring.py --synthetic
"""
import argparse
import numpy as np

from haptal_curate import score, curate, report
from haptal_curate.types import Demo, DemoDataset
from haptal_curate.metrics import SmoothnessMetric, GripperTimingMetric, EnsembleMetric


def make_synthetic_dataset(n: int = 20, T: int = 80, seed: int = 42) -> DemoDataset:
    """Generate a synthetic dataset for demonstration purposes."""
    rng = np.random.default_rng(seed)
    demos = []
    for i in range(n):
        actions = rng.standard_normal((T, 7)).astype(np.float32)
        obs = rng.standard_normal((T, 9)).astype(np.float32)
        # Vary gripper timing: some demos open early (defective), some late (clean)
        open_at = rng.integers(T // 4, T)
        actions[:, -1] = -1.0
        actions[open_at:, -1] = 1.0
        demos.append(Demo(obs=obs, actions=actions, episode_length=T, demo_id=f"demo_{i}"))
    return DemoDataset(demos=demos)


def main():
    parser = argparse.ArgumentParser(description="haptal-curate basic scoring example")
    parser.add_argument("--hdf5", type=str, default=None, help="Path to HDF5 demo file")
    parser.add_argument("--synthetic", action="store_true", help="Use synthetic data")
    parser.add_argument("--fraction", type=float, default=0.5, help="Fraction of demos to keep")
    args = parser.parse_args()

    if args.hdf5:
        print(f"Loading demos from {args.hdf5}")
        score_result = score(args.hdf5)
    else:
        print("Using synthetic dataset (20 demos, T=80)")
        dataset = make_synthetic_dataset()
        score_result = score(dataset, metrics=[EnsembleMetric()])

    print(f"\nScored {len(score_result.scores)} demos using metric: {score_result.metric_name}")
    print(f"  Score range: [{score_result.scores.min():.3f}, {score_result.scores.max():.3f}]")
    print(f"  Score mean:  {score_result.scores.mean():.3f} ± {score_result.scores.std():.3f}")

    # Curate: keep top 50%
    curation_result = curate(score_result, fraction=args.fraction)
    print(f"\nCuration (top {args.fraction*100:.0f}%):")
    print(f"  Kept {len(curation_result.kept_indices)} demos, removed {len(curation_result.removed_indices)}")

    # Report
    summary = report(score_result, curation_result)
    print("\nTop 5 demos (highest quality):", summary["top5_demo_ids"])
    print("Bottom 5 demos (lowest quality):", summary["bottom5_demo_ids"])
    if "confound" in summary:
        c = summary["confound"]
        print(f"\nLength confound: severity={c['severity']}, Spearman r={c['spearman_r']:.3f}")
        print(f"  {c['message']}")


if __name__ == "__main__":
    main()