|
|
"""TD3+BC: TD3 with Behavior Cloning regularization for offline RL. |
|
|
|
|
|
All computations done in normalized space: |
|
|
- States: zero mean, unit variance (from dataset stats) |
|
|
- Actions: scaled to [-1, 1] using joint limits |
|
|
""" |
|
|
|
|
|
import os |
|
|
import csv |
|
|
import copy |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from offline_dataset import OfflineRLDataset |
|
|
|
|
|
|
|
|
|
|
|
JOINT_LIMITS_LOW = torch.tensor( |
|
|
[-1.606, -1.221, -3.142, -2.251, -3.142, -2.16, -3.142, 0.0, 0.0], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
JOINT_LIMITS_HIGH = torch.tensor( |
|
|
[1.606, 1.518, 3.142, 2.251, 3.142, 3.142, 3.142, 0.05, 0.05], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
def normalize_action(action, low, high): |
|
|
"""Map raw action from [low, high] to [-1, 1].""" |
|
|
return 2.0 * (action - low) / (high - low) - 1.0 |
|
|
|
|
|
|
|
|
def denormalize_action(action_norm, low, high): |
|
|
"""Map normalized action from [-1, 1] to [low, high].""" |
|
|
return low + (action_norm + 1.0) * 0.5 * (high - low) |
|
|
|
|
|
|
|
|
class Actor(nn.Module): |
|
|
def __init__(self, state_dim, action_dim, state_mean, state_std): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(state_dim, 256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, 256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, action_dim), |
|
|
nn.Tanh(), |
|
|
) |
|
|
self.register_buffer("state_mean", state_mean) |
|
|
self.register_buffer("state_std", state_std) |
|
|
self.register_buffer("action_low", JOINT_LIMITS_LOW) |
|
|
self.register_buffer("action_high", JOINT_LIMITS_HIGH) |
|
|
|
|
|
def forward(self, state): |
|
|
"""Returns normalized action in [-1, 1].""" |
|
|
state_norm = (state - self.state_mean) / self.state_std |
|
|
return self.net(state_norm) |
|
|
|
|
|
def get_raw_action(self, state): |
|
|
"""Returns denormalized action in joint-limit space.""" |
|
|
a_norm = self.forward(state) |
|
|
return denormalize_action(a_norm, self.action_low, self.action_high) |
|
|
|
|
|
|
|
|
class Critic(nn.Module): |
|
|
"""Twin Q-networks with LayerNorm for stable offline RL training.""" |
|
|
def __init__(self, state_dim, action_dim): |
|
|
super().__init__() |
|
|
self.q1 = nn.Sequential( |
|
|
nn.Linear(state_dim + action_dim, 256), |
|
|
nn.LayerNorm(256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, 256), |
|
|
nn.LayerNorm(256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, 1), |
|
|
) |
|
|
self.q2 = nn.Sequential( |
|
|
nn.Linear(state_dim + action_dim, 256), |
|
|
nn.LayerNorm(256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, 256), |
|
|
nn.LayerNorm(256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, 1), |
|
|
) |
|
|
|
|
|
def forward(self, state_norm, action_norm): |
|
|
sa = torch.cat([state_norm, action_norm], dim=-1) |
|
|
return self.q1(sa), self.q2(sa) |
|
|
|
|
|
def q1_forward(self, state_norm, action_norm): |
|
|
sa = torch.cat([state_norm, action_norm], dim=-1) |
|
|
return self.q1(sa) |
|
|
|
|
|
|
|
|
class TD3BC: |
|
|
def __init__( |
|
|
self, |
|
|
state_dim=9, |
|
|
action_dim=9, |
|
|
state_mean=None, |
|
|
state_std=None, |
|
|
lr=3e-4, |
|
|
discount=0.99, |
|
|
tau=0.005, |
|
|
policy_noise=0.2, |
|
|
noise_clip=0.5, |
|
|
policy_delay=2, |
|
|
alpha=2.5, |
|
|
device="cuda", |
|
|
): |
|
|
self.device = device |
|
|
self.discount = discount |
|
|
self.tau = tau |
|
|
self.policy_noise = policy_noise |
|
|
self.noise_clip = noise_clip |
|
|
self.policy_delay = policy_delay |
|
|
self.alpha = alpha |
|
|
self.max_action = 1.0 |
|
|
|
|
|
self.state_mean = state_mean.to(device) |
|
|
self.state_std = state_std.to(device) |
|
|
self.action_low = JOINT_LIMITS_LOW.to(device) |
|
|
self.action_high = JOINT_LIMITS_HIGH.to(device) |
|
|
|
|
|
self.actor = Actor(state_dim, action_dim, state_mean, state_std).to(device) |
|
|
self.actor_target = copy.deepcopy(self.actor) |
|
|
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) |
|
|
|
|
|
self.critic = Critic(state_dim, action_dim).to(device) |
|
|
self.critic_target = copy.deepcopy(self.critic) |
|
|
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr) |
|
|
|
|
|
self.total_it = 0 |
|
|
|
|
|
def _normalize_state(self, state): |
|
|
return (state - self.state_mean) / self.state_std |
|
|
|
|
|
def _normalize_action(self, action): |
|
|
return normalize_action(action, self.action_low, self.action_high) |
|
|
|
|
|
def train_step(self, state, action, reward, next_state, done): |
|
|
"""One training step. state/action/next_state are raw (unnormalized).""" |
|
|
self.total_it += 1 |
|
|
|
|
|
|
|
|
s_norm = self._normalize_state(state) |
|
|
a_norm = self._normalize_action(action) |
|
|
ns_norm = self._normalize_state(next_state) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
noise = (torch.randn_like(a_norm) * self.policy_noise).clamp( |
|
|
-self.noise_clip, self.noise_clip |
|
|
) |
|
|
|
|
|
next_a_norm = (self.actor_target(next_state) + noise).clamp(-1.0, 1.0) |
|
|
|
|
|
|
|
|
target_q1, target_q2 = self.critic_target(ns_norm, next_a_norm) |
|
|
target_q = torch.min(target_q1, target_q2) |
|
|
target_q = reward.unsqueeze(-1) + (1.0 - done.unsqueeze(-1)) * self.discount * target_q |
|
|
|
|
|
target_q = target_q.clamp(-1.0, 2.0) |
|
|
|
|
|
|
|
|
current_q1, current_q2 = self.critic(s_norm, a_norm) |
|
|
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) |
|
|
|
|
|
self.critic_optimizer.zero_grad() |
|
|
critic_loss.backward() |
|
|
nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0) |
|
|
self.critic_optimizer.step() |
|
|
|
|
|
|
|
|
actor_loss_val = 0.0 |
|
|
bc_loss_val = 0.0 |
|
|
q_value_mean = 0.0 |
|
|
|
|
|
if self.total_it % self.policy_delay == 0: |
|
|
|
|
|
pi_norm = self.actor(state) |
|
|
q_val = self.critic.q1_forward(s_norm, pi_norm) |
|
|
|
|
|
|
|
|
lam = self.alpha / self.critic.q1_forward(s_norm, a_norm).abs().mean().detach() |
|
|
|
|
|
|
|
|
bc_loss = ((pi_norm - a_norm) ** 2).mean() |
|
|
|
|
|
actor_loss = -lam * q_val.mean() + bc_loss |
|
|
|
|
|
self.actor_optimizer.zero_grad() |
|
|
actor_loss.backward() |
|
|
nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0) |
|
|
self.actor_optimizer.step() |
|
|
|
|
|
|
|
|
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): |
|
|
target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data) |
|
|
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): |
|
|
target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data) |
|
|
|
|
|
actor_loss_val = actor_loss.item() |
|
|
bc_loss_val = bc_loss.item() |
|
|
q_value_mean = q_val.mean().item() |
|
|
|
|
|
return { |
|
|
"critic_loss": critic_loss.item(), |
|
|
"actor_loss": actor_loss_val, |
|
|
"bc_loss": bc_loss_val, |
|
|
"q_value_mean": q_value_mean, |
|
|
"q_value_std": current_q1.std().item(), |
|
|
} |
|
|
|
|
|
def save(self, filepath): |
|
|
torch.save({ |
|
|
"actor": self.actor.state_dict(), |
|
|
"critic": self.critic.state_dict(), |
|
|
"actor_target": self.actor_target.state_dict(), |
|
|
"critic_target": self.critic_target.state_dict(), |
|
|
"actor_optimizer": self.actor_optimizer.state_dict(), |
|
|
"critic_optimizer": self.critic_optimizer.state_dict(), |
|
|
"total_it": self.total_it, |
|
|
}, filepath) |
|
|
|
|
|
def load(self, filepath): |
|
|
checkpoint = torch.load(filepath, map_location=self.device) |
|
|
self.actor.load_state_dict(checkpoint["actor"]) |
|
|
self.critic.load_state_dict(checkpoint["critic"]) |
|
|
self.actor_target.load_state_dict(checkpoint["actor_target"]) |
|
|
self.critic_target.load_state_dict(checkpoint["critic_target"]) |
|
|
self.actor_optimizer.load_state_dict(checkpoint["actor_optimizer"]) |
|
|
self.critic_optimizer.load_state_dict(checkpoint["critic_optimizer"]) |
|
|
self.total_it = checkpoint["total_it"] |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--dataset", default="/code/zxx240000/training/offline_rl/data/offline_dataset.npz") |
|
|
parser.add_argument("--output_dir", default="/code/zxx240000/training/offline_rl/results/td3_bc") |
|
|
parser.add_argument("--num_iterations", type=int, default=100000) |
|
|
parser.add_argument("--batch_size", type=int, default=256) |
|
|
parser.add_argument("--lr", type=float, default=3e-4) |
|
|
parser.add_argument("--discount", type=float, default=0.99) |
|
|
parser.add_argument("--tau", type=float, default=0.005) |
|
|
parser.add_argument("--policy_noise", type=float, default=0.2) |
|
|
parser.add_argument("--noise_clip", type=float, default=0.5) |
|
|
parser.add_argument("--policy_delay", type=int, default=2) |
|
|
parser.add_argument("--alpha", type=float, default=2.5) |
|
|
parser.add_argument("--eval_freq", type=int, default=10000) |
|
|
parser.add_argument("--seed", type=int, default=42) |
|
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
|
|
|
|
|
|
ckpt_dir = os.path.join(args.output_dir, "checkpoints") |
|
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"Loading dataset from {args.dataset}") |
|
|
dataset = OfflineRLDataset(args.dataset, device=args.device) |
|
|
print(f" {dataset.size} transitions loaded") |
|
|
|
|
|
|
|
|
state_mean = dataset.state_mean.to(args.device) |
|
|
state_std = dataset.state_std.to(args.device) |
|
|
|
|
|
|
|
|
agent = TD3BC( |
|
|
state_dim=9, |
|
|
action_dim=9, |
|
|
state_mean=state_mean, |
|
|
state_std=state_std, |
|
|
lr=args.lr, |
|
|
discount=args.discount, |
|
|
tau=args.tau, |
|
|
policy_noise=args.policy_noise, |
|
|
noise_clip=args.noise_clip, |
|
|
policy_delay=args.policy_delay, |
|
|
alpha=args.alpha, |
|
|
device=args.device, |
|
|
) |
|
|
print(f"TD3+BC agent created on {args.device}") |
|
|
|
|
|
|
|
|
raw_actions = dataset.actions.to(args.device) |
|
|
norm_actions = normalize_action(raw_actions, JOINT_LIMITS_LOW.to(args.device), JOINT_LIMITS_HIGH.to(args.device)) |
|
|
print(f" Normalized action range: [{norm_actions.min():.3f}, {norm_actions.max():.3f}]") |
|
|
print(f" Normalized action mean: {norm_actions.mean(0).cpu().numpy()}") |
|
|
|
|
|
|
|
|
log_path = os.path.join(args.output_dir, "training_log.csv") |
|
|
log_file = open(log_path, "w", newline="") |
|
|
log_writer = csv.writer(log_file) |
|
|
log_writer.writerow(["step", "critic_loss", "actor_loss", "bc_loss", "q_value_mean", "q_value_std"]) |
|
|
|
|
|
|
|
|
running = {"critic_loss": 0, "actor_loss": 0, "bc_loss": 0, "q_value_mean": 0, "q_value_std": 0} |
|
|
actor_updates = 0 |
|
|
|
|
|
print(f"\nStarting training for {args.num_iterations} iterations...") |
|
|
for step in range(1, args.num_iterations + 1): |
|
|
state, action, reward, next_state, done = dataset.sample(args.batch_size) |
|
|
metrics = agent.train_step(state, action, reward, next_state, done) |
|
|
|
|
|
running["critic_loss"] += metrics["critic_loss"] |
|
|
running["q_value_std"] += metrics["q_value_std"] |
|
|
if metrics["actor_loss"] != 0: |
|
|
running["actor_loss"] += metrics["actor_loss"] |
|
|
running["bc_loss"] += metrics["bc_loss"] |
|
|
running["q_value_mean"] += metrics["q_value_mean"] |
|
|
actor_updates += 1 |
|
|
|
|
|
if step % args.eval_freq == 0: |
|
|
n = args.eval_freq |
|
|
n_actor = max(actor_updates, 1) |
|
|
avg_critic = running["critic_loss"] / n |
|
|
avg_actor = running["actor_loss"] / n_actor |
|
|
avg_bc = running["bc_loss"] / n_actor |
|
|
avg_q_mean = running["q_value_mean"] / n_actor |
|
|
avg_q_std = running["q_value_std"] / n |
|
|
|
|
|
log_writer.writerow([step, f"{avg_critic:.6f}", f"{avg_actor:.6f}", |
|
|
f"{avg_bc:.6f}", f"{avg_q_mean:.6f}", f"{avg_q_std:.6f}"]) |
|
|
log_file.flush() |
|
|
|
|
|
print(f"Step {step:>6d} | Critic: {avg_critic:.6f} | Actor: {avg_actor:.6f} | " |
|
|
f"BC: {avg_bc:.6f} | Q-mean: {avg_q_mean:.4f} | Q-std: {avg_q_std:.4f}") |
|
|
|
|
|
|
|
|
ckpt_path = os.path.join(ckpt_dir, f"checkpoint_{step}.pt") |
|
|
agent.save(ckpt_path) |
|
|
|
|
|
|
|
|
running = {k: 0 for k in running} |
|
|
actor_updates = 0 |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
test_states = dataset.states[:100].to(args.device) |
|
|
test_actions_raw = agent.actor.get_raw_action(test_states) |
|
|
a_min = test_actions_raw.min(dim=0).values.cpu().numpy() |
|
|
a_max = test_actions_raw.max(dim=0).values.cpu().numpy() |
|
|
within_limits = ( |
|
|
(test_actions_raw >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all() |
|
|
and (test_actions_raw <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all() |
|
|
) |
|
|
if not within_limits: |
|
|
print(f" WARNING: Policy outputs outside joint limits!") |
|
|
print(f" Min: {a_min}") |
|
|
print(f" Max: {a_max}") |
|
|
|
|
|
log_file.close() |
|
|
|
|
|
|
|
|
best_path = os.path.join(args.output_dir, "best_model.pt") |
|
|
agent.save(best_path) |
|
|
print(f"\nFinal model saved to {best_path}") |
|
|
|
|
|
|
|
|
print("\n=== FINAL VALIDATION ===") |
|
|
with torch.no_grad(): |
|
|
all_states = dataset.states.to(args.device) |
|
|
chunk_size = 4096 |
|
|
all_actions = [] |
|
|
for i in range(0, len(all_states), chunk_size): |
|
|
chunk = all_states[i:i+chunk_size] |
|
|
all_actions.append(agent.actor.get_raw_action(chunk)) |
|
|
all_actions = torch.cat(all_actions, dim=0) |
|
|
|
|
|
print("Policy action statistics (raw joint space):") |
|
|
joint_names = ["shoulder_pan", "shoulder_lift", "upperarm_roll", "elbow_flex", |
|
|
"forearm_roll", "wrist_flex", "wrist_roll", "l_gripper", "r_gripper"] |
|
|
for i, name in enumerate(joint_names): |
|
|
a = all_actions[:, i] |
|
|
print(f" {name}: min={a.min():.4f}, max={a.max():.4f}, mean={a.mean():.4f}, " |
|
|
f"limits=[{JOINT_LIMITS_LOW[i]:.3f}, {JOINT_LIMITS_HIGH[i]:.3f}]") |
|
|
|
|
|
within = ( |
|
|
(all_actions >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all() |
|
|
and (all_actions <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all() |
|
|
) |
|
|
print(f"\nAll actions within joint limits: {within.item()}") |
|
|
|
|
|
|
|
|
print("\nVerifying saved model loads correctly...") |
|
|
agent2 = TD3BC(state_dim=9, action_dim=9, state_mean=state_mean, state_std=state_std, device=args.device) |
|
|
agent2.load(best_path) |
|
|
with torch.no_grad(): |
|
|
test_s = dataset.states[:10].to(args.device) |
|
|
test_a = agent2.actor.get_raw_action(test_s) |
|
|
print(f" Loaded model produces actions: shape={test_a.shape}, range=[{test_a.min():.4f}, {test_a.max():.4f}]") |
|
|
|
|
|
print(f"\nTraining log saved to {log_path}") |
|
|
print(f"Checkpoints saved to {ckpt_dir}") |
|
|
print(f"Best model saved to {best_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|