harshraj22's picture
download
raw
19.6 kB
import os
import re
import sys
import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from peft import LoraConfig, get_peft_model
import wandb
# Ensure the root directory is on the path so cropRL module works
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from cropRL.tasks import create_env_for_task
from cropRL.models import MultiAgentAction
from cropRL.inference import parse_action, get_agent_system_prompt
def get_action_logprobs(model, input_ids, attention_mask, gen_seqs, gen_mask):
"""
Given full input_ids, their attention mask, generated sequences, and their mask,
compute the sum of log probabilities for the non-padded generated tokens.
"""
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits[:, :-1, :]
labels = input_ids[:, 1:]
gen_seq_len = gen_seqs.shape[1]
gen_logits = logits[:, -gen_seq_len:, :]
gen_labels = labels[:, -gen_seq_len:]
logprobs = F.log_softmax(gen_logits, dim=-1)
action_logprobs = logprobs.gather(dim=-1, index=gen_labels.unsqueeze(-1)).squeeze(-1)
# Mask out padding tokens
masked_logprobs = action_logprobs * gen_mask
return masked_logprobs.sum(dim=-1)
def get_action_prefix_fn(tokenizer, prompt_length):
"""Creates a prefix_allowed_tokens_fn to constrain generation to valid action formats."""
digit_tokens = {str(i): tokenizer.encode(str(i), add_special_tokens=False)[0] for i in range(10)}
space_token = tokenizer.encode(" ", add_special_tokens=False)[0]
token_1 = digit_tokens["1"]
tokens_0_to_4 = [digit_tokens[str(i)] for i in range(5)]
all_digits = list(digit_tokens.values())
def prefix_allowed_tokens_fn(batch_id, input_ids):
gen_tokens = input_ids[prompt_length:]
if len(gen_tokens) == 0:
return all_digits
elif len(gen_tokens) == 1:
first = gen_tokens[0].item()
if first == token_1:
return tokens_0_to_4 + [space_token, tokenizer.eos_token_id]
else:
return [space_token, tokenizer.eos_token_id]
elif len(gen_tokens) == 2:
first = gen_tokens[0].item()
second = gen_tokens[1].item()
if first == token_1 and second == token_1: # "11"
return [space_token] # Force space after "11"
else:
return [tokenizer.eos_token_id]
else:
first = gen_tokens[0].item()
if len(gen_tokens) > 1:
second = gen_tokens[1].item()
if first == token_1 and second == token_1:
return list(range(tokenizer.vocab_size))
return [tokenizer.eos_token_id]
return prefix_allowed_tokens_fn
def train(args):
print("="*50)
print("GRPO TRAINING CONFIGURATION")
print(f"Model Taken From: {args.model_name}")
import os
model_source = "Local Checkpoint" if os.path.isdir(args.model_name) else "HuggingFace Hub"
print(f"Model Source: {model_source}")
print(f"Task: {args.task}")
print(f"Group Size (G): {args.group_size}")
print(f"LoRA Targets: ['q_proj', 'v_proj']")
print("="*50)
# Initialize WandB
wandb.init(project="CropRL-GRPO", name=args.run_name, config=vars(args))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # important for batched generation
# Load Model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
# Apply LoRA
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
print("LoRA applied successfully. Trainable parameters:")
model.print_trainable_parameters()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.warmup_iterations,
num_training_steps=args.num_iterations
)
os.makedirs(args.output_dir, exist_ok=True)
for iteration in tqdm(range(1, args.num_iterations + 1), desc="Training Iterations"):
print(f"\n--- Iteration {iteration}/{args.num_iterations} ---")
# --- 1. Rollout Phase ---
model.eval() # Prevent dropout noise during rollout
envs = [create_env_for_task(args.task, text_mode=True) for _ in range(args.group_size)]
n_agents = envs[0]._ma_cfg.num_agents
# Curriculum Learning: Expanding horizon starts small to learn short-term consequences first
current_max_months = min(60, 10 + iteration * 2)
print(f"Curriculum Horizon: {current_max_months} months")
for env_idx, env in enumerate(envs):
env._env_cfg.max_months = current_max_months
# Unique seed per iteration and environment to prevent overfitting to a single weather/market trajectory
env_seed = (iteration * args.group_size) + env_idx
env.reset(seed=env_seed)
# Get initial net worths for reward shaping (per env, per agent)
prev_net_worths = [[env._farms[a].compute_net_worth() for a in range(n_agents)] for env in envs]
active_envs = list(range(args.group_size))
done_agents = {i: set() for i in range(args.group_size)}
histories = {i: {a: [] for a in range(n_agents)} for i in range(args.group_size)}
trajectories = [[[] for _ in range(n_agents)] for _ in range(args.group_size)]
step_count = 0
total_env_steps = envs[0]._env_cfg.max_months * envs[0]._ma_cfg.action_slots_per_month
with torch.no_grad(), tqdm(total=total_env_steps, desc="Rollout Phase", leave=False) as pbar:
while active_envs:
step_count += 1
# Use the rotating turn order from the first active env (valid proxy for batch)
for agent_slot in range(n_agents):
prompts = []
valid_env_indices = []
agent_ids_for_batch = []
# Fetch fresh observations for this agent across active environments
for env_idx in active_envs:
turn_order = envs[env_idx].get_turn_order()
agent_id = turn_order[agent_slot]
if agent_id in done_agents[env_idx]:
action_obj = MultiAgentAction(action_id=0, agent_id=agent_id, forum_message=None)
envs[env_idx].step(action_obj)
continue
obs = envs[env_idx].get_obs(agent_id)
if obs.done:
done_agents[env_idx].add(agent_id)
# Dead/done agents must wait out their slots so they don't block TimeController
action_obj = MultiAgentAction(action_id=0, agent_id=agent_id, forum_message=None)
envs[env_idx].step(action_obj)
continue
user_msg = obs.text_summary if getattr(obs, "text_summary", None) else str(obs)
history_block = "\n".join(histories[env_idx][agent_id][-12:]) if histories[env_idx][agent_id] else "None"
user_msg += f"\n\nRecent History:\n{history_block}"
messages = [
{"role": "system", "content": get_agent_system_prompt(agent_id, n_agents)},
{"role": "user", "content": user_msg}
]
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
enable_thinking=False
)
prompts.append(prompt)
valid_env_indices.append(env_idx)
agent_ids_for_batch.append(agent_id)
if not prompts:
continue
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
prefix_fn = get_action_prefix_fn(tokenizer, inputs.input_ids.shape[1])
outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=True,
temperature=args.temperature,
top_p=0.8,
top_k=20,
min_p=0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
prefix_allowed_tokens_fn=prefix_fn,
)
gen_seqs = outputs[:, inputs.input_ids.shape[1]:]
action_texts = tokenizer.batch_decode(gen_seqs, skip_special_tokens=True)
# Mask out right-padding in generation
gen_mask = (gen_seqs != tokenizer.pad_token_id).long()
full_seqs = outputs
full_attention_mask = (full_seqs != tokenizer.pad_token_id).long()
old_logprobs = get_action_logprobs(model, full_seqs, full_attention_mask, gen_seqs, gen_mask)
for idx, env_idx in enumerate(valid_env_indices):
agent_id = agent_ids_for_batch[idx]
action_text = action_texts[idx]
act_id, forum_msg = parse_action(action_text, fallback_action=0)
action_obj = MultiAgentAction(action_id=act_id, agent_id=agent_id, forum_message=forum_msg)
next_obs = envs[env_idx].step(action_obj)
# Reward shaping: Change in exact net worth (including crop/land values)
current_net_worth = envs[env_idx]._farms[agent_id].compute_net_worth()
reward = current_net_worth - prev_net_worths[env_idx][agent_id]
prev_net_worths[env_idx][agent_id] = current_net_worth
action_name = envs[env_idx]._env_cfg.action_names[act_id] if act_id < len(envs[env_idx]._env_cfg.action_names) else f"Action {act_id}"
histories[env_idx][agent_id].append(f"Step {getattr(next_obs, 'current_step', step_count)}: Selected '{action_name}' -> Reward {reward:+.2f}")
trajectories[env_idx][agent_id].append({
"input_ids": full_seqs[idx].cpu(),
"attention_mask": full_attention_mask[idx].cpu(),
"gen_seqs": gen_seqs[idx].cpu(),
"gen_mask": gen_mask[idx].cpu(),
"old_logprob": old_logprobs[idx].item(),
"reward": reward,
"action_id": act_id
})
if next_obs.done:
done_agents[env_idx].add(agent_id)
# Update active envs list (only keep envs where not all agents are done)
active_envs = [i for i in active_envs if len(done_agents[i]) < n_agents]
pbar.update(1)
# --- 2. Compute Advantages (GRPO) ---
# Normalize returns across all agents and all group environments
all_returns = []
for env_idx in range(args.group_size):
for agent_id in range(n_agents):
ret = sum(step["reward"] for step in trajectories[env_idx][agent_id])
all_returns.append(ret)
all_returns = np.array(all_returns)
mean_return = all_returns.mean()
std_return = all_returns.std() + 1e-8
print(f"Returns: {all_returns.round(2)}")
print(f"Mean Return: {mean_return:.2f} | Std: {std_return:.2f}")
# --- 3. Optimization Phase ---
model.train() # Enable dropout/training mode
# Flatten dataset for randomized mini-batching
dataset = []
ret_idx = 0
for env_idx in range(args.group_size):
for agent_id in range(n_agents):
A_i = (all_returns[ret_idx] - mean_return) / std_return
for step in trajectories[env_idx][agent_id]:
dataset.append({
"input_ids": step["input_ids"],
"attention_mask": step["attention_mask"],
"gen_seqs": step["gen_seqs"],
"gen_mask": step["gen_mask"],
"old_logprob": step["old_logprob"],
"A_i": A_i
})
ret_idx += 1
# Shuffle dataset to break temporal correlations
random.shuffle(dataset)
total_loss = 0
total_kl = 0
optim_steps = 0
optimizer.zero_grad()
# Iterate over steps, accumulating gradients to simulate mini-batches
for step_idx, step in tqdm(enumerate(dataset), total=len(dataset), desc="Optimization Phase", leave=False):
full_seq = step["input_ids"].unsqueeze(0).to(device)
full_attention_mask = step["attention_mask"].unsqueeze(0).to(device)
gen_seqs = step["gen_seqs"].unsqueeze(0).to(device)
gen_mask = step["gen_mask"].unsqueeze(0).to(device)
old_logprob = step["old_logprob"]
A_i = step["A_i"]
# Forward pass current model
current_logprobs = get_action_logprobs(model, full_seq, full_attention_mask, gen_seqs, gen_mask).squeeze(0)
# Forward pass reference model (LoRA disabled)
with torch.no_grad():
with model.disable_adapter():
ref_logprobs = get_action_logprobs(model, full_seq, full_attention_mask, gen_seqs, gen_mask).squeeze(0)
# PPO Ratio
ratio = torch.exp(current_logprobs - old_logprob)
# KL Divergence Penalty
kl_div = torch.exp(ref_logprobs - current_logprobs) - (ref_logprobs - current_logprobs) - 1
# Clipped Surrogate Objective
surr1 = ratio * A_i
surr2 = torch.clamp(ratio, 1.0 - args.clip_eps, 1.0 + args.clip_eps) * A_i
policy_loss = -torch.min(surr1, surr2)
loss = policy_loss + args.beta * kl_div
# Gradient accumulation
loss = loss / args.gradient_accumulation_steps
loss.backward()
total_loss += loss.item() * args.gradient_accumulation_steps
total_kl += kl_div.item()
# Step optimizer periodically
if (step_idx + 1) % args.gradient_accumulation_steps == 0 or (step_idx + 1) == len(dataset):
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
optim_steps += 1
# Logging
avg_loss = total_loss / max(1, len(dataset))
avg_kl = total_kl / max(1, len(dataset))
wandb.log({
"iteration": iteration,
"mean_return": mean_return,
"mean_return_per_month": mean_return / max(1, current_max_months),
"current_horizon": current_max_months,
"dataset_size": len(dataset),
"std_return": std_return,
"loss": avg_loss,
"kl_divergence": avg_kl,
"max_return": all_returns.max(),
"min_return": all_returns.min(),
"learning_rate": lr_scheduler.get_last_lr()[0],
})
# Step the learning rate scheduler at the end of each iteration
lr_scheduler.step()
# Save Checkpoint
if iteration % args.save_every == 0:
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{iteration}")
model.save_pretrained(ckpt_dir)
tokenizer.save_pretrained(ckpt_dir)
print(f"Checkpoint saved to {ckpt_dir}")
print("Training complete!")
model.save_pretrained(os.path.join(args.output_dir, "final"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-0.6B", help="Hugging Face model path")
parser.add_argument("--run_name", type=str, default="CropRL_GRPO_Run_1", help="WandB run name")
parser.add_argument("--task", type=str, default="easy_2agent", help="CropRL task identifier")
parser.add_argument("--num_iterations", type=int, default=50, help="Total training iterations")
parser.add_argument("--group_size", type=int, default=8, help="Number of trajectories to collect per iteration (G)")
parser.add_argument("--gradient_accumulation_steps", type=int, default=16, help="Batch size equivalent via grad accumulation")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate for LoRA")
parser.add_argument("--lr_scheduler_type", type=str, default="cosine", help="Scheduler type (cosine, linear)")
parser.add_argument("--warmup_iterations", type=int, default=5, help="Number of warmup iterations")
parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank")
parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha")
parser.add_argument("--clip_eps", type=float, default=0.2, help="PPO clipping parameter")
parser.add_argument("--beta", type=float, default=0.01, help="KL penalty coefficient")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--max_new_tokens", type=int, default=10, help="Max tokens per action generation")
parser.add_argument("--save_every", type=int, default=2, help="Save checkpoint every N iterations")
parser.add_argument("--output_dir", type=str, default="./train/checkpoints", help="Output directory for checkpoints")
args = parser.parse_args()
train(args)

Xet Storage Details

Size:
19.6 kB
·
Xet hash:
d1efd23f7e77b6eec3109731dd377e6fa0fe4a1cff384c24194115c9f64289ee

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.