PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /train_focused.py
nraptisss's picture
Add focused 4-layer SFT script: train only on tmf921/camara/3gpp/etsi_zsm (removes weak O1/A1/lifecycle layers)
0c12387 verified
#!/usr/bin/env python3
"""Focused SFT training on 4 strong layers only.
Removes O1 NRM, A1 policy, and lifecycle layers from training.
Keeps: tmf921, camara, intent_3gpp, etsi_zsm (+ adversarial in test).
Expected improvement: ~85%+ normalized field F1 (vs 79.6% with all layers).
Same recipe as Stage 1, just cleaner data.
Usage:
export HF_TOKEN=hf_...
export CUDA_VISIBLE_DEVICES=0
export TOKENIZERS_PARALLELISM=false
python scripts/train_focused.py
"""
import gc
import json
import os
from pathlib import Path
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, BitsAndBytesConfig, set_seed
from trl import SFTConfig, SFTTrainer
# ============================================================
# Configuration
# ============================================================
BASE_MODEL = "Qwen/Qwen3-8B"
DATASET_NAME = "nraptisss/TMF921-intent-to-config-research-sota"
OUTPUT_DIR = "outputs/qwen3-8b-tmf921-focused"
HUB_MODEL_ID = "nraptisss/Qwen3-8B-TMF921-Focused-4Layer"
# Layers to KEEP (strong layers only)
KEEP_LAYERS = {"tmf921", "camara", "intent_3gpp", "etsi_zsm"}
# Also keep adversarial examples (they test rejection, not generation)
# Adversarial rows have target_layer starting with "adversarial" or lifecycle_operation != "create"
# but they're primarily in test_adversarial split. In train_sota they're marked.
def main():
set_seed(42)
print("=" * 60)
print("TMF921 Focused 4-Layer SFT Training")
print("=" * 60)
print(f"Base model: {BASE_MODEL}")
print(f"Keep layers: {sorted(KEEP_LAYERS)}")
print(f"Output: {OUTPUT_DIR}")
print(f"Hub: {HUB_MODEL_ID}")
print("=" * 60)
# Step 1: Load and filter dataset
print("\nStep 1: Loading and filtering dataset...")
ds = load_dataset(DATASET_NAME)
# Filter train_sota: keep only strong layers + adversarial rows
train_full = ds["train_sota"]
print(f" train_sota before filter: {len(train_full)}")
def is_keep(example):
layer = example.get("target_layer", "")
# Keep strong layers
if layer in KEEP_LAYERS:
return True
# Keep adversarial rows (they teach rejection)
if "adversarial" in layer:
return True
return False
train_filtered = train_full.filter(is_keep)
print(f" train_sota after filter: {len(train_filtered)}")
# Show what we kept
from collections import Counter
layer_counts = Counter(train_filtered["target_layer"])
for layer, count in layer_counts.most_common():
print(f" {layer}: {count}")
# Filter validation too
val_full = ds["validation"]
val_filtered = val_full.filter(is_keep)
print(f" validation: {len(val_full)} -> {len(val_filtered)}")
# For SFT, only pass the messages column
train_dataset = train_filtered.select_columns(["messages"])
eval_dataset = val_filtered.select_columns(["messages"])
print(f"\n Final train: {len(train_dataset)} examples")
print(f" Final eval: {len(eval_dataset)} examples")
# Step 2: Configure model + training
print("\nStep 2: Configuring model and training...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_init_kwargs = {
"trust_remote_code": True,
"torch_dtype": torch.bfloat16,
"quantization_config": bnb_config,
"device_map": {"": 0},
}
peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
# Same proven recipe as Stage 1, just on focused data
sft_config = SFTConfig(
output_dir=OUTPUT_DIR,
model_init_kwargs=model_init_kwargs,
# Data
max_length=2048,
packing=False,
assistant_only_loss=True,
dataset_num_proc=8,
# Optimization
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_steps=50,
weight_decay=0.01,
max_grad_norm=0.3,
num_train_epochs=3, # 3 epochs on smaller dataset (was 2 on full)
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
per_device_eval_batch_size=2,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="paged_adamw_32bit",
# Eval/Save
eval_strategy="steps",
eval_steps=200,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
# Logging
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
report_to="none",
run_name="qwen3-8b-tmf921-focused-4layer",
# Hub
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
# Thinking mode off
chat_template_kwargs={"enable_thinking": False},
)
# Step 3: Train
print("\nStep 3: Starting training...")
trainer = SFTTrainer(
model=BASE_MODEL,
args=sft_config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
peft_config=peft_config,
)
print(f" Trainable params: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}")
trainer.train()
# Save
print("\nSaving final model...")
metrics = trainer.evaluate()
print(f" Final eval loss: {metrics.get('eval_loss', 'N/A')}")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
with open(f"{OUTPUT_DIR}/final_eval_metrics.json", "w") as f:
json.dump(metrics, f, indent=2)
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
if sft_config.push_to_hub:
print(f"\nPushing to hub: {HUB_MODEL_ID}")
trainer.push_to_hub(commit_message="Focused 4-layer SFT: tmf921/camara/3gpp/etsi_zsm only")
print("\n" + "=" * 60)
print("Training complete!")
print(f"Model: {OUTPUT_DIR}")
print(f"Hub: https://huggingface.co/{HUB_MODEL_ID}")
print("=" * 60)
if __name__ == "__main__":
main()