"""TD3+BC: TD3 with Behavior Cloning regularization for offline RL. All computations done in normalized space: - States: zero mean, unit variance (from dataset stats) - Actions: scaled to [-1, 1] using joint limits """ import os import csv import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from offline_dataset import OfflineRLDataset # Joint limits for scaling JOINT_LIMITS_LOW = torch.tensor( [-1.606, -1.221, -3.142, -2.251, -3.142, -2.16, -3.142, 0.0, 0.0], dtype=torch.float32, ) JOINT_LIMITS_HIGH = torch.tensor( [1.606, 1.518, 3.142, 2.251, 3.142, 3.142, 3.142, 0.05, 0.05], dtype=torch.float32, ) def normalize_action(action, low, high): """Map raw action from [low, high] to [-1, 1].""" return 2.0 * (action - low) / (high - low) - 1.0 def denormalize_action(action_norm, low, high): """Map normalized action from [-1, 1] to [low, high].""" return low + (action_norm + 1.0) * 0.5 * (high - low) class Actor(nn.Module): def __init__(self, state_dim, action_dim, state_mean, state_std): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, action_dim), nn.Tanh(), ) self.register_buffer("state_mean", state_mean) self.register_buffer("state_std", state_std) self.register_buffer("action_low", JOINT_LIMITS_LOW) self.register_buffer("action_high", JOINT_LIMITS_HIGH) def forward(self, state): """Returns normalized action in [-1, 1].""" state_norm = (state - self.state_mean) / self.state_std return self.net(state_norm) def get_raw_action(self, state): """Returns denormalized action in joint-limit space.""" a_norm = self.forward(state) return denormalize_action(a_norm, self.action_low, self.action_high) class Critic(nn.Module): """Twin Q-networks with LayerNorm for stable offline RL training.""" def __init__(self, state_dim, action_dim): super().__init__() self.q1 = nn.Sequential( nn.Linear(state_dim + action_dim, 256), nn.LayerNorm(256), nn.ReLU(), nn.Linear(256, 256), nn.LayerNorm(256), nn.ReLU(), nn.Linear(256, 1), ) self.q2 = nn.Sequential( nn.Linear(state_dim + action_dim, 256), nn.LayerNorm(256), nn.ReLU(), nn.Linear(256, 256), nn.LayerNorm(256), nn.ReLU(), nn.Linear(256, 1), ) def forward(self, state_norm, action_norm): sa = torch.cat([state_norm, action_norm], dim=-1) return self.q1(sa), self.q2(sa) def q1_forward(self, state_norm, action_norm): sa = torch.cat([state_norm, action_norm], dim=-1) return self.q1(sa) class TD3BC: def __init__( self, state_dim=9, action_dim=9, state_mean=None, state_std=None, lr=3e-4, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_delay=2, alpha=2.5, device="cuda", ): self.device = device self.discount = discount self.tau = tau self.policy_noise = policy_noise self.noise_clip = noise_clip self.policy_delay = policy_delay self.alpha = alpha self.max_action = 1.0 # normalized action space self.state_mean = state_mean.to(device) self.state_std = state_std.to(device) self.action_low = JOINT_LIMITS_LOW.to(device) self.action_high = JOINT_LIMITS_HIGH.to(device) self.actor = Actor(state_dim, action_dim, state_mean, state_std).to(device) self.actor_target = copy.deepcopy(self.actor) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) self.critic = Critic(state_dim, action_dim).to(device) self.critic_target = copy.deepcopy(self.critic) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr) self.total_it = 0 def _normalize_state(self, state): return (state - self.state_mean) / self.state_std def _normalize_action(self, action): return normalize_action(action, self.action_low, self.action_high) def train_step(self, state, action, reward, next_state, done): """One training step. state/action/next_state are raw (unnormalized).""" self.total_it += 1 # Normalize inputs s_norm = self._normalize_state(state) a_norm = self._normalize_action(action) ns_norm = self._normalize_state(next_state) with torch.no_grad(): # Target policy smoothing in normalized action space noise = (torch.randn_like(a_norm) * self.policy_noise).clamp( -self.noise_clip, self.noise_clip ) # Actor outputs normalized actions next_a_norm = (self.actor_target(next_state) + noise).clamp(-1.0, 1.0) # Twin Q targets target_q1, target_q2 = self.critic_target(ns_norm, next_a_norm) target_q = torch.min(target_q1, target_q2) target_q = reward.unsqueeze(-1) + (1.0 - done.unsqueeze(-1)) * self.discount * target_q # Clamp to prevent value explosion (max possible Q ≈ gamma^50 * 1.0 ≈ 0.6) target_q = target_q.clamp(-1.0, 2.0) # Critic update current_q1, current_q2 = self.critic(s_norm, a_norm) critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) self.critic_optimizer.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0) self.critic_optimizer.step() # Delayed actor update actor_loss_val = 0.0 bc_loss_val = 0.0 q_value_mean = 0.0 if self.total_it % self.policy_delay == 0: # Actor outputs normalized actions pi_norm = self.actor(state) q_val = self.critic.q1_forward(s_norm, pi_norm) # Lambda normalization lam = self.alpha / self.critic.q1_forward(s_norm, a_norm).abs().mean().detach() # BC loss in normalized action space bc_loss = ((pi_norm - a_norm) ** 2).mean() actor_loss = -lam * q_val.mean() + bc_loss self.actor_optimizer.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0) self.actor_optimizer.step() # Soft update target networks for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data) for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data) actor_loss_val = actor_loss.item() bc_loss_val = bc_loss.item() q_value_mean = q_val.mean().item() return { "critic_loss": critic_loss.item(), "actor_loss": actor_loss_val, "bc_loss": bc_loss_val, "q_value_mean": q_value_mean, "q_value_std": current_q1.std().item(), } def save(self, filepath): torch.save({ "actor": self.actor.state_dict(), "critic": self.critic.state_dict(), "actor_target": self.actor_target.state_dict(), "critic_target": self.critic_target.state_dict(), "actor_optimizer": self.actor_optimizer.state_dict(), "critic_optimizer": self.critic_optimizer.state_dict(), "total_it": self.total_it, }, filepath) def load(self, filepath): checkpoint = torch.load(filepath, map_location=self.device) self.actor.load_state_dict(checkpoint["actor"]) self.critic.load_state_dict(checkpoint["critic"]) self.actor_target.load_state_dict(checkpoint["actor_target"]) self.critic_target.load_state_dict(checkpoint["critic_target"]) self.actor_optimizer.load_state_dict(checkpoint["actor_optimizer"]) self.critic_optimizer.load_state_dict(checkpoint["critic_optimizer"]) self.total_it = checkpoint["total_it"] def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--dataset", default="/code/zxx240000/training/offline_rl/data/offline_dataset.npz") parser.add_argument("--output_dir", default="/code/zxx240000/training/offline_rl/results/td3_bc") parser.add_argument("--num_iterations", type=int, default=100000) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--discount", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--policy_noise", type=float, default=0.2) parser.add_argument("--noise_clip", type=float, default=0.5) parser.add_argument("--policy_delay", type=int, default=2) parser.add_argument("--alpha", type=float, default=2.5) parser.add_argument("--eval_freq", type=int, default=10000) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() # Seed torch.manual_seed(args.seed) np.random.seed(args.seed) # Dirs ckpt_dir = os.path.join(args.output_dir, "checkpoints") os.makedirs(ckpt_dir, exist_ok=True) # Load dataset print(f"Loading dataset from {args.dataset}") dataset = OfflineRLDataset(args.dataset, device=args.device) print(f" {dataset.size} transitions loaded") # Move normalization stats to device state_mean = dataset.state_mean.to(args.device) state_std = dataset.state_std.to(args.device) # Create agent agent = TD3BC( state_dim=9, action_dim=9, state_mean=state_mean, state_std=state_std, lr=args.lr, discount=args.discount, tau=args.tau, policy_noise=args.policy_noise, noise_clip=args.noise_clip, policy_delay=args.policy_delay, alpha=args.alpha, device=args.device, ) print(f"TD3+BC agent created on {args.device}") # Print normalized action stats for sanity check raw_actions = dataset.actions.to(args.device) norm_actions = normalize_action(raw_actions, JOINT_LIMITS_LOW.to(args.device), JOINT_LIMITS_HIGH.to(args.device)) print(f" Normalized action range: [{norm_actions.min():.3f}, {norm_actions.max():.3f}]") print(f" Normalized action mean: {norm_actions.mean(0).cpu().numpy()}") # Training log log_path = os.path.join(args.output_dir, "training_log.csv") log_file = open(log_path, "w", newline="") log_writer = csv.writer(log_file) log_writer.writerow(["step", "critic_loss", "actor_loss", "bc_loss", "q_value_mean", "q_value_std"]) # Running metrics for averaging running = {"critic_loss": 0, "actor_loss": 0, "bc_loss": 0, "q_value_mean": 0, "q_value_std": 0} actor_updates = 0 print(f"\nStarting training for {args.num_iterations} iterations...") for step in range(1, args.num_iterations + 1): state, action, reward, next_state, done = dataset.sample(args.batch_size) metrics = agent.train_step(state, action, reward, next_state, done) running["critic_loss"] += metrics["critic_loss"] running["q_value_std"] += metrics["q_value_std"] if metrics["actor_loss"] != 0: running["actor_loss"] += metrics["actor_loss"] running["bc_loss"] += metrics["bc_loss"] running["q_value_mean"] += metrics["q_value_mean"] actor_updates += 1 if step % args.eval_freq == 0: n = args.eval_freq n_actor = max(actor_updates, 1) avg_critic = running["critic_loss"] / n avg_actor = running["actor_loss"] / n_actor avg_bc = running["bc_loss"] / n_actor avg_q_mean = running["q_value_mean"] / n_actor avg_q_std = running["q_value_std"] / n log_writer.writerow([step, f"{avg_critic:.6f}", f"{avg_actor:.6f}", f"{avg_bc:.6f}", f"{avg_q_mean:.6f}", f"{avg_q_std:.6f}"]) log_file.flush() print(f"Step {step:>6d} | Critic: {avg_critic:.6f} | Actor: {avg_actor:.6f} | " f"BC: {avg_bc:.6f} | Q-mean: {avg_q_mean:.4f} | Q-std: {avg_q_std:.4f}") # Save checkpoint ckpt_path = os.path.join(ckpt_dir, f"checkpoint_{step}.pt") agent.save(ckpt_path) # Reset running metrics running = {k: 0 for k in running} actor_updates = 0 # Validate policy outputs (raw action space) with torch.no_grad(): test_states = dataset.states[:100].to(args.device) test_actions_raw = agent.actor.get_raw_action(test_states) a_min = test_actions_raw.min(dim=0).values.cpu().numpy() a_max = test_actions_raw.max(dim=0).values.cpu().numpy() within_limits = ( (test_actions_raw >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all() and (test_actions_raw <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all() ) if not within_limits: print(f" WARNING: Policy outputs outside joint limits!") print(f" Min: {a_min}") print(f" Max: {a_max}") log_file.close() # Save final model best_path = os.path.join(args.output_dir, "best_model.pt") agent.save(best_path) print(f"\nFinal model saved to {best_path}") # Final validation print("\n=== FINAL VALIDATION ===") with torch.no_grad(): all_states = dataset.states.to(args.device) chunk_size = 4096 all_actions = [] for i in range(0, len(all_states), chunk_size): chunk = all_states[i:i+chunk_size] all_actions.append(agent.actor.get_raw_action(chunk)) all_actions = torch.cat(all_actions, dim=0) print("Policy action statistics (raw joint space):") joint_names = ["shoulder_pan", "shoulder_lift", "upperarm_roll", "elbow_flex", "forearm_roll", "wrist_flex", "wrist_roll", "l_gripper", "r_gripper"] for i, name in enumerate(joint_names): a = all_actions[:, i] print(f" {name}: min={a.min():.4f}, max={a.max():.4f}, mean={a.mean():.4f}, " f"limits=[{JOINT_LIMITS_LOW[i]:.3f}, {JOINT_LIMITS_HIGH[i]:.3f}]") within = ( (all_actions >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all() and (all_actions <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all() ) print(f"\nAll actions within joint limits: {within.item()}") # Reload and verify saved model print("\nVerifying saved model loads correctly...") agent2 = TD3BC(state_dim=9, action_dim=9, state_mean=state_mean, state_std=state_std, device=args.device) agent2.load(best_path) with torch.no_grad(): test_s = dataset.states[:10].to(args.device) test_a = agent2.actor.get_raw_action(test_s) print(f" Loaded model produces actions: shape={test_a.shape}, range=[{test_a.min():.4f}, {test_a.max():.4f}]") print(f"\nTraining log saved to {log_path}") print(f"Checkpoints saved to {ckpt_dir}") print(f"Best model saved to {best_path}") if __name__ == "__main__": main()