hackathon / training /trajectory.py
Ev3Dev's picture
Upload folder using huggingface_hub
4db0438 verified
"""Trajectory serialisation and dataset utilities.
A ``Trajectory`` stores the full history of one episode (task, actions,
observations, rewards, latent-state snapshots) in a format that supports:
- offline RL training
- imitation learning from expert demonstrations
- evaluation / replay
- simulator calibration
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
from models import (
ExperimentAction,
ExperimentObservation,
TaskSpec,
)
@dataclass
class TrajectoryStep:
step_index: int
action: Dict[str, Any]
observation: Dict[str, Any]
reward: float
done: bool
reward_breakdown: Dict[str, float] = field(default_factory=dict)
latent_snapshot: Optional[Dict[str, Any]] = None
@dataclass
class Trajectory:
"""Complete record of one environment episode."""
episode_id: str
task: Dict[str, Any]
steps: List[TrajectoryStep] = field(default_factory=list)
total_reward: float = 0.0
success: bool = False
metadata: Dict[str, Any] = field(default_factory=dict)
# ── construction helpers ────────────────────────────────────────────
def add_step(
self,
action: ExperimentAction,
observation: ExperimentObservation,
reward: float,
done: bool,
reward_breakdown: Optional[Dict[str, float]] = None,
latent_snapshot: Optional[Dict[str, Any]] = None,
) -> None:
self.steps.append(TrajectoryStep(
step_index=len(self.steps),
action=action.model_dump(),
observation=observation.model_dump(),
reward=reward,
done=done,
reward_breakdown=reward_breakdown or {},
latent_snapshot=latent_snapshot,
))
self.total_reward += reward
if done:
self.success = reward > 0
# ── serialisation ───────────────────────────────────────────────────
def to_dict(self) -> Dict[str, Any]:
return {
"episode_id": self.episode_id,
"task": self.task,
"steps": [
{
"step_index": s.step_index,
"action": s.action,
"observation": s.observation,
"reward": s.reward,
"done": s.done,
"reward_breakdown": s.reward_breakdown,
"latent_snapshot": s.latent_snapshot,
}
for s in self.steps
],
"total_reward": self.total_reward,
"success": self.success,
"metadata": self.metadata,
}
def save(self, path: str | Path) -> None:
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
with open(p, "w") as f:
json.dump(self.to_dict(), f, indent=2, default=str)
@classmethod
def load(cls, path: str | Path) -> "Trajectory":
with open(path) as f:
d = json.load(f)
traj = cls(
episode_id=d["episode_id"],
task=d["task"],
total_reward=d.get("total_reward", 0.0),
success=d.get("success", False),
metadata=d.get("metadata", {}),
)
for s in d.get("steps", []):
traj.steps.append(TrajectoryStep(**s))
return traj
class TrajectoryDataset:
"""In-memory collection of trajectories with convenience accessors."""
def __init__(self, trajectories: Optional[List[Trajectory]] = None):
self.trajectories: List[Trajectory] = trajectories or []
def add(self, traj: Trajectory) -> None:
self.trajectories.append(traj)
def __len__(self) -> int:
return len(self.trajectories)
def __getitem__(self, idx: int) -> Trajectory:
return self.trajectories[idx]
def filter_successful(self) -> "TrajectoryDataset":
return TrajectoryDataset([t for t in self.trajectories if t.success])
def save_dir(self, directory: str | Path) -> None:
d = Path(directory)
d.mkdir(parents=True, exist_ok=True)
for t in self.trajectories:
t.save(d / f"{t.episode_id}.json")
@classmethod
def load_dir(cls, directory: str | Path) -> "TrajectoryDataset":
d = Path(directory)
trajs = [Trajectory.load(p) for p in sorted(d.glob("*.json"))]
return cls(trajs)
def summary(self) -> Dict[str, Any]:
if not self.trajectories:
return {"n": 0}
rewards = [t.total_reward for t in self.trajectories]
lengths = [len(t.steps) for t in self.trajectories]
success_rate = sum(1 for t in self.trajectories if t.success) / len(self.trajectories)
return {
"n": len(self.trajectories),
"success_rate": success_rate,
"mean_reward": sum(rewards) / len(rewards),
"mean_length": sum(lengths) / len(lengths),
"max_reward": max(rewards),
"min_reward": min(rewards),
}