PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
File size: 6,572 Bytes
0c12387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#!/usr/bin/env python3
"""Focused SFT training on 4 strong layers only.

Removes O1 NRM, A1 policy, and lifecycle layers from training.
Keeps: tmf921, camara, intent_3gpp, etsi_zsm (+ adversarial in test).

Expected improvement: ~85%+ normalized field F1 (vs 79.6% with all layers).
Same recipe as Stage 1, just cleaner data.

Usage:
    export HF_TOKEN=hf_...
    export CUDA_VISIBLE_DEVICES=0
    export TOKENIZERS_PARALLELISM=false
    python scripts/train_focused.py
"""
import gc
import json
import os
from pathlib import Path

import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, BitsAndBytesConfig, set_seed
from trl import SFTConfig, SFTTrainer


# ============================================================
# Configuration
# ============================================================

BASE_MODEL = "Qwen/Qwen3-8B"
DATASET_NAME = "nraptisss/TMF921-intent-to-config-research-sota"
OUTPUT_DIR = "outputs/qwen3-8b-tmf921-focused"
HUB_MODEL_ID = "nraptisss/Qwen3-8B-TMF921-Focused-4Layer"

# Layers to KEEP (strong layers only)
KEEP_LAYERS = {"tmf921", "camara", "intent_3gpp", "etsi_zsm"}

# Also keep adversarial examples (they test rejection, not generation)
# Adversarial rows have target_layer starting with "adversarial" or lifecycle_operation != "create"
# but they're primarily in test_adversarial split. In train_sota they're marked.


def main():
    set_seed(42)

    print("=" * 60)
    print("TMF921 Focused 4-Layer SFT Training")
    print("=" * 60)
    print(f"Base model: {BASE_MODEL}")
    print(f"Keep layers: {sorted(KEEP_LAYERS)}")
    print(f"Output: {OUTPUT_DIR}")
    print(f"Hub: {HUB_MODEL_ID}")
    print("=" * 60)

    # Step 1: Load and filter dataset
    print("\nStep 1: Loading and filtering dataset...")
    ds = load_dataset(DATASET_NAME)

    # Filter train_sota: keep only strong layers + adversarial rows
    train_full = ds["train_sota"]
    print(f"  train_sota before filter: {len(train_full)}")

    def is_keep(example):
        layer = example.get("target_layer", "")
        # Keep strong layers
        if layer in KEEP_LAYERS:
            return True
        # Keep adversarial rows (they teach rejection)
        if "adversarial" in layer:
            return True
        return False

    train_filtered = train_full.filter(is_keep)
    print(f"  train_sota after filter: {len(train_filtered)}")

    # Show what we kept
    from collections import Counter
    layer_counts = Counter(train_filtered["target_layer"])
    for layer, count in layer_counts.most_common():
        print(f"    {layer}: {count}")

    # Filter validation too
    val_full = ds["validation"]
    val_filtered = val_full.filter(is_keep)
    print(f"  validation: {len(val_full)} -> {len(val_filtered)}")

    # For SFT, only pass the messages column
    train_dataset = train_filtered.select_columns(["messages"])
    eval_dataset = val_filtered.select_columns(["messages"])
    print(f"\n  Final train: {len(train_dataset)} examples")
    print(f"  Final eval: {len(eval_dataset)} examples")

    # Step 2: Configure model + training
    print("\nStep 2: Configuring model and training...")

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    model_init_kwargs = {
        "trust_remote_code": True,
        "torch_dtype": torch.bfloat16,
        "quantization_config": bnb_config,
        "device_map": {"": 0},
    }

    peft_config = LoraConfig(
        r=64,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules="all-linear",
    )

    # Same proven recipe as Stage 1, just on focused data
    sft_config = SFTConfig(
        output_dir=OUTPUT_DIR,
        model_init_kwargs=model_init_kwargs,
        # Data
        max_length=2048,
        packing=False,
        assistant_only_loss=True,
        dataset_num_proc=8,
        # Optimization
        learning_rate=2e-4,
        lr_scheduler_type="cosine",
        warmup_steps=50,
        weight_decay=0.01,
        max_grad_norm=0.3,
        num_train_epochs=3,  # 3 epochs on smaller dataset (was 2 on full)
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        per_device_eval_batch_size=2,
        bf16=True,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        optim="paged_adamw_32bit",
        # Eval/Save
        eval_strategy="steps",
        eval_steps=200,
        save_strategy="steps",
        save_steps=200,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        # Logging
        logging_strategy="steps",
        logging_steps=10,
        logging_first_step=True,
        disable_tqdm=True,
        report_to="none",
        run_name="qwen3-8b-tmf921-focused-4layer",
        # Hub
        push_to_hub=True,
        hub_model_id=HUB_MODEL_ID,
        # Thinking mode off
        chat_template_kwargs={"enable_thinking": False},
    )

    # Step 3: Train
    print("\nStep 3: Starting training...")
    trainer = SFTTrainer(
        model=BASE_MODEL,
        args=sft_config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    print(f"  Trainable params: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}")

    trainer.train()

    # Save
    print("\nSaving final model...")
    metrics = trainer.evaluate()
    print(f"  Final eval loss: {metrics.get('eval_loss', 'N/A')}")

    Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    with open(f"{OUTPUT_DIR}/final_eval_metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    trainer.save_model(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)

    if sft_config.push_to_hub:
        print(f"\nPushing to hub: {HUB_MODEL_ID}")
        trainer.push_to_hub(commit_message="Focused 4-layer SFT: tmf921/camara/3gpp/etsi_zsm only")

    print("\n" + "=" * 60)
    print("Training complete!")
    print(f"Model: {OUTPUT_DIR}")
    print(f"Hub: https://huggingface.co/{HUB_MODEL_ID}")
    print("=" * 60)


if __name__ == "__main__":
    main()