driftcall / scripts /run_driftcall_train.py
saumilyajj's picture
Upload folder using huggingface_hub
f2df60e verified
"""DriftCall training orchestrator — local GPU run. **DEPRECATED.**
This script wires the legacy TRL ``GRPOTrainer`` + Unsloth ``UnslothGRPOTrainer``
chain to ``cells/step_15/16/17_train_stage*.train()``. That stack hits Unsloth's
broken ``Linear4bit.forward`` patch on Gemma 3n's ``per_layer_model_projection``
inside TRL's training loop (our lazy fix-up patch fires too late) and breaks on
every TRL/Unsloth version bump.
**Use the new self-contained loop instead:**
scripts/train_driftcall_grpo.py # single stage
scripts/train_full_gemma3n.sh # all three stages
That loop bypasses TRL entirely and drives rollouts/rewards/updates directly,
with controlled patch ordering. See ``docs/modules/training.md`` §3.2.
This file is kept on disk because some environments still reference it.
Do not extend it. New training work goes in ``scripts/train_driftcall_grpo.py``.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
import sys
from pathlib import Path
from typing import Any
# Add project root to PYTHONPATH so cells/ is importable without install.
_REPO_ROOT = Path(__file__).resolve().parent.parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
# transformers 5.x removed the legacy ``TRANSFORMERS_CACHE`` symbol that
# ``llm_blender`` (a TRL 0.24 transitive dep) imports at module load time.
# Restore it BEFORE any ``from trl import ...`` so GRPOTrainer can boot.
def _patch_transformers_cache() -> None:
try:
import transformers.utils.hub as _hub
except Exception:
return
if not hasattr(_hub, "TRANSFORMERS_CACHE"):
_hub.TRANSFORMERS_CACHE = os.environ.get(
"HF_HOME",
os.path.expanduser("~/.cache/huggingface"),
)
_patch_transformers_cache()
# Unsloth 2026.4.x ships a buggy ``Gemma3nRMSNorm_forward`` patch that
# unconditionally reads ``self.weight``; for Gemma3n's
# ``embedding_post_projection_norm`` (constructed with ``with_scale=False``)
# this raises AttributeError during model.generate(). Override the patched
# ``Gemma3nMultimodalEmbedder.forward`` with a with_scale-aware version.
def _patch_unsloth_gemma3n_rmsnorm() -> None:
try:
import torch
from transformers.models.gemma3n.modeling_gemma3n import (
Gemma3nMultimodalEmbedder,
)
except Exception:
return
def _safe_rmsnorm(norm_module: Any, x: Any) -> Any:
# Mirror the canonical Gemma3n RMSNorm forward but respect with_scale.
normed = norm_module._norm(x.float())
if getattr(norm_module, "with_scale", True):
normed = normed * norm_module.weight.float()
return normed.type_as(x)
def _patched_forward(
self: Any,
input_ids: Any = None,
inputs_embeds: Any = None,
) -> Any:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is not None:
emb_norm = _safe_rmsnorm(self.soft_embedding_norm, inputs_embeds)
else:
hard_emb = self.embedding(input_ids - self.vocab_offset)
emb_norm = _safe_rmsnorm(self.hard_embedding_norm, hard_emb)
old_dtype = emb_norm.dtype
emb_norm = emb_norm.to(torch.float32)
with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=True):
emb_norm_proj = self.embedding_projection(emb_norm)
emb_norm_proj = emb_norm_proj.to(old_dtype)
return _safe_rmsnorm(self.embedding_post_projection_norm, emb_norm_proj)
Gemma3nMultimodalEmbedder.forward = _patched_forward
def _patch_unsloth_bnb_linear4bit_quant_state() -> None:
"""Trigger ``fix_4bit_weight_quant_state_from_module`` on packed 4-bit
weights that bnb stored with ``shape[0] == 1`` (transposed packing).
Unsloth's stock patch only checks ``weight.shape[-1] == 1``; some Gemma3n
layers (notably ``per_layer_model_projection``) ship with the packed
dim on axis 0 instead, so the auto-fix never fires and we crash with
``mat1 and mat2 shapes cannot be multiplied (..., 1xPACKED)``.
"""
try:
import torch
import bitsandbytes
from unsloth_zoo.temporary_patches.bitsandbytes import (
fix_4bit_weight_quant_state_from_module,
)
except Exception:
return
Linear4bit = bitsandbytes.nn.modules.Linear4bit
_orig_forward = Linear4bit.forward
def _safe_forward(self: Any, x: torch.Tensor) -> torch.Tensor:
weight = self.weight
# Detect packed 4-bit tensors with no quant_state and a flat shape
# in either orientation. Trigger the fix routine which restores
# quant_state from the ``module`` attribute and reshapes the weight.
try:
if (
getattr(weight, "quant_state", None) is None
and weight.dim() == 2
and (weight.shape[0] == 1 or weight.shape[-1] == 1)
):
fix_4bit_weight_quant_state_from_module(self)
except Exception:
pass
return _orig_forward(self, x)
Linear4bit.forward = _safe_forward
# ---------------------------------------------------------------------------
# Action parser — extract DriftCallAction from model text output
# ---------------------------------------------------------------------------
def _parse_action(text: str) -> Any:
"""Parse a DriftCallAction from the model's assistant turn text.
Tries JSON extraction first; falls back to an ABORT action on parse
failure so the episode terminates cleanly rather than hanging.
"""
from cells.step_04_models import ActionType, DriftCallAction
text = text.strip()
# Try to extract the first JSON object in the text
try:
start = text.index("{")
depth, end = 0, -1
for i, ch in enumerate(text[start:], start):
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
end = i + 1
break
if end > start:
obj = json.loads(text[start:end])
action_type_str = obj.get("action_type", "abort")
try:
atype = ActionType(action_type_str)
except ValueError:
atype = ActionType.ABORT
return DriftCallAction(
action_type=atype,
tool_name=obj.get("tool_name"),
tool_args=obj.get("tool_args"),
message=obj.get("message"),
confidence=obj.get("confidence"),
rationale=obj.get("rationale"),
)
except (ValueError, json.JSONDecodeError, KeyError):
pass
# Fallback: submit if the model said something confidence-like, else abort
lower = text.lower()
if "submit" in lower:
return DriftCallAction(
action_type=ActionType.SUBMIT,
message=text[:200],
confidence=0.5,
)
return DriftCallAction(action_type=ActionType.ABORT, message=text[:200])
# ---------------------------------------------------------------------------
# Observation serializer — obs → messages list (training.md §3.2.1)
# ---------------------------------------------------------------------------
def _obs_to_messages(
goal: Any,
obs: Any,
history: list[dict[str, str]],
is_turn_zero: bool,
) -> list[dict[str, str]]:
"""Append the latest observation fields to the message history.
training.md §3.2.1 — returns the updated history list (mutated in-place
for efficiency; callers can deepcopy if needed).
"""
from cells.step_14_custom_trainer import PINNED_SYSTEM_PROMPT
if is_turn_zero:
# Build the system prompt with available tools appended.
tool_schemas = getattr(obs, "available_tools", [])
system_content = PINNED_SYSTEM_PROMPT
if tool_schemas:
system_content += "\nAvailable tools: " + json.dumps(
tool_schemas, ensure_ascii=False, sort_keys=True
)
history.clear()
history.append({"role": "system", "content": system_content})
history.append(
{"role": "user", "content": getattr(goal, "seed_utterance", "")}
)
else:
# Append any new tool results from the last step.
tool_results = getattr(obs, "tool_results", [])
if tool_results:
for tr in tool_results[-1:]: # only the latest tool result
history.append(
{
"role": "tool",
"content": json.dumps(
{
"tool": getattr(tr, "tool_name", ""),
"status": getattr(tr, "status", ""),
"response": getattr(tr, "response", {}),
},
ensure_ascii=False,
sort_keys=True,
),
}
)
# Append drift events if any.
drift_log = getattr(obs, "drift_log", [])
if drift_log:
drift_json = json.dumps(
[
{
"turn": getattr(d, "turn", 0),
"type": getattr(d, "drift_type", ""),
"domain": getattr(d, "domain", ""),
"description": getattr(d, "description", ""),
}
for d in drift_log[-3:] # last 3 drifts to cap token budget
],
ensure_ascii=False,
sort_keys=True,
)
# Append as user message so the model sees the drift signal.
history.append(
{"role": "user", "content": f"[drift] {drift_json}"}
)
return history
def _derive_rollout_seed(goal: Any, g_index: int, episode_seed: int) -> int:
"""Deterministic seed per rollout within a group (training.md §3.2)."""
payload = f"{episode_seed}:{getattr(goal, 'seed_utterance', '')}:{g_index}".encode()
digest = hashlib.blake2b(payload, digest_size=8).digest()
return int.from_bytes(digest, "little") & 0x7FFF_FFFF
# ---------------------------------------------------------------------------
# rollout_group_fn — the core multi-turn inference loop
# ---------------------------------------------------------------------------
def build_rollout_group_fn(
*,
max_turns: int = 8,
max_new_tokens: int = 512,
temperature: float = 0.9,
top_p: float = 0.95,
hardware: str = "v100",
) -> Any:
"""Build and return a RolloutGroupFn (training.md §3.2).
Runs G independent multi-turn episodes with the live model. Each rollout:
1. env.reset(seed=derived_seed)
2. Serialise obs → messages, generate one action token-by-token
3. Parse → DriftCallAction → env.step(action)
4. Repeat until obs.done or max_turns reached
Returns (tuple[Episode, ...], tuple[str, ...]) of length G.
"""
def rollout_group_fn(
*,
model: Any,
tokenizer: Any,
goal: Any,
episode_seed: int,
num_generations: int,
env_factory: Any,
) -> tuple[tuple[Any, ...], tuple[str, ...]]:
import torch
from cells.step_04_models import ActionType, DriftCallAction
# Apply Gemma3n + bnb patches on first rollout call, AFTER Unsloth has
# already monkey-patched its (broken) versions.
_patch_unsloth_gemma3n_rmsnorm()
_patch_unsloth_bnb_linear4bit_quant_state()
device = next(model.parameters()).device
episodes_out: list[Any] = []
completions_out: list[str] = []
for g in range(num_generations):
seed = _derive_rollout_seed(goal, g, episode_seed)
env = env_factory()
obs = env.reset(seed=seed)
history: list[dict[str, str]] = []
all_responses: list[str] = []
is_turn_zero = True
for _turn in range(max_turns):
# Build messages for this turn.
_obs_to_messages(goal, obs, history, is_turn_zero)
is_turn_zero = False
# Tokenize the conversation so far.
try:
prompt_str = tokenizer.apply_chat_template(
history,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
prompt_str = " ".join(
m.get("content", "") for m in history
)
# Gemma 3n's processor is multimodal — pass `text=` explicitly
# so the call dispatches to the text-only branch.
inputs = tokenizer(
text=prompt_str,
return_tensors="pt",
truncation=True,
max_length=1024,
).to(device)
# Generate the assistant response.
with torch.no_grad():
gen_kwargs: dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"pad_token_id": tokenizer.eos_token_id,
}
output_ids = model.generate(
**inputs, **gen_kwargs
)
new_token_ids = output_ids[0][inputs["input_ids"].shape[1]:]
response_text = tokenizer.decode(
new_token_ids, skip_special_tokens=True
).strip()
all_responses.append(response_text)
history.append({"role": "assistant", "content": response_text})
# Parse and step the environment.
action = _parse_action(response_text)
obs = env.step(action)
if obs.done:
break
# Collect the completed episode from the env.
episode = env.episode()
episodes_out.append(episode)
completions_out.append("\n".join(all_responses))
return tuple(episodes_out), tuple(completions_out)
return rollout_group_fn
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(
prog="run_driftcall_train",
description="DriftCall 3-stage GRPO training on a local GPU.",
)
p.add_argument("--stage1-steps", type=int, default=150)
p.add_argument("--stage2-steps", type=int, default=200)
p.add_argument("--stage3-steps", type=int, default=150)
p.add_argument("--hardware", choices=["v100", "h100"], default="v100")
p.add_argument("--output-dir", type=Path, default=Path("/workspace/checkpoints"))
p.add_argument("--eval-episodes", type=int, default=50)
p.add_argument("--probe-episodes", type=int, default=200)
p.add_argument(
"--skip-eval", action="store_true", help="Skip baseline/final eval + probe."
)
p.add_argument(
"--push-to-hub",
action="store_true",
help="Push trained LoRA to HF Hub after stage 3.",
)
p.add_argument("--hf-repo", type=str, default=os.environ.get("DRIFTCALL_HF_REPO", ""))
return p.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv)
out_dir = args.output_dir
print(f"[train] hardware={args.hardware} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'unset')}")
print(f"[train] steps: stage1={args.stage1_steps} stage2={args.stage2_steps} stage3={args.stage3_steps}")
print(f"[train] output_dir={out_dir}")
from cells.step_07_task_generator import generate as task_gen_fn
from cells.step_10_env import DriftCallEnv
from cells.step_12_gemma_boot import BootConfig
boot_config = BootConfig(hardware=args.hardware)
rollout_group_fn = build_rollout_group_fn(hardware=args.hardware)
# ---------------------------------------------------------------------------
# Stage 1
# ---------------------------------------------------------------------------
print("\n[train] === Stage 1 ===")
from cells.step_15_train_stage1 import train as train_stage1
def env_factory_s1() -> DriftCallEnv:
return DriftCallEnv(config={"curriculum_stage": 1, "audio_boundary_enabled": False})
ckpt_s1 = train_stage1(
num_steps=args.stage1_steps,
output_dir=out_dir / "stage1" / "final",
boot_config=boot_config,
task_gen=task_gen_fn,
env_factory=env_factory_s1,
rollout_group_fn=rollout_group_fn,
)
print(f"[train] Stage 1 complete → {ckpt_s1}")
# ---------------------------------------------------------------------------
# Stage 2
# ---------------------------------------------------------------------------
print("\n[train] === Stage 2 ===")
from cells.step_16_train_stage2 import train as train_stage2
def env_factory_s2() -> DriftCallEnv:
return DriftCallEnv(config={"curriculum_stage": 2, "audio_boundary_enabled": False})
ckpt_s2 = train_stage2(
num_steps=args.stage2_steps,
resume_from=ckpt_s1,
output_dir=out_dir / "stage2" / "final",
boot_config=boot_config,
task_gen=task_gen_fn,
env_factory=env_factory_s2,
rollout_group_fn=rollout_group_fn,
)
print(f"[train] Stage 2 complete → {ckpt_s2}")
# ---------------------------------------------------------------------------
# Stage 3
# ---------------------------------------------------------------------------
print("\n[train] === Stage 3 ===")
from cells.step_17_train_stage3 import train as train_stage3
def env_factory_s3() -> DriftCallEnv:
return DriftCallEnv(config={"curriculum_stage": 3, "audio_boundary_enabled": False})
ckpt_s3 = train_stage3(
num_steps=args.stage3_steps,
resume_from=ckpt_s2,
output_dir=out_dir / "stage3" / "final",
boot_config=boot_config,
task_gen=task_gen_fn,
env_factory=env_factory_s3,
rollout_group_fn=rollout_group_fn,
)
print(f"[train] Stage 3 complete → {ckpt_s3}")
if not args.skip_eval:
# ---------------------------------------------------------------------------
# Baseline + Final eval
# ---------------------------------------------------------------------------
print("\n[train] === Baseline eval ===")
from cells.step_18_eval_baseline import eval_baseline
eval_dir = Path(os.environ.get("DRIFTCALL_EVAL_DIR", "/workspace/eval_reports"))
eval_dir.mkdir(parents=True, exist_ok=True)
baseline_report = eval_baseline(
n_episodes=args.eval_episodes,
output_path=eval_dir / "baseline.json",
env_factory=env_factory_s3,
task_gen=task_gen_fn,
)
print(f"[train] Baseline eval: R1={getattr(baseline_report, 'r1_mean', '?'):.3f}")
print("\n[train] === Final eval ===")
from cells.step_19_eval_final import eval_final
final_report = eval_final(
checkpoint_path=ckpt_s3,
n_episodes=args.eval_episodes,
output_path=eval_dir / "final.json",
env_factory=env_factory_s3,
task_gen=task_gen_fn,
)
print(f"[train] Final eval: R1={getattr(final_report, 'r1_mean', '?'):.3f}")
if args.push_to_hub and args.hf_repo:
print(f"\n[train] === Pushing LoRA to HF Hub: {args.hf_repo} ===")
from cells.step_24_deploy_hf import push_lora_to_hub
result = push_lora_to_hub(
ckpt_s3, repo_id=args.hf_repo, token=os.environ.get("HF_TOKEN")
)
if result.success:
print(f"[train] Pushed to hub: {args.hf_repo}")
else:
print(f"[train] Push failed (rc={result.return_code}): {result.stderr[:200]}")
print("\n[train] === COMPLETE ===")
print(f"[train] Checkpoints at: {out_dir}/stage{{1,2,3}}/final")
return 0
if __name__ == "__main__":
raise SystemExit(main())