PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
nraptisss's picture
Fix RFT: batch generation, reduce to 200 prompts x 8 samples (~24h feasible on RTX 6000 Ada)
05ea6fa verified
#!/usr/bin/env python3
"""Best-of-N Rejection Fine-Tuning (RFT) for TMF921 — FIXED for RTX 6000 Ada.
Generates N=8 completions per prompt using BATCHED generation (not sequential).
Reduced to 200 prompts focused on weak layers. Total runtime: ~20-24 hours.
Usage:
export PYTHONPATH="$PWD/src"
python scripts/train_rft.py --stage generate # ~20-24h
python scripts/train_rft.py --stage train # ~2h
python scripts/train_rft.py --stage all # both sequential
"""
import argparse
import gc
import json
import os
import re
import random
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed
from tqdm import tqdm
# ============================================================
# JSON Parsing & Evaluation
# ============================================================
def strip_code_fence(text: str) -> str:
text = text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s*```$", "", text)
return text.strip()
def extract_json_text(text: str) -> str:
text = strip_code_fence(text)
if not text:
return text
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
return text[start:end + 1].strip()
return text
def parse_json(text: str) -> Tuple[Optional[Any], Optional[str]]:
candidate = extract_json_text(text)
try:
return json.loads(candidate), None
except Exception as e:
return None, str(e)[:200]
def canonical_json(obj: Any) -> str:
return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":"))
def flatten_json(obj: Any, prefix: str = "") -> Dict[str, Any]:
out = {}
if isinstance(obj, dict):
for k, v in obj.items():
p = f"{prefix}.{k}" if prefix else str(k)
out.update(flatten_json(v, p))
elif isinstance(obj, list):
for i, v in enumerate(obj):
out.update(flatten_json(v, f"{prefix}[{i}]"))
else:
out[prefix] = obj
return out
VOLATILE_KEY_EXACT = {
"id", "uuid", "href", "name", "description", "displayName", "label",
"@schemaLocation", "schemaLocation", "version", "revision",
"createdAt", "updatedAt", "modifiedAt", "lastModified", "timestamp",
"creationDate", "lastUpdate", "requestedStartDate", "requestedCompletionDate",
"startTime", "endTime", "validFrom", "validTo", "validFor",
"correlationId", "requestId", "transactionId", "reservationId",
}
VOLATILE_KEY_FRAGMENTS = ["href", "schema", "timestamp", "uuid", "correlation", "transaction"]
PROTECTED_KEYS = {"sst", "sd", "sliceType", "slice_type", "latency", "reliability", "dl", "ul", "maxUEs", "maxNumberOfUEs"}
ID_LIKE_RE = re.compile(r"\b(?:intent|slice|policy|booking|cell|me|gnb|nsi|nssi|req|report|monitor|assurance)[-_][A-Za-z0-9._:-]+", re.IGNORECASE)
HEX_RE = re.compile(r"\b[0-9a-f]{8,}\b", re.IGNORECASE)
ISO_TIME_RE = re.compile(r"\b\d{4}-\d{2}-\d{2}[T ][0-9:.+-Z]*\b")
def is_volatile_key(key: str) -> bool:
if key in PROTECTED_KEYS:
return False
if key in VOLATILE_KEY_EXACT:
return True
lk = key.lower()
if lk in {k.lower() for k in VOLATILE_KEY_EXACT}:
return True
return any(fragment in lk for fragment in VOLATILE_KEY_FRAGMENTS)
def normalize_string(s: str) -> str:
s = ISO_TIME_RE.sub("<TIME>", s)
s = ID_LIKE_RE.sub("<ID>", s)
s = HEX_RE.sub("<HEX>", s)
return s.strip()
def normalize_json_obj(obj: Any) -> Any:
if isinstance(obj, dict):
out = {}
for k, v in obj.items():
if is_volatile_key(str(k)):
continue
nv = normalize_json_obj(v)
if nv == {} or nv == [] or nv is None:
continue
out[str(k)] = nv
return dict(sorted(out.items(), key=lambda kv: kv[0]))
if isinstance(obj, list):
items = [normalize_json_obj(x) for x in obj]
items = [x for x in items if x not in ({}, [], None)]
return sorted(items, key=lambda x: canonical_json(x))
if isinstance(obj, str):
return normalize_string(obj)
return obj
def compute_normalized_field_f1(pred_text: str, gold_text: str) -> float:
pred_obj, _ = parse_json(pred_text)
gold_obj, _ = parse_json(gold_text)
if pred_obj is None or gold_obj is None:
return 0.0
pred_norm = normalize_json_obj(pred_obj)
gold_norm = normalize_json_obj(gold_obj)
pred_flat = flatten_json(pred_norm)
gold_flat = flatten_json(gold_norm)
pred_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in pred_flat.items())
gold_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in gold_flat.items())
tp = len(pred_items & gold_items)
fp = len(pred_items - gold_items)
fn = len(gold_items - pred_items)
p = tp / (tp + fp) if (tp + fp) else 1.0
r = tp / (tp + fn) if (tp + fn) else 1.0
return 2 * p * r / (p + r) if (p + r) else 0.0
# ============================================================
# Configuration
# ============================================================
BASE_MODEL = "Qwen/Qwen3-8B"
SFT_ADAPTER = "nraptisss/Qwen3-8B-TMF921-Intent-QLoRA-qwen3-8b-qlora-20260501-083834"
DATASET_NAME = "nraptisss/TMF921-intent-to-config-research-sota"
TRAIN_SPLIT = "train_sota"
# Generation — feasible for RTX 6000 Ada in ~20-24h
N_SAMPLES = 8 # 8 completions per prompt (not 16)
TEMPERATURE = 0.9
TOP_P = 0.95
MAX_NEW_TOKENS = 1536
NUM_PROMPTS = 200 # 200 prompts (not 2000) — focused on weak layers
REWARD_THRESHOLD = 0.75 # Slightly lower threshold to get more samples
MAX_PER_PROMPT = 3 # Cap at 3 per prompt
# Training
RFT_OUTPUT_DIR = "outputs/qwen3-8b-tmf921-rft"
RFT_HUB_MODEL_ID = "nraptisss/Qwen3-8B-TMF921-Intent-RFT"
RFT_DATA_DIR = "outputs/rft_data"
# ============================================================
# Stage 1: Generate + Score + Filter (BATCHED)
# ============================================================
def stage_generate():
set_seed(42)
print("=" * 60)
print("RFT Stage 1: Generate + Score + Filter (BATCHED)")
print("=" * 60)
print(f" N samples per prompt: {N_SAMPLES}")
print(f" Temperature: {TEMPERATURE}")
print(f" Num prompts: {NUM_PROMPTS}")
print(f" Reward threshold: {REWARD_THRESHOLD}")
print(f" Max per prompt: {MAX_PER_PROMPT}")
print(f" Estimated time: ~20-24 hours")
print("=" * 60)
# Load model
print("\nLoading SFT model...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, trust_remote_code=True, quantization_config=bnb_config,
device_map={"": 0}, torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(model, SFT_ADAPTER)
model = model.merge_and_unload()
model.eval()
gc.collect()
torch.cuda.empty_cache()
print(f" GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Select prompts — focus on weak layers
print("\nSelecting prompts (weak-layer focused)...")
ds = load_dataset(DATASET_NAME, split=TRAIN_SPLIT)
weak_layers = ["o1_nrm", "a1_policy", "tmf921_lifecycle_report", "tmf921_lifecycle_monitor", "tmf921_lifecycle_scale"]
weak_examples = [ex for ex in ds if ex["target_layer"] in weak_layers]
strong_examples = [ex for ex in ds if ex["target_layer"] not in weak_layers]
random.seed(42)
n_weak = min(len(weak_examples), int(NUM_PROMPTS * 0.7)) # 70% weak layers
n_strong = NUM_PROMPTS - n_weak
selected = random.sample(weak_examples, n_weak) + random.sample(strong_examples, min(n_strong, len(strong_examples)))
random.shuffle(selected)
print(f" Selected {len(selected)} prompts ({n_weak} weak, {n_strong} strong)")
# Resume support
Path(RFT_DATA_DIR).mkdir(parents=True, exist_ok=True)
resume_path = f"{RFT_DATA_DIR}/rft_training_data.json"
all_passing = []
done_ids = set()
if os.path.exists(resume_path):
with open(resume_path) as f:
all_passing = json.load(f)
done_ids = {ex["source_id"] for ex in all_passing}
print(f" Resuming: {len(done_ids)} prompts already done, {len(all_passing)} samples collected")
stats = {"total_generated": 0, "total_valid_json": 0, "total_passing": len(all_passing)}
# Generate — one prompt at a time but N samples per call using num_return_sequences
print(f"\nGenerating {N_SAMPLES} completions per prompt (batched)...")
remaining = [ex for ex in selected if ex.get("id", "") not in done_ids]
print(f" Remaining: {len(remaining)} prompts")
for i, example in enumerate(tqdm(remaining, desc="Generating")):
messages = example["messages"]
prompt_msgs = [m for m in messages if m["role"] != "assistant"]
gold_text = [m for m in messages if m["role"] == "assistant"][0]["content"]
target_layer = example["target_layer"]
source_id = example.get("id", f"idx_{i}")
# Build prompt
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
# Generate N samples in ONE call
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
num_return_sequences=N_SAMPLES,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode all N completions
new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
completions = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
stats["total_generated"] += N_SAMPLES
# Score and filter
scored = []
for comp in completions:
comp = comp.strip()
obj, _ = parse_json(comp)
if obj is not None:
stats["total_valid_json"] += 1
f1 = compute_normalized_field_f1(comp, gold_text)
scored.append((comp, f1))
# Keep best passing samples
scored.sort(key=lambda x: -x[1])
seen = set()
passing = []
for comp, f1 in scored:
if f1 < REWARD_THRESHOLD:
continue
obj, _ = parse_json(comp)
sig = frozenset(flatten_json(obj).keys()) if obj else None
if sig in seen:
continue
seen.add(sig)
passing.append((comp, f1))
if len(passing) >= MAX_PER_PROMPT:
break
for comp, f1 in passing:
all_passing.append({
"messages": prompt_msgs + [{"role": "assistant", "content": comp}],
"target_layer": target_layer,
"reward_f1": f1,
"source_id": source_id,
})
stats["total_passing"] += 1
# Save every 10 prompts
if (i + 1) % 10 == 0:
_save_rft_data(all_passing, stats)
tqdm.write(f" [{i+1}/{len(remaining)}] gen={stats['total_generated']}, "
f"json={stats['total_valid_json']}, pass={stats['total_passing']}")
# Final save
_save_rft_data(all_passing, stats)
print(f"\n{'='*60}")
print(f"Generation complete!")
print(f" Total generated: {stats['total_generated']}")
print(f" Valid JSON: {stats['total_valid_json']} ({stats['total_valid_json']/max(1,stats['total_generated'])*100:.1f}%)")
print(f" Passing (F1>{REWARD_THRESHOLD}): {stats['total_passing']} ({stats['total_passing']/max(1,stats['total_generated'])*100:.1f}%)")
print(f" Saved to: {RFT_DATA_DIR}/")
print(f"{'='*60}")
def _save_rft_data(all_passing, stats):
Path(RFT_DATA_DIR).mkdir(parents=True, exist_ok=True)
with open(f"{RFT_DATA_DIR}/rft_training_data.json", "w") as f:
json.dump(all_passing, f, indent=2, ensure_ascii=False)
with open(f"{RFT_DATA_DIR}/rft_stats.json", "w") as f:
json.dump(stats, f, indent=2)
# ============================================================
# Stage 2: Fine-tune
# ============================================================
def stage_train():
set_seed(42)
print("=" * 60)
print("RFT Stage 2: Fine-tune on filtered completions")
print("=" * 60)
data_path = f"{RFT_DATA_DIR}/rft_training_data.json"
if not os.path.exists(data_path):
raise FileNotFoundError(f"No data at {data_path}. Run --stage generate first.")
with open(data_path) as f:
raw_data = json.load(f)
print(f" Loaded {len(raw_data)} filtered examples")
if len(raw_data) < 10:
print(" ERROR: Too few samples. Lower REWARD_THRESHOLD or increase NUM_PROMPTS.")
return
# Build dataset
train_data = Dataset.from_list([{"messages": ex["messages"]} for ex in raw_data])
# Load SFT model
print("\nLoading SFT model...")
from trl import SFTConfig, SFTTrainer
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, trust_remote_code=True, quantization_config=bnb_config,
device_map={"": 0}, torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(model, SFT_ADAPTER)
model = model.merge_and_unload()
gc.collect()
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
r=32, lora_alpha=16, lora_dropout=0.05,
bias="none", task_type="CAUSAL_LM", target_modules="all-linear",
)
sft_config = SFTConfig(
output_dir=RFT_OUTPUT_DIR,
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-5,
lr_scheduler_type="cosine",
warmup_steps=20,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="paged_adamw_32bit",
max_grad_norm=1.0,
max_length=2048,
assistant_only_loss=True,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
save_strategy="epoch",
save_total_limit=2,
push_to_hub=True,
hub_model_id=RFT_HUB_MODEL_ID,
report_to="none",
chat_template_kwargs={"enable_thinking": False},
)
trainer = SFTTrainer(
model=model, args=sft_config, train_dataset=train_data,
processing_class=tokenizer, peft_config=peft_config,
)
print(f" Trainable params: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}")
print(f" Training examples: {len(train_data)}")
trainer.train()
trainer.save_model(RFT_OUTPUT_DIR)
if sft_config.push_to_hub:
trainer.push_to_hub(commit_message="RFT: Best-of-N rejection fine-tuning")
print(f"\n{'='*60}")
print(f"RFT complete! Model at: https://huggingface.co/{RFT_HUB_MODEL_ID}")
print(f"{'='*60}")
# ============================================================
# Main
# ============================================================
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--stage", choices=["generate", "train", "all"], default="all")
args = parser.parse_args()
if args.stage in ("generate", "all"):
stage_generate()
if args.stage in ("train", "all"):
stage_train()
if __name__ == "__main__":
main()