Spaces:
Sleeping
Sleeping
| """DriftCall training orchestrator — local GPU run. **DEPRECATED.** | |
| This script wires the legacy TRL ``GRPOTrainer`` + Unsloth ``UnslothGRPOTrainer`` | |
| chain to ``cells/step_15/16/17_train_stage*.train()``. That stack hits Unsloth's | |
| broken ``Linear4bit.forward`` patch on Gemma 3n's ``per_layer_model_projection`` | |
| inside TRL's training loop (our lazy fix-up patch fires too late) and breaks on | |
| every TRL/Unsloth version bump. | |
| **Use the new self-contained loop instead:** | |
| scripts/train_driftcall_grpo.py # single stage | |
| scripts/train_full_gemma3n.sh # all three stages | |
| That loop bypasses TRL entirely and drives rollouts/rewards/updates directly, | |
| with controlled patch ordering. See ``docs/modules/training.md`` §3.2. | |
| This file is kept on disk because some environments still reference it. | |
| Do not extend it. New training work goes in ``scripts/train_driftcall_grpo.py``. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| # Add project root to PYTHONPATH so cells/ is importable without install. | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| # transformers 5.x removed the legacy ``TRANSFORMERS_CACHE`` symbol that | |
| # ``llm_blender`` (a TRL 0.24 transitive dep) imports at module load time. | |
| # Restore it BEFORE any ``from trl import ...`` so GRPOTrainer can boot. | |
| def _patch_transformers_cache() -> None: | |
| try: | |
| import transformers.utils.hub as _hub | |
| except Exception: | |
| return | |
| if not hasattr(_hub, "TRANSFORMERS_CACHE"): | |
| _hub.TRANSFORMERS_CACHE = os.environ.get( | |
| "HF_HOME", | |
| os.path.expanduser("~/.cache/huggingface"), | |
| ) | |
| _patch_transformers_cache() | |
| # Unsloth 2026.4.x ships a buggy ``Gemma3nRMSNorm_forward`` patch that | |
| # unconditionally reads ``self.weight``; for Gemma3n's | |
| # ``embedding_post_projection_norm`` (constructed with ``with_scale=False``) | |
| # this raises AttributeError during model.generate(). Override the patched | |
| # ``Gemma3nMultimodalEmbedder.forward`` with a with_scale-aware version. | |
| def _patch_unsloth_gemma3n_rmsnorm() -> None: | |
| try: | |
| import torch | |
| from transformers.models.gemma3n.modeling_gemma3n import ( | |
| Gemma3nMultimodalEmbedder, | |
| ) | |
| except Exception: | |
| return | |
| def _safe_rmsnorm(norm_module: Any, x: Any) -> Any: | |
| # Mirror the canonical Gemma3n RMSNorm forward but respect with_scale. | |
| normed = norm_module._norm(x.float()) | |
| if getattr(norm_module, "with_scale", True): | |
| normed = normed * norm_module.weight.float() | |
| return normed.type_as(x) | |
| def _patched_forward( | |
| self: Any, | |
| input_ids: Any = None, | |
| inputs_embeds: Any = None, | |
| ) -> Any: | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError( | |
| "You must specify exactly one of input_ids or inputs_embeds" | |
| ) | |
| if inputs_embeds is not None: | |
| emb_norm = _safe_rmsnorm(self.soft_embedding_norm, inputs_embeds) | |
| else: | |
| hard_emb = self.embedding(input_ids - self.vocab_offset) | |
| emb_norm = _safe_rmsnorm(self.hard_embedding_norm, hard_emb) | |
| old_dtype = emb_norm.dtype | |
| emb_norm = emb_norm.to(torch.float32) | |
| with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=True): | |
| emb_norm_proj = self.embedding_projection(emb_norm) | |
| emb_norm_proj = emb_norm_proj.to(old_dtype) | |
| return _safe_rmsnorm(self.embedding_post_projection_norm, emb_norm_proj) | |
| Gemma3nMultimodalEmbedder.forward = _patched_forward | |
| def _patch_unsloth_bnb_linear4bit_quant_state() -> None: | |
| """Trigger ``fix_4bit_weight_quant_state_from_module`` on packed 4-bit | |
| weights that bnb stored with ``shape[0] == 1`` (transposed packing). | |
| Unsloth's stock patch only checks ``weight.shape[-1] == 1``; some Gemma3n | |
| layers (notably ``per_layer_model_projection``) ship with the packed | |
| dim on axis 0 instead, so the auto-fix never fires and we crash with | |
| ``mat1 and mat2 shapes cannot be multiplied (..., 1xPACKED)``. | |
| """ | |
| try: | |
| import torch | |
| import bitsandbytes | |
| from unsloth_zoo.temporary_patches.bitsandbytes import ( | |
| fix_4bit_weight_quant_state_from_module, | |
| ) | |
| except Exception: | |
| return | |
| Linear4bit = bitsandbytes.nn.modules.Linear4bit | |
| _orig_forward = Linear4bit.forward | |
| def _safe_forward(self: Any, x: torch.Tensor) -> torch.Tensor: | |
| weight = self.weight | |
| # Detect packed 4-bit tensors with no quant_state and a flat shape | |
| # in either orientation. Trigger the fix routine which restores | |
| # quant_state from the ``module`` attribute and reshapes the weight. | |
| try: | |
| if ( | |
| getattr(weight, "quant_state", None) is None | |
| and weight.dim() == 2 | |
| and (weight.shape[0] == 1 or weight.shape[-1] == 1) | |
| ): | |
| fix_4bit_weight_quant_state_from_module(self) | |
| except Exception: | |
| pass | |
| return _orig_forward(self, x) | |
| Linear4bit.forward = _safe_forward | |
| # --------------------------------------------------------------------------- | |
| # Action parser — extract DriftCallAction from model text output | |
| # --------------------------------------------------------------------------- | |
| def _parse_action(text: str) -> Any: | |
| """Parse a DriftCallAction from the model's assistant turn text. | |
| Tries JSON extraction first; falls back to an ABORT action on parse | |
| failure so the episode terminates cleanly rather than hanging. | |
| """ | |
| from cells.step_04_models import ActionType, DriftCallAction | |
| text = text.strip() | |
| # Try to extract the first JSON object in the text | |
| try: | |
| start = text.index("{") | |
| depth, end = 0, -1 | |
| for i, ch in enumerate(text[start:], start): | |
| if ch == "{": | |
| depth += 1 | |
| elif ch == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| end = i + 1 | |
| break | |
| if end > start: | |
| obj = json.loads(text[start:end]) | |
| action_type_str = obj.get("action_type", "abort") | |
| try: | |
| atype = ActionType(action_type_str) | |
| except ValueError: | |
| atype = ActionType.ABORT | |
| return DriftCallAction( | |
| action_type=atype, | |
| tool_name=obj.get("tool_name"), | |
| tool_args=obj.get("tool_args"), | |
| message=obj.get("message"), | |
| confidence=obj.get("confidence"), | |
| rationale=obj.get("rationale"), | |
| ) | |
| except (ValueError, json.JSONDecodeError, KeyError): | |
| pass | |
| # Fallback: submit if the model said something confidence-like, else abort | |
| lower = text.lower() | |
| if "submit" in lower: | |
| return DriftCallAction( | |
| action_type=ActionType.SUBMIT, | |
| message=text[:200], | |
| confidence=0.5, | |
| ) | |
| return DriftCallAction(action_type=ActionType.ABORT, message=text[:200]) | |
| # --------------------------------------------------------------------------- | |
| # Observation serializer — obs → messages list (training.md §3.2.1) | |
| # --------------------------------------------------------------------------- | |
| def _obs_to_messages( | |
| goal: Any, | |
| obs: Any, | |
| history: list[dict[str, str]], | |
| is_turn_zero: bool, | |
| ) -> list[dict[str, str]]: | |
| """Append the latest observation fields to the message history. | |
| training.md §3.2.1 — returns the updated history list (mutated in-place | |
| for efficiency; callers can deepcopy if needed). | |
| """ | |
| from cells.step_14_custom_trainer import PINNED_SYSTEM_PROMPT | |
| if is_turn_zero: | |
| # Build the system prompt with available tools appended. | |
| tool_schemas = getattr(obs, "available_tools", []) | |
| system_content = PINNED_SYSTEM_PROMPT | |
| if tool_schemas: | |
| system_content += "\nAvailable tools: " + json.dumps( | |
| tool_schemas, ensure_ascii=False, sort_keys=True | |
| ) | |
| history.clear() | |
| history.append({"role": "system", "content": system_content}) | |
| history.append( | |
| {"role": "user", "content": getattr(goal, "seed_utterance", "")} | |
| ) | |
| else: | |
| # Append any new tool results from the last step. | |
| tool_results = getattr(obs, "tool_results", []) | |
| if tool_results: | |
| for tr in tool_results[-1:]: # only the latest tool result | |
| history.append( | |
| { | |
| "role": "tool", | |
| "content": json.dumps( | |
| { | |
| "tool": getattr(tr, "tool_name", ""), | |
| "status": getattr(tr, "status", ""), | |
| "response": getattr(tr, "response", {}), | |
| }, | |
| ensure_ascii=False, | |
| sort_keys=True, | |
| ), | |
| } | |
| ) | |
| # Append drift events if any. | |
| drift_log = getattr(obs, "drift_log", []) | |
| if drift_log: | |
| drift_json = json.dumps( | |
| [ | |
| { | |
| "turn": getattr(d, "turn", 0), | |
| "type": getattr(d, "drift_type", ""), | |
| "domain": getattr(d, "domain", ""), | |
| "description": getattr(d, "description", ""), | |
| } | |
| for d in drift_log[-3:] # last 3 drifts to cap token budget | |
| ], | |
| ensure_ascii=False, | |
| sort_keys=True, | |
| ) | |
| # Append as user message so the model sees the drift signal. | |
| history.append( | |
| {"role": "user", "content": f"[drift] {drift_json}"} | |
| ) | |
| return history | |
| def _derive_rollout_seed(goal: Any, g_index: int, episode_seed: int) -> int: | |
| """Deterministic seed per rollout within a group (training.md §3.2).""" | |
| payload = f"{episode_seed}:{getattr(goal, 'seed_utterance', '')}:{g_index}".encode() | |
| digest = hashlib.blake2b(payload, digest_size=8).digest() | |
| return int.from_bytes(digest, "little") & 0x7FFF_FFFF | |
| # --------------------------------------------------------------------------- | |
| # rollout_group_fn — the core multi-turn inference loop | |
| # --------------------------------------------------------------------------- | |
| def build_rollout_group_fn( | |
| *, | |
| max_turns: int = 8, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.9, | |
| top_p: float = 0.95, | |
| hardware: str = "v100", | |
| ) -> Any: | |
| """Build and return a RolloutGroupFn (training.md §3.2). | |
| Runs G independent multi-turn episodes with the live model. Each rollout: | |
| 1. env.reset(seed=derived_seed) | |
| 2. Serialise obs → messages, generate one action token-by-token | |
| 3. Parse → DriftCallAction → env.step(action) | |
| 4. Repeat until obs.done or max_turns reached | |
| Returns (tuple[Episode, ...], tuple[str, ...]) of length G. | |
| """ | |
| def rollout_group_fn( | |
| *, | |
| model: Any, | |
| tokenizer: Any, | |
| goal: Any, | |
| episode_seed: int, | |
| num_generations: int, | |
| env_factory: Any, | |
| ) -> tuple[tuple[Any, ...], tuple[str, ...]]: | |
| import torch | |
| from cells.step_04_models import ActionType, DriftCallAction | |
| # Apply Gemma3n + bnb patches on first rollout call, AFTER Unsloth has | |
| # already monkey-patched its (broken) versions. | |
| _patch_unsloth_gemma3n_rmsnorm() | |
| _patch_unsloth_bnb_linear4bit_quant_state() | |
| device = next(model.parameters()).device | |
| episodes_out: list[Any] = [] | |
| completions_out: list[str] = [] | |
| for g in range(num_generations): | |
| seed = _derive_rollout_seed(goal, g, episode_seed) | |
| env = env_factory() | |
| obs = env.reset(seed=seed) | |
| history: list[dict[str, str]] = [] | |
| all_responses: list[str] = [] | |
| is_turn_zero = True | |
| for _turn in range(max_turns): | |
| # Build messages for this turn. | |
| _obs_to_messages(goal, obs, history, is_turn_zero) | |
| is_turn_zero = False | |
| # Tokenize the conversation so far. | |
| try: | |
| prompt_str = tokenizer.apply_chat_template( | |
| history, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| except Exception: | |
| prompt_str = " ".join( | |
| m.get("content", "") for m in history | |
| ) | |
| # Gemma 3n's processor is multimodal — pass `text=` explicitly | |
| # so the call dispatches to the text-only branch. | |
| inputs = tokenizer( | |
| text=prompt_str, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=1024, | |
| ).to(device) | |
| # Generate the assistant response. | |
| with torch.no_grad(): | |
| gen_kwargs: dict[str, Any] = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| } | |
| output_ids = model.generate( | |
| **inputs, **gen_kwargs | |
| ) | |
| new_token_ids = output_ids[0][inputs["input_ids"].shape[1]:] | |
| response_text = tokenizer.decode( | |
| new_token_ids, skip_special_tokens=True | |
| ).strip() | |
| all_responses.append(response_text) | |
| history.append({"role": "assistant", "content": response_text}) | |
| # Parse and step the environment. | |
| action = _parse_action(response_text) | |
| obs = env.step(action) | |
| if obs.done: | |
| break | |
| # Collect the completed episode from the env. | |
| episode = env.episode() | |
| episodes_out.append(episode) | |
| completions_out.append("\n".join(all_responses)) | |
| return tuple(episodes_out), tuple(completions_out) | |
| return rollout_group_fn | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| prog="run_driftcall_train", | |
| description="DriftCall 3-stage GRPO training on a local GPU.", | |
| ) | |
| p.add_argument("--stage1-steps", type=int, default=150) | |
| p.add_argument("--stage2-steps", type=int, default=200) | |
| p.add_argument("--stage3-steps", type=int, default=150) | |
| p.add_argument("--hardware", choices=["v100", "h100"], default="v100") | |
| p.add_argument("--output-dir", type=Path, default=Path("/workspace/checkpoints")) | |
| p.add_argument("--eval-episodes", type=int, default=50) | |
| p.add_argument("--probe-episodes", type=int, default=200) | |
| p.add_argument( | |
| "--skip-eval", action="store_true", help="Skip baseline/final eval + probe." | |
| ) | |
| p.add_argument( | |
| "--push-to-hub", | |
| action="store_true", | |
| help="Push trained LoRA to HF Hub after stage 3.", | |
| ) | |
| p.add_argument("--hf-repo", type=str, default=os.environ.get("DRIFTCALL_HF_REPO", "")) | |
| return p.parse_args(argv) | |
| def main(argv: list[str] | None = None) -> int: | |
| args = _parse_args(argv) | |
| out_dir = args.output_dir | |
| print(f"[train] hardware={args.hardware} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'unset')}") | |
| print(f"[train] steps: stage1={args.stage1_steps} stage2={args.stage2_steps} stage3={args.stage3_steps}") | |
| print(f"[train] output_dir={out_dir}") | |
| from cells.step_07_task_generator import generate as task_gen_fn | |
| from cells.step_10_env import DriftCallEnv | |
| from cells.step_12_gemma_boot import BootConfig | |
| boot_config = BootConfig(hardware=args.hardware) | |
| rollout_group_fn = build_rollout_group_fn(hardware=args.hardware) | |
| # --------------------------------------------------------------------------- | |
| # Stage 1 | |
| # --------------------------------------------------------------------------- | |
| print("\n[train] === Stage 1 ===") | |
| from cells.step_15_train_stage1 import train as train_stage1 | |
| def env_factory_s1() -> DriftCallEnv: | |
| return DriftCallEnv(config={"curriculum_stage": 1, "audio_boundary_enabled": False}) | |
| ckpt_s1 = train_stage1( | |
| num_steps=args.stage1_steps, | |
| output_dir=out_dir / "stage1" / "final", | |
| boot_config=boot_config, | |
| task_gen=task_gen_fn, | |
| env_factory=env_factory_s1, | |
| rollout_group_fn=rollout_group_fn, | |
| ) | |
| print(f"[train] Stage 1 complete → {ckpt_s1}") | |
| # --------------------------------------------------------------------------- | |
| # Stage 2 | |
| # --------------------------------------------------------------------------- | |
| print("\n[train] === Stage 2 ===") | |
| from cells.step_16_train_stage2 import train as train_stage2 | |
| def env_factory_s2() -> DriftCallEnv: | |
| return DriftCallEnv(config={"curriculum_stage": 2, "audio_boundary_enabled": False}) | |
| ckpt_s2 = train_stage2( | |
| num_steps=args.stage2_steps, | |
| resume_from=ckpt_s1, | |
| output_dir=out_dir / "stage2" / "final", | |
| boot_config=boot_config, | |
| task_gen=task_gen_fn, | |
| env_factory=env_factory_s2, | |
| rollout_group_fn=rollout_group_fn, | |
| ) | |
| print(f"[train] Stage 2 complete → {ckpt_s2}") | |
| # --------------------------------------------------------------------------- | |
| # Stage 3 | |
| # --------------------------------------------------------------------------- | |
| print("\n[train] === Stage 3 ===") | |
| from cells.step_17_train_stage3 import train as train_stage3 | |
| def env_factory_s3() -> DriftCallEnv: | |
| return DriftCallEnv(config={"curriculum_stage": 3, "audio_boundary_enabled": False}) | |
| ckpt_s3 = train_stage3( | |
| num_steps=args.stage3_steps, | |
| resume_from=ckpt_s2, | |
| output_dir=out_dir / "stage3" / "final", | |
| boot_config=boot_config, | |
| task_gen=task_gen_fn, | |
| env_factory=env_factory_s3, | |
| rollout_group_fn=rollout_group_fn, | |
| ) | |
| print(f"[train] Stage 3 complete → {ckpt_s3}") | |
| if not args.skip_eval: | |
| # --------------------------------------------------------------------------- | |
| # Baseline + Final eval | |
| # --------------------------------------------------------------------------- | |
| print("\n[train] === Baseline eval ===") | |
| from cells.step_18_eval_baseline import eval_baseline | |
| eval_dir = Path(os.environ.get("DRIFTCALL_EVAL_DIR", "/workspace/eval_reports")) | |
| eval_dir.mkdir(parents=True, exist_ok=True) | |
| baseline_report = eval_baseline( | |
| n_episodes=args.eval_episodes, | |
| output_path=eval_dir / "baseline.json", | |
| env_factory=env_factory_s3, | |
| task_gen=task_gen_fn, | |
| ) | |
| print(f"[train] Baseline eval: R1={getattr(baseline_report, 'r1_mean', '?'):.3f}") | |
| print("\n[train] === Final eval ===") | |
| from cells.step_19_eval_final import eval_final | |
| final_report = eval_final( | |
| checkpoint_path=ckpt_s3, | |
| n_episodes=args.eval_episodes, | |
| output_path=eval_dir / "final.json", | |
| env_factory=env_factory_s3, | |
| task_gen=task_gen_fn, | |
| ) | |
| print(f"[train] Final eval: R1={getattr(final_report, 'r1_mean', '?'):.3f}") | |
| if args.push_to_hub and args.hf_repo: | |
| print(f"\n[train] === Pushing LoRA to HF Hub: {args.hf_repo} ===") | |
| from cells.step_24_deploy_hf import push_lora_to_hub | |
| result = push_lora_to_hub( | |
| ckpt_s3, repo_id=args.hf_repo, token=os.environ.get("HF_TOKEN") | |
| ) | |
| if result.success: | |
| print(f"[train] Pushed to hub: {args.hf_repo}") | |
| else: | |
| print(f"[train] Push failed (rc={result.return_code}): {result.stderr[:200]}") | |
| print("\n[train] === COMPLETE ===") | |
| print(f"[train] Checkpoints at: {out_dir}/stage{{1,2,3}}/final") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |