atomwalk12's picture
initial commit
0dd6c2f
import asyncio
import concurrent.futures
import copy
import logging
import os
import traceback
from datetime import datetime
import art
import wandb
from art.local import LocalBackend
from art.utils import iterate_dataset, limit_concurrency
from dotenv import load_dotenv
from tqdm.asyncio import tqdm_asyncio
from linalg_zero.grpo.agents.tool_calling_agent import ToolCallingRLAgent
from linalg_zero.grpo.envs import get_env
from linalg_zero.grpo.general_rm import calculate_reward, create_general_rm_trajectory_groups
from linalg_zero.grpo.rl_utils import (
log_trajectory_to_openpipe,
write_eval_trajectories,
)
from linalg_zero.grpo.run import agent_factory
from linalg_zero.grpo.task_selection import (
get_task_indices,
)
from linalg_zero.grpo.types import LinAlgPolicyConfig, RunConfig, SolveResult
from linalg_zero.grpo.utils.checkpointing import archive_checkpoint, delete_checkpoints_keep_best
from linalg_zero.grpo.utils.curriculum import (
COVERAGE_LOG_MAX_TOOL_CALLS_BUCKET,
CurriculumCoverageTracker,
iterate_curriculum,
prefill_coverage_tracker,
)
from linalg_zero.grpo.utils.eval_metrics import (
aggregate_retry_summaries,
log_eval_aggregate,
log_group_diversity,
summarize_trajectories,
)
from linalg_zero.grpo.utils.hf_upload import push_experiment_dir_to_hf_sync
from linalg_zero.grpo.utils.trajectory_messages import clean_messages
# Load environment variables
load_dotenv(override=True)
# Suppress LiteLLM logging spam
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
@limit_concurrency(256)
async def rollout_linalg_task(
model: art.Model[LinAlgPolicyConfig],
task_index: int,
step: int = 0,
phase: str = "train",
reward_type: str = "real",
is_shadow: bool = False,
) -> art.Trajectory:
"""
Generate a trajectory for a single tau-bench task using the given model.
This adapts the tau-bench evaluation loop for RL trajectory generation.
Now truly async.
"""
# print(f"Rolling out task {task_index} (step {step}, phase {phase})")
config = copy.deepcopy(model.config.run_config)
success_reward = 1.0
if is_shadow:
config.model = "gpt-4.1"
config.model_provider = "openai"
config.api_key = None
config.base_url = None
# Get isolated environment for this task
env = get_env(
config.env,
user_strategy=config.user_strategy,
task_split=phase,
dataset_path=config.dataset_path,
task_index=task_index,
)
if config.add_no_think:
env.wiki += "\n/no_think"
# Create agent with the trainable model
agent = agent_factory(
tools_info=env.tools_info,
wiki=env.wiki,
config=config,
)
if not isinstance(agent, ToolCallingRLAgent):
raise TypeError("Agent must be a ToolCallingRLAgent")
# Create trajectory object
traj = art.Trajectory(
messages_and_choices=[],
tools=env.tools_info,
reward=0,
metadata={
"task_index": str(task_index),
"env": config.env,
"training_step": str(step),
"phase": phase,
"model": model.name,
"reward_type": config.reward_type,
"is_shadow": str(is_shadow),
},
)
try:
# Run the agent on the task (now async call)
result = await agent.solve(
env=env,
task_index=task_index,
max_assistant_turns=config.max_assistant_turns,
)
# optimal_trajectory: 1 if the entire trajectory is optimal (correct answer + optimal efficiency)
optimal_trajectory = 1 if abs(result.reward - success_reward) <= 1e-6 else 0
# Convert result to trajectory format
traj.reward, explanation = await calculate_reward(result, config)
# Build metrics dictionary
reward_info = result.info.get("reward_info") or {}
info = reward_info.get("info") or {}
outputs = info.get("outputs") or {}
has_valid_outputs = "correctness_score" in outputs
total_completion_tokens = result.info.get("total_completion_tokens")
if not isinstance(total_completion_tokens, int | float):
total_completion_tokens = result.info["avg_completion_tokens"] * result.info["total_steps"]
traj.metrics = {
"total_steps": result.info["total_steps"],
"final_prompt_tokens": result.info["final_prompt_tokens"],
"total_completion_tokens": float(total_completion_tokens),
"avg_completion_tokens": result.info["avg_completion_tokens"],
"max_completion_tokens": result.info["max_completion_tokens"],
"forced_stop": result.info["forced_stop"],
"optimal_trajectory": optimal_trajectory,
"valid_trajectory": 1 if has_valid_outputs else 0,
}
if has_valid_outputs:
traj.metrics.update({
"correctness_score": outputs["correctness_score"],
"format_score": outputs["format_score"],
"tool_success_score": outputs["tool_success_score"],
"efficiency_penalty": outputs["efficiency_penalty"],
"num_turns": outputs["num_turns"],
"expected_turns": outputs["expected_turns"],
"turn_deviation": outputs["num_turns"] - outputs["expected_turns"],
})
traj.metadata.update(result.info)
traj.metadata["reward"] = "pending_general_rm" if config.reward_type == "general_rm" else traj.reward
traj.metadata["optimal_trajectory"] = traj.metrics["optimal_trajectory"]
traj.metadata["judge_explanation"] = explanation
if config.messages_only:
traj.messages_and_choices = clean_messages(result.messages)
else:
traj.messages_and_choices = agent.create_messages_and_choices()
except Exception as e:
print(f"Error in rollout for task {task_index}: {e}")
traj.reward = -1.0
traj.metadata["error"] = str(e)
traj.metadata["traceback"] = traceback.format_exc()
traj.messages_and_choices = agent.create_messages_and_choices()
result = SolveResult(
reward=-1.0,
info={"error": str(e)},
messages=agent.messages,
total_cost=0.0,
)
traj.finish()
# Log to langfuse/openpipe
try:
await log_trajectory_to_openpipe(traj, result.messages)
except Exception as e:
print(f"Error logging trajectory to openpipe: {e}")
# print(f"Finished rolling out task {task_index} (reward: {traj.reward})")
return traj
async def evaluate_model(
model: art.Model[LinAlgPolicyConfig],
config: RunConfig,
step: int,
val_task_indices: list[int],
split: str = "val",
) -> float:
"""Evaluate the model on a subset of tasks"""
eval_retries = 1
try:
training_config = model.config.training_config
if training_config is not None and getattr(training_config, "eval_retries", None) is not None:
eval_retries = max(1, int(training_config.eval_retries))
except Exception:
eval_retries = 1
print(f"Evaluating model on {len(val_task_indices)} tasks (passes={eval_retries})...")
summaries: list[dict[str, float]] = []
model_step = await model.get_step()
eval_step = max(step, model_step)
run_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
os.makedirs(config.log_dir, exist_ok=True)
for pass_idx in range(eval_retries):
trajectories = await art.gather_trajectories(
rollout_linalg_task(
model=model,
task_index=val_task_index,
step=eval_step,
phase=split,
reward_type=config.reward_type,
)
for val_task_index in val_task_indices
)
output_path = os.path.join(
config.log_dir,
f"eval_trajectories_{split}_step_{eval_step}_pass_{pass_idx}_{run_id}.jsonl",
)
write_eval_trajectories(
output_path=output_path,
trajectories=trajectories,
eval_step=eval_step,
pass_idx=pass_idx,
split=split,
)
print(f"Wrote eval trajectories to {output_path}")
summaries.append(summarize_trajectories(trajectories))
if eval_retries > 1:
print(f"Eval pass {pass_idx + 1}/{eval_retries}: reward={summaries[-1].get('reward', float('nan')):.4f}")
aggregate = aggregate_retry_summaries(summaries=summaries)
log_eval_aggregate(split=split, step=eval_step, aggregate=aggregate)
print(
f"[{split}] step={eval_step} n={int(aggregate.get('n', 0))} "
f"reward_mean={aggregate.get('reward_mean', float('nan')):.4f} "
f"reward_retry_std={aggregate.get('reward_std', float('nan')):.4f}"
)
if "optimal_trajectory_mean" in aggregate:
print(f"[{split}] optimal_trajectory_mean={aggregate.get('optimal_trajectory_mean', float('nan')):.4f}")
# Return mean reward across retries (for callers that use the float).
return float(aggregate.get("reward_mean", float("nan")))
async def test(model: art.TrainableModel[LinAlgPolicyConfig]):
"""Main evaluation loop"""
loop = asyncio.get_event_loop()
big_pool = concurrent.futures.ThreadPoolExecutor(max_workers=50)
loop.set_default_executor(big_pool)
config = model.config.run_config
training_config = model.config.training_config
if training_config is None:
raise ValueError("Training config is not set")
register_kwargs = {}
if model.config.training_config.chat_template is not None:
register_kwargs["_openai_client_config"] = art.dev.OpenAIServerConfig(
server_args=art.dev.ServerArgs(chat_template=model.config.training_config.chat_template)
)
with LocalBackend(in_process=config.in_process) as backend:
# Resume/fork must happen *before* register() so the Unsloth/vLLM service
# loads the correct LoRA adapter on startup.
if model.config.run_config.resume:
await backend._experimental_fork_checkpoint(
model,
from_model=model.config.run_config.resume_from,
from_project=model.config.run_config.project,
not_after_step=model.config.run_config.resume_step,
verbose=True,
)
# Setup model with backend (starts the inference server + loads LoRA)
await model.register(backend, **register_kwargs)
config.api_key = model.inference_api_key
config.base_url = model.inference_base_url
config.base_model = model.base_model
print("Loading training tasks...")
# Load validation environment
test_env = get_env(
config.env,
user_strategy=config.user_strategy,
task_split="test",
dataset_path=config.dataset_path,
)
test_task_indices = get_task_indices(
task_ids=None,
start_index=0,
end_index=-1,
tasks=test_env.tasks,
curriculum=None,
difficulty=None,
seed=config.seed,
)
print(f"Validation on {len(test_task_indices)} tasks")
# Final evaluation
print("\n--- Final Evaluation ---")
final_step = await model.get_step()
final_reward = await evaluate_model(model, config, final_step, test_task_indices, split="test")
print(f"Final average reward: {final_reward}")
print("Evaluation complete!")
async def train(model: art.TrainableModel[LinAlgPolicyConfig]): # noqa: C901
"""Main training loop adapted from art-e example"""
loop = asyncio.get_event_loop()
big_pool = concurrent.futures.ThreadPoolExecutor(max_workers=50)
loop.set_default_executor(big_pool)
config = model.config.run_config
training_config = model.config.training_config
if training_config is None:
raise ValueError("Training config is not set")
register_kwargs = {}
if model.config.training_config.chat_template is not None:
register_kwargs["_openai_client_config"] = art.dev.OpenAIServerConfig(
server_args=art.dev.ServerArgs(chat_template=model.config.training_config.chat_template)
)
with LocalBackend(in_process=config.in_process) as backend:
# Resume/fork must happen *before* register() so the Unsloth/vLLM service
# loads the correct LoRA adapter on startup.
if model.config.run_config.resume:
await backend._experimental_fork_checkpoint(
model,
from_model=model.config.run_config.resume_from,
from_project=model.config.run_config.project,
not_after_step=model.config.run_config.resume_step,
verbose=True,
)
else:
print("Will continue training from previous latest checkpoint")
# Setup model with backend (starts the inference server + loads LoRA)
await model.register(backend, **register_kwargs)
config.api_key = model.inference_api_key
config.base_url = model.inference_base_url
config.base_model = model.base_model
print("Loading training tasks...")
# Get environment to access tasks
train_env = get_env(
config.env,
user_strategy=config.user_strategy,
task_split="train",
dataset_path=config.dataset_path,
)
# Load validation environment
val_env = get_env(
config.env,
user_strategy=config.user_strategy,
task_split="val",
dataset_path=config.dataset_path,
)
train_task_indices = get_task_indices(
task_ids=config.task_ids,
start_index=config.start_index,
end_index=config.end_index,
tasks=train_env.tasks,
curriculum=None,
difficulty=None,
seed=config.seed,
)
val_task_indices = get_task_indices(
task_ids=config.val_task_ids,
start_index=config.start_val_index,
end_index=config.end_val_index,
tasks=val_env.tasks,
curriculum=None,
difficulty=None,
seed=config.seed,
)
print(f"Training on {len(train_task_indices)} tasks")
print(f"Validation on {len(val_task_indices)} tasks")
coverage: CurriculumCoverageTracker | None = None
if config.curriculum is not None and config.curriculum.enabled:
tool_calls_by_index = {idx: len(train_env.tasks[idx].actions) for idx in train_task_indices}
coverage = CurriculumCoverageTracker(
tool_calls_by_index=tool_calls_by_index,
max_bucket_to_log=COVERAGE_LOG_MAX_TOOL_CALLS_BUCKET,
)
initial_step = await model.get_step()
base_epoch_size = len(train_task_indices)
if config.curriculum is not None and config.curriculum.enabled:
if coverage is not None:
prefill_coverage_tracker(
coverage=coverage,
initial_step=initial_step,
train_task_indices=train_task_indices,
tasks=train_env.tasks,
config=config,
training_config=training_config,
)
train_iterator = iterate_curriculum(
base_epoch_size=base_epoch_size,
groups_per_step=training_config.groups_per_step,
num_epochs=training_config.num_epochs,
initial_step=initial_step,
tasks=train_env.tasks,
config=config,
seed=config.seed,
)
else:
# Training iterator
train_iterator = iterate_dataset(
train_task_indices,
groups_per_step=training_config.groups_per_step,
num_epochs=training_config.num_epochs,
initial_step=initial_step,
)
for batch in train_iterator:
print(f"\n--- Training Step {batch.step} (Epoch {batch.epoch}, Step {batch.epoch_step}) ---")
if coverage is not None:
coverage_metrics = coverage.update(step=batch.step, sampled_indices=list(batch.items))
if wandb.run is not None:
wandb.log(coverage_metrics, step=batch.step)
# Evaluation
if batch.step % training_config.eval_steps == 0 and not config.skip_eval:
print(f"\n--- Evaluating at Step {batch.step} ---")
await evaluate_model(model, config, batch.step, val_task_indices)
archive_checkpoint(model=model, step=await model.get_step(), split="val")
while True:
# proceed = input("Delete all previous checkpoints? (yes/no/exit): ").lower().strip()
proceed = "yes"
if proceed == "yes":
print("Deleting checkpoints...")
await delete_checkpoints_keep_best(model)
break
elif proceed == "no":
print("Skipping checkpoint deletion.")
break
elif proceed == "exit":
print("Exiting...")
if wandb.run is not None:
wandb.finish()
return
else:
print("Please type 'yes', 'no', or 'exit'.")
# Generate trajectory groups
print(f"Generating trajectories for {len(batch.items)} tasks...")
groups = await art.gather_trajectory_groups(
art.TrajectoryGroup(
rollout_linalg_task(
model=model,
task_index=task_index,
step=batch.step,
phase="train",
reward_type=config.reward_type,
is_shadow=config.add_shadow_trajectory
and rollout_idx % training_config.trajectories_per_group == 0,
)
for rollout_idx in range(training_config.trajectories_per_group)
)
for task_index in batch.items
)
# await model.log(groups, split="train")
log_group_diversity(step=batch.step, groups=groups, split="train")
if config.reward_type == "general_rm":
print("Creating general RM trajectory groups...")
updated_groups = await tqdm_asyncio.gather(
*[create_general_rm_trajectory_groups(group, config) for group in groups],
desc="Creating general RM trajectory groups",
total=len(groups),
)
groups = updated_groups
# Training step
print(f"Training on {len(groups)} trajectory groups...")
dev_train_config: art.dev.TrainConfig = {
"plot_tensors": config.plot_tensors,
"importance_sampling_level": training_config.importance_sampling_level,
"allow_training_without_logprobs": bool(config.messages_only),
"scale_rewards": training_config.scale_rewards,
}
if training_config.epsilon is not None:
dev_train_config["epsilon"] = training_config.epsilon
if training_config.epsilon_high is not None:
dev_train_config["epsilon_high"] = training_config.epsilon_high
if training_config.max_negative_advantage_importance_sampling_weight is not None:
dev_train_config["max_negative_advantage_importance_sampling_weight"] = (
training_config.max_negative_advantage_importance_sampling_weight
)
if training_config.truncated_importance_sampling is not None:
dev_train_config["truncated_importance_sampling"] = training_config.truncated_importance_sampling
await model.train(
groups,
config=art.TrainConfig(learning_rate=training_config.learning_rate, beta=training_config.beta),
_config=dev_train_config,
)
if config.is_multi_gpu:
await delete_checkpoints_keep_best(model)
# Log progress
total_reward = sum(sum(traj.reward for traj in group.trajectories) for group in groups)
num_trajectories = sum(len(group.trajectories) for group in groups)
avg_reward = total_reward / num_trajectories if num_trajectories > 0 else 0
print(f"Step {batch.step}: Average training reward = {avg_reward}")
# Final evaluation
print("\n--- Final Evaluation ---")
final_step = await model.get_step()
final_reward = await evaluate_model(model, config, final_step, val_task_indices)
archive_checkpoint(model=model, step=await model.get_step(), split="val")
wandb.finish()
print(f"Final average reward: {final_reward}")
# Optional post-training upload to HF Hub (enabled when `HF_HUB_NAMESPACE` is set).
try:
await asyncio.to_thread(push_experiment_dir_to_hf_sync, model=model)
except Exception as e:
print(f"[hf] Warning: post-training upload failed: {e}")
print("Training completed!")