Instructions to use nraptisss/tmf921-intent-training with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use nraptisss/tmf921-intent-training with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """QLoRA SFT training for TMF921 intent-to-config research dataset. | |
| Designed for a single RTX 6000 Ada 48/50GB server. Uses TRL SFTTrainer with PEFT QLoRA. | |
| """ | |
| import argparse | |
| import math | |
| import os | |
| import re | |
| from pathlib import Path | |
| import torch | |
| from datasets import load_dataset | |
| from peft import LoraConfig | |
| from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback, set_seed | |
| from trl import SFTConfig, SFTTrainer | |
| from tmf921_train.utils import load_config, write_json | |
| try: | |
| import trackio | |
| except Exception: # pragma: no cover | |
| trackio = None | |
| class TrackioAlertCallback(TrainerCallback): | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if not state.is_world_process_zero or not logs or trackio is None: | |
| return | |
| loss = logs.get("loss") | |
| grad_norm = logs.get("grad_norm") | |
| if loss is not None and (math.isnan(float(loss)) or math.isinf(float(loss))): | |
| trackio.alert( | |
| title="NaN/Inf training loss", | |
| text=f"step={state.global_step} loss={loss} — stop run and reduce learning_rate by 10x.", | |
| level="ERROR", | |
| ) | |
| if grad_norm is not None and float(grad_norm) > 10.0: | |
| trackio.alert( | |
| title="Gradient norm spike", | |
| text=f"step={state.global_step} grad_norm={float(grad_norm):.3f} — consider lower lr or max_grad_norm.", | |
| level="WARN", | |
| ) | |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): | |
| if not state.is_world_process_zero or not metrics or trackio is None: | |
| return | |
| loss = metrics.get("eval_loss") | |
| if loss is not None and float(loss) > 1.0: | |
| trackio.alert( | |
| title="High validation loss", | |
| text=f"step={state.global_step} eval_loss={float(loss):.4f} — check convergence and rare-class oversampling.", | |
| level="WARN", | |
| ) | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--config", default="configs/rtx6000ada_qwen3_8b_qlora.yaml") | |
| p.add_argument("--model_name_or_path") | |
| p.add_argument("--dataset_name") | |
| p.add_argument("--train_split") | |
| p.add_argument("--eval_split") | |
| p.add_argument("--output_dir") | |
| p.add_argument("--hub_model_id") | |
| p.add_argument("--max_steps", type=int, default=None, help="Debug/short run override") | |
| p.add_argument("--no_push", action="store_true") | |
| p.add_argument("--packing", action="store_true", help="Override config and enable packing. Requires compatible attention setup.") | |
| p.add_argument("--flash_attn", action="store_true", help="Use flash_attention_2 in model_init_kwargs. Install flash-attn first.") | |
| p.add_argument("--resume_from_checkpoint", default=None, help="Path to checkpoint dir, or 'true' to auto-resume latest checkpoint in output_dir") | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| def require_cuda(): | |
| print("=== CUDA CHECK ===") | |
| print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "CUDA is not available to PyTorch. Refusing to train on CPU. " | |
| "Run `bash scripts/install_rtx6000ada.sh`, verify `nvidia-smi`, and set CUDA_VISIBLE_DEVICES=0." | |
| ) | |
| print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}") | |
| def valid_hf_repo_id(repo_id): | |
| if not repo_id or not isinstance(repo_id, str): | |
| return False | |
| if repo_id.endswith("/") or repo_id.startswith("/") or "//" in repo_id: | |
| return False | |
| pattern = r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}/[A-Za-z0-9][A-Za-z0-9._-]{0,95}$" | |
| return re.match(pattern, repo_id) is not None | |
| def sanitize_trackio_config(cfg): | |
| # Environment variable takes precedence only if valid. Invalid values like "nraptisss/" | |
| # crash Trackio before training starts, so ignore them and continue without a Space. | |
| env_space = os.environ.get("TRACKIO_SPACE_ID", "").strip() | |
| cfg_space = str(cfg.get("trackio_space_id") or "").strip() | |
| chosen = env_space or cfg_space | |
| if chosen and valid_hf_repo_id(chosen): | |
| cfg["trackio_space_id"] = chosen | |
| print(f"Trackio Space: {chosen}") | |
| else: | |
| if chosen: | |
| print(f"WARNING: ignoring invalid Trackio Space ID: {chosen!r}. Expected format: namespace/space-name") | |
| cfg["trackio_space_id"] = None | |
| os.environ.pop("TRACKIO_SPACE_ID", None) | |
| # Set DISABLE_TRACKIO=1 to bypass Trackio completely if desired. | |
| if os.environ.get("DISABLE_TRACKIO", "0") == "1": | |
| print("Trackio disabled via DISABLE_TRACKIO=1") | |
| cfg["project"] = None | |
| cfg["trackio_space_id"] = None | |
| return cfg | |
| def main(): | |
| args = parse_args() | |
| require_cuda() | |
| cfg = load_config(args.config) | |
| cfg = sanitize_trackio_config(cfg) | |
| for k in ["model_name_or_path", "dataset_name", "train_split", "eval_split", "output_dir", "hub_model_id"]: | |
| v = getattr(args, k) | |
| if v is not None: | |
| cfg[k] = v | |
| if args.max_steps is not None: | |
| cfg["max_steps"] = args.max_steps | |
| cfg["num_train_epochs"] = 1 | |
| if args.no_push: | |
| cfg["push_to_hub"] = False | |
| if args.packing: | |
| cfg["packing"] = True | |
| set_seed(args.seed) | |
| Path(cfg["output_dir"]).mkdir(parents=True, exist_ok=True) | |
| write_json(Path(cfg["output_dir"]) / "resolved_config.json", cfg) | |
| print("Loading dataset", cfg["dataset_name"]) | |
| ds = load_dataset(cfg["dataset_name"]) | |
| train_dataset = ds[cfg.get("train_split", "train_sota")] | |
| eval_dataset = ds[cfg.get("eval_split", "validation")] | |
| print(train_dataset) | |
| print(eval_dataset) | |
| # TRL infers dataset type from column names. This research dataset includes both | |
| # `messages` and convenience `prompt`/`completion` columns; passing all columns can | |
| # make TRL classify it as prompt-completion instead of conversational and reject | |
| # assistant_only_loss=True. For SFT we intentionally train from ChatML `messages`. | |
| train_dataset = train_dataset.select_columns(["messages"]) | |
| eval_dataset = eval_dataset.select_columns(["messages"]) | |
| print("SFT train columns:", train_dataset.column_names) | |
| print("SFT eval columns:", eval_dataset.column_names) | |
| tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"], trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| bnb_config = None | |
| if cfg.get("load_in_4bit", True): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type=cfg.get("bnb_4bit_quant_type", "nf4"), | |
| bnb_4bit_use_double_quant=bool(cfg.get("bnb_4bit_use_double_quant", True)), | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model_init_kwargs = { | |
| "trust_remote_code": True, | |
| "device_map": {"": 0}, | |
| "dtype": torch.bfloat16 if cfg.get("bf16", True) else torch.float16, | |
| } | |
| if bnb_config is not None: | |
| model_init_kwargs["quantization_config"] = bnb_config | |
| if args.flash_attn: | |
| model_init_kwargs["attn_implementation"] = "flash_attention_2" | |
| target_modules = cfg.get("lora_target_modules", "all-linear") | |
| peft_config = LoraConfig( | |
| r=int(cfg.get("lora_r", 64)), | |
| lora_alpha=int(cfg.get("lora_alpha", 16)), | |
| lora_dropout=float(cfg.get("lora_dropout", 0.05)), | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=target_modules, | |
| ) | |
| report_to = "trackio" if cfg.get("project") else "none" | |
| sft_args = SFTConfig( | |
| output_dir=cfg["output_dir"], | |
| model_init_kwargs=model_init_kwargs, | |
| max_length=int(cfg.get("max_length", 2048)), | |
| packing=bool(cfg.get("packing", False)), | |
| assistant_only_loss=bool(cfg.get("assistant_only_loss", True)), | |
| dataset_num_proc=int(cfg.get("dataset_num_proc", 8)), | |
| learning_rate=float(cfg.get("learning_rate", 2e-4)), | |
| lr_scheduler_type=cfg.get("lr_scheduler_type", "constant"), | |
| warmup_steps=int(cfg.get("warmup_steps", 0)), | |
| weight_decay=float(cfg.get("weight_decay", 0.0)), | |
| max_grad_norm=float(cfg.get("max_grad_norm", 0.3)), | |
| num_train_epochs=float(cfg.get("epochs", 2)), | |
| max_steps=int(cfg["max_steps"]) if cfg.get("max_steps") is not None else -1, | |
| per_device_train_batch_size=int(cfg.get("per_device_train_batch_size", 2)), | |
| gradient_accumulation_steps=int(cfg.get("gradient_accumulation_steps", 8)), | |
| per_device_eval_batch_size=int(cfg.get("per_device_eval_batch_size", 2)), | |
| bf16=bool(cfg.get("bf16", True)), | |
| gradient_checkpointing=bool(cfg.get("gradient_checkpointing", True)), | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| optim=cfg.get("optim", "paged_adamw_32bit"), | |
| eval_strategy="steps", | |
| eval_steps=int(cfg.get("eval_steps", 250)), | |
| save_strategy="steps", | |
| save_steps=int(cfg.get("save_steps", 250)), | |
| save_total_limit=int(cfg.get("save_total_limit", 3)), | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| logging_strategy="steps", | |
| logging_steps=int(cfg.get("logging_steps", 10)), | |
| logging_first_step=True, | |
| disable_tqdm=True, | |
| report_to=report_to, | |
| run_name=cfg.get("run_name"), | |
| project=cfg.get("project"), | |
| trackio_space_id=cfg.get("trackio_space_id"), | |
| push_to_hub=bool(cfg.get("push_to_hub", True)), | |
| hub_model_id=cfg.get("hub_model_id"), | |
| ) | |
| trainer = SFTTrainer( | |
| model=cfg["model_name_or_path"], | |
| args=sft_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| callbacks=[TrackioAlertCallback()], | |
| ) | |
| resume_arg = args.resume_from_checkpoint | |
| if resume_arg is not None and str(resume_arg).lower() == "true": | |
| resume_arg = True | |
| trainer.train(resume_from_checkpoint=resume_arg) | |
| metrics = trainer.evaluate() | |
| write_json(Path(cfg["output_dir"]) / "final_eval_metrics.json", metrics) | |
| trainer.save_model(cfg["output_dir"]) | |
| tokenizer.save_pretrained(cfg["output_dir"]) | |
| if bool(cfg.get("push_to_hub", True)): | |
| trainer.push_to_hub( | |
| commit_message="Qwen TMF921 QLoRA SFT", | |
| dataset_name=cfg["dataset_name"], | |
| ) | |
| print(f"Pushed model/adapters to https://huggingface.co/{cfg.get('hub_model_id')}") | |
| if __name__ == "__main__": | |
| main() | |