fetch-lift-td3-bc / training_code.py
Ason-jay's picture
Upload folder using huggingface_hub
35f348e verified
"""TD3+BC: TD3 with Behavior Cloning regularization for offline RL.
All computations done in normalized space:
- States: zero mean, unit variance (from dataset stats)
- Actions: scaled to [-1, 1] using joint limits
"""
import os
import csv
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from offline_dataset import OfflineRLDataset
# Joint limits for scaling
JOINT_LIMITS_LOW = torch.tensor(
[-1.606, -1.221, -3.142, -2.251, -3.142, -2.16, -3.142, 0.0, 0.0],
dtype=torch.float32,
)
JOINT_LIMITS_HIGH = torch.tensor(
[1.606, 1.518, 3.142, 2.251, 3.142, 3.142, 3.142, 0.05, 0.05],
dtype=torch.float32,
)
def normalize_action(action, low, high):
"""Map raw action from [low, high] to [-1, 1]."""
return 2.0 * (action - low) / (high - low) - 1.0
def denormalize_action(action_norm, low, high):
"""Map normalized action from [-1, 1] to [low, high]."""
return low + (action_norm + 1.0) * 0.5 * (high - low)
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, state_mean, state_std):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim),
nn.Tanh(),
)
self.register_buffer("state_mean", state_mean)
self.register_buffer("state_std", state_std)
self.register_buffer("action_low", JOINT_LIMITS_LOW)
self.register_buffer("action_high", JOINT_LIMITS_HIGH)
def forward(self, state):
"""Returns normalized action in [-1, 1]."""
state_norm = (state - self.state_mean) / self.state_std
return self.net(state_norm)
def get_raw_action(self, state):
"""Returns denormalized action in joint-limit space."""
a_norm = self.forward(state)
return denormalize_action(a_norm, self.action_low, self.action_high)
class Critic(nn.Module):
"""Twin Q-networks with LayerNorm for stable offline RL training."""
def __init__(self, state_dim, action_dim):
super().__init__()
self.q1 = nn.Sequential(
nn.Linear(state_dim + action_dim, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 1),
)
self.q2 = nn.Sequential(
nn.Linear(state_dim + action_dim, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 1),
)
def forward(self, state_norm, action_norm):
sa = torch.cat([state_norm, action_norm], dim=-1)
return self.q1(sa), self.q2(sa)
def q1_forward(self, state_norm, action_norm):
sa = torch.cat([state_norm, action_norm], dim=-1)
return self.q1(sa)
class TD3BC:
def __init__(
self,
state_dim=9,
action_dim=9,
state_mean=None,
state_std=None,
lr=3e-4,
discount=0.99,
tau=0.005,
policy_noise=0.2,
noise_clip=0.5,
policy_delay=2,
alpha=2.5,
device="cuda",
):
self.device = device
self.discount = discount
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_delay = policy_delay
self.alpha = alpha
self.max_action = 1.0 # normalized action space
self.state_mean = state_mean.to(device)
self.state_std = state_std.to(device)
self.action_low = JOINT_LIMITS_LOW.to(device)
self.action_high = JOINT_LIMITS_HIGH.to(device)
self.actor = Actor(state_dim, action_dim, state_mean, state_std).to(device)
self.actor_target = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
self.total_it = 0
def _normalize_state(self, state):
return (state - self.state_mean) / self.state_std
def _normalize_action(self, action):
return normalize_action(action, self.action_low, self.action_high)
def train_step(self, state, action, reward, next_state, done):
"""One training step. state/action/next_state are raw (unnormalized)."""
self.total_it += 1
# Normalize inputs
s_norm = self._normalize_state(state)
a_norm = self._normalize_action(action)
ns_norm = self._normalize_state(next_state)
with torch.no_grad():
# Target policy smoothing in normalized action space
noise = (torch.randn_like(a_norm) * self.policy_noise).clamp(
-self.noise_clip, self.noise_clip
)
# Actor outputs normalized actions
next_a_norm = (self.actor_target(next_state) + noise).clamp(-1.0, 1.0)
# Twin Q targets
target_q1, target_q2 = self.critic_target(ns_norm, next_a_norm)
target_q = torch.min(target_q1, target_q2)
target_q = reward.unsqueeze(-1) + (1.0 - done.unsqueeze(-1)) * self.discount * target_q
# Clamp to prevent value explosion (max possible Q ≈ gamma^50 * 1.0 ≈ 0.6)
target_q = target_q.clamp(-1.0, 2.0)
# Critic update
current_q1, current_q2 = self.critic(s_norm, a_norm)
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0)
self.critic_optimizer.step()
# Delayed actor update
actor_loss_val = 0.0
bc_loss_val = 0.0
q_value_mean = 0.0
if self.total_it % self.policy_delay == 0:
# Actor outputs normalized actions
pi_norm = self.actor(state)
q_val = self.critic.q1_forward(s_norm, pi_norm)
# Lambda normalization
lam = self.alpha / self.critic.q1_forward(s_norm, a_norm).abs().mean().detach()
# BC loss in normalized action space
bc_loss = ((pi_norm - a_norm) ** 2).mean()
actor_loss = -lam * q_val.mean() + bc_loss
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
self.actor_optimizer.step()
# Soft update target networks
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)
actor_loss_val = actor_loss.item()
bc_loss_val = bc_loss.item()
q_value_mean = q_val.mean().item()
return {
"critic_loss": critic_loss.item(),
"actor_loss": actor_loss_val,
"bc_loss": bc_loss_val,
"q_value_mean": q_value_mean,
"q_value_std": current_q1.std().item(),
}
def save(self, filepath):
torch.save({
"actor": self.actor.state_dict(),
"critic": self.critic.state_dict(),
"actor_target": self.actor_target.state_dict(),
"critic_target": self.critic_target.state_dict(),
"actor_optimizer": self.actor_optimizer.state_dict(),
"critic_optimizer": self.critic_optimizer.state_dict(),
"total_it": self.total_it,
}, filepath)
def load(self, filepath):
checkpoint = torch.load(filepath, map_location=self.device)
self.actor.load_state_dict(checkpoint["actor"])
self.critic.load_state_dict(checkpoint["critic"])
self.actor_target.load_state_dict(checkpoint["actor_target"])
self.critic_target.load_state_dict(checkpoint["critic_target"])
self.actor_optimizer.load_state_dict(checkpoint["actor_optimizer"])
self.critic_optimizer.load_state_dict(checkpoint["critic_optimizer"])
self.total_it = checkpoint["total_it"]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="/code/zxx240000/training/offline_rl/data/offline_dataset.npz")
parser.add_argument("--output_dir", default="/code/zxx240000/training/offline_rl/results/td3_bc")
parser.add_argument("--num_iterations", type=int, default=100000)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--discount", type=float, default=0.99)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--policy_noise", type=float, default=0.2)
parser.add_argument("--noise_clip", type=float, default=0.5)
parser.add_argument("--policy_delay", type=int, default=2)
parser.add_argument("--alpha", type=float, default=2.5)
parser.add_argument("--eval_freq", type=int, default=10000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
# Seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Dirs
ckpt_dir = os.path.join(args.output_dir, "checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
# Load dataset
print(f"Loading dataset from {args.dataset}")
dataset = OfflineRLDataset(args.dataset, device=args.device)
print(f" {dataset.size} transitions loaded")
# Move normalization stats to device
state_mean = dataset.state_mean.to(args.device)
state_std = dataset.state_std.to(args.device)
# Create agent
agent = TD3BC(
state_dim=9,
action_dim=9,
state_mean=state_mean,
state_std=state_std,
lr=args.lr,
discount=args.discount,
tau=args.tau,
policy_noise=args.policy_noise,
noise_clip=args.noise_clip,
policy_delay=args.policy_delay,
alpha=args.alpha,
device=args.device,
)
print(f"TD3+BC agent created on {args.device}")
# Print normalized action stats for sanity check
raw_actions = dataset.actions.to(args.device)
norm_actions = normalize_action(raw_actions, JOINT_LIMITS_LOW.to(args.device), JOINT_LIMITS_HIGH.to(args.device))
print(f" Normalized action range: [{norm_actions.min():.3f}, {norm_actions.max():.3f}]")
print(f" Normalized action mean: {norm_actions.mean(0).cpu().numpy()}")
# Training log
log_path = os.path.join(args.output_dir, "training_log.csv")
log_file = open(log_path, "w", newline="")
log_writer = csv.writer(log_file)
log_writer.writerow(["step", "critic_loss", "actor_loss", "bc_loss", "q_value_mean", "q_value_std"])
# Running metrics for averaging
running = {"critic_loss": 0, "actor_loss": 0, "bc_loss": 0, "q_value_mean": 0, "q_value_std": 0}
actor_updates = 0
print(f"\nStarting training for {args.num_iterations} iterations...")
for step in range(1, args.num_iterations + 1):
state, action, reward, next_state, done = dataset.sample(args.batch_size)
metrics = agent.train_step(state, action, reward, next_state, done)
running["critic_loss"] += metrics["critic_loss"]
running["q_value_std"] += metrics["q_value_std"]
if metrics["actor_loss"] != 0:
running["actor_loss"] += metrics["actor_loss"]
running["bc_loss"] += metrics["bc_loss"]
running["q_value_mean"] += metrics["q_value_mean"]
actor_updates += 1
if step % args.eval_freq == 0:
n = args.eval_freq
n_actor = max(actor_updates, 1)
avg_critic = running["critic_loss"] / n
avg_actor = running["actor_loss"] / n_actor
avg_bc = running["bc_loss"] / n_actor
avg_q_mean = running["q_value_mean"] / n_actor
avg_q_std = running["q_value_std"] / n
log_writer.writerow([step, f"{avg_critic:.6f}", f"{avg_actor:.6f}",
f"{avg_bc:.6f}", f"{avg_q_mean:.6f}", f"{avg_q_std:.6f}"])
log_file.flush()
print(f"Step {step:>6d} | Critic: {avg_critic:.6f} | Actor: {avg_actor:.6f} | "
f"BC: {avg_bc:.6f} | Q-mean: {avg_q_mean:.4f} | Q-std: {avg_q_std:.4f}")
# Save checkpoint
ckpt_path = os.path.join(ckpt_dir, f"checkpoint_{step}.pt")
agent.save(ckpt_path)
# Reset running metrics
running = {k: 0 for k in running}
actor_updates = 0
# Validate policy outputs (raw action space)
with torch.no_grad():
test_states = dataset.states[:100].to(args.device)
test_actions_raw = agent.actor.get_raw_action(test_states)
a_min = test_actions_raw.min(dim=0).values.cpu().numpy()
a_max = test_actions_raw.max(dim=0).values.cpu().numpy()
within_limits = (
(test_actions_raw >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all()
and (test_actions_raw <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all()
)
if not within_limits:
print(f" WARNING: Policy outputs outside joint limits!")
print(f" Min: {a_min}")
print(f" Max: {a_max}")
log_file.close()
# Save final model
best_path = os.path.join(args.output_dir, "best_model.pt")
agent.save(best_path)
print(f"\nFinal model saved to {best_path}")
# Final validation
print("\n=== FINAL VALIDATION ===")
with torch.no_grad():
all_states = dataset.states.to(args.device)
chunk_size = 4096
all_actions = []
for i in range(0, len(all_states), chunk_size):
chunk = all_states[i:i+chunk_size]
all_actions.append(agent.actor.get_raw_action(chunk))
all_actions = torch.cat(all_actions, dim=0)
print("Policy action statistics (raw joint space):")
joint_names = ["shoulder_pan", "shoulder_lift", "upperarm_roll", "elbow_flex",
"forearm_roll", "wrist_flex", "wrist_roll", "l_gripper", "r_gripper"]
for i, name in enumerate(joint_names):
a = all_actions[:, i]
print(f" {name}: min={a.min():.4f}, max={a.max():.4f}, mean={a.mean():.4f}, "
f"limits=[{JOINT_LIMITS_LOW[i]:.3f}, {JOINT_LIMITS_HIGH[i]:.3f}]")
within = (
(all_actions >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all()
and (all_actions <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all()
)
print(f"\nAll actions within joint limits: {within.item()}")
# Reload and verify saved model
print("\nVerifying saved model loads correctly...")
agent2 = TD3BC(state_dim=9, action_dim=9, state_mean=state_mean, state_std=state_std, device=args.device)
agent2.load(best_path)
with torch.no_grad():
test_s = dataset.states[:10].to(args.device)
test_a = agent2.actor.get_raw_action(test_s)
print(f" Loaded model produces actions: shape={test_a.shape}, range=[{test_a.min():.4f}, {test_a.max():.4f}]")
print(f"\nTraining log saved to {log_path}")
print(f"Checkpoints saved to {ckpt_dir}")
print(f"Best model saved to {best_path}")
if __name__ == "__main__":
main()