Instructions to use nraptisss/tmf921-intent-training with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use nraptisss/tmf921-intent-training with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """Best-of-N Rejection Fine-Tuning (RFT) for TMF921 — FIXED for RTX 6000 Ada. | |
| Generates N=8 completions per prompt using BATCHED generation (not sequential). | |
| Reduced to 200 prompts focused on weak layers. Total runtime: ~20-24 hours. | |
| Usage: | |
| export PYTHONPATH="$PWD/src" | |
| python scripts/train_rft.py --stage generate # ~20-24h | |
| python scripts/train_rft.py --stage train # ~2h | |
| python scripts/train_rft.py --stage all # both sequential | |
| """ | |
| import argparse | |
| import gc | |
| import json | |
| import os | |
| import re | |
| import random | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| from datasets import Dataset, load_dataset | |
| from peft import LoraConfig, PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed | |
| from tqdm import tqdm | |
| # ============================================================ | |
| # JSON Parsing & Evaluation | |
| # ============================================================ | |
| def strip_code_fence(text: str) -> str: | |
| text = text.strip() | |
| if text.startswith("```"): | |
| text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE) | |
| text = re.sub(r"\s*```$", "", text) | |
| return text.strip() | |
| def extract_json_text(text: str) -> str: | |
| text = strip_code_fence(text) | |
| if not text: | |
| return text | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start >= 0 and end > start: | |
| return text[start:end + 1].strip() | |
| return text | |
| def parse_json(text: str) -> Tuple[Optional[Any], Optional[str]]: | |
| candidate = extract_json_text(text) | |
| try: | |
| return json.loads(candidate), None | |
| except Exception as e: | |
| return None, str(e)[:200] | |
| def canonical_json(obj: Any) -> str: | |
| return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":")) | |
| def flatten_json(obj: Any, prefix: str = "") -> Dict[str, Any]: | |
| out = {} | |
| if isinstance(obj, dict): | |
| for k, v in obj.items(): | |
| p = f"{prefix}.{k}" if prefix else str(k) | |
| out.update(flatten_json(v, p)) | |
| elif isinstance(obj, list): | |
| for i, v in enumerate(obj): | |
| out.update(flatten_json(v, f"{prefix}[{i}]")) | |
| else: | |
| out[prefix] = obj | |
| return out | |
| VOLATILE_KEY_EXACT = { | |
| "id", "uuid", "href", "name", "description", "displayName", "label", | |
| "@schemaLocation", "schemaLocation", "version", "revision", | |
| "createdAt", "updatedAt", "modifiedAt", "lastModified", "timestamp", | |
| "creationDate", "lastUpdate", "requestedStartDate", "requestedCompletionDate", | |
| "startTime", "endTime", "validFrom", "validTo", "validFor", | |
| "correlationId", "requestId", "transactionId", "reservationId", | |
| } | |
| VOLATILE_KEY_FRAGMENTS = ["href", "schema", "timestamp", "uuid", "correlation", "transaction"] | |
| PROTECTED_KEYS = {"sst", "sd", "sliceType", "slice_type", "latency", "reliability", "dl", "ul", "maxUEs", "maxNumberOfUEs"} | |
| ID_LIKE_RE = re.compile(r"\b(?:intent|slice|policy|booking|cell|me|gnb|nsi|nssi|req|report|monitor|assurance)[-_][A-Za-z0-9._:-]+", re.IGNORECASE) | |
| HEX_RE = re.compile(r"\b[0-9a-f]{8,}\b", re.IGNORECASE) | |
| ISO_TIME_RE = re.compile(r"\b\d{4}-\d{2}-\d{2}[T ][0-9:.+-Z]*\b") | |
| def is_volatile_key(key: str) -> bool: | |
| if key in PROTECTED_KEYS: | |
| return False | |
| if key in VOLATILE_KEY_EXACT: | |
| return True | |
| lk = key.lower() | |
| if lk in {k.lower() for k in VOLATILE_KEY_EXACT}: | |
| return True | |
| return any(fragment in lk for fragment in VOLATILE_KEY_FRAGMENTS) | |
| def normalize_string(s: str) -> str: | |
| s = ISO_TIME_RE.sub("<TIME>", s) | |
| s = ID_LIKE_RE.sub("<ID>", s) | |
| s = HEX_RE.sub("<HEX>", s) | |
| return s.strip() | |
| def normalize_json_obj(obj: Any) -> Any: | |
| if isinstance(obj, dict): | |
| out = {} | |
| for k, v in obj.items(): | |
| if is_volatile_key(str(k)): | |
| continue | |
| nv = normalize_json_obj(v) | |
| if nv == {} or nv == [] or nv is None: | |
| continue | |
| out[str(k)] = nv | |
| return dict(sorted(out.items(), key=lambda kv: kv[0])) | |
| if isinstance(obj, list): | |
| items = [normalize_json_obj(x) for x in obj] | |
| items = [x for x in items if x not in ({}, [], None)] | |
| return sorted(items, key=lambda x: canonical_json(x)) | |
| if isinstance(obj, str): | |
| return normalize_string(obj) | |
| return obj | |
| def compute_normalized_field_f1(pred_text: str, gold_text: str) -> float: | |
| pred_obj, _ = parse_json(pred_text) | |
| gold_obj, _ = parse_json(gold_text) | |
| if pred_obj is None or gold_obj is None: | |
| return 0.0 | |
| pred_norm = normalize_json_obj(pred_obj) | |
| gold_norm = normalize_json_obj(gold_obj) | |
| pred_flat = flatten_json(pred_norm) | |
| gold_flat = flatten_json(gold_norm) | |
| pred_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in pred_flat.items()) | |
| gold_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in gold_flat.items()) | |
| tp = len(pred_items & gold_items) | |
| fp = len(pred_items - gold_items) | |
| fn = len(gold_items - pred_items) | |
| p = tp / (tp + fp) if (tp + fp) else 1.0 | |
| r = tp / (tp + fn) if (tp + fn) else 1.0 | |
| return 2 * p * r / (p + r) if (p + r) else 0.0 | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| BASE_MODEL = "Qwen/Qwen3-8B" | |
| SFT_ADAPTER = "nraptisss/Qwen3-8B-TMF921-Intent-QLoRA-qwen3-8b-qlora-20260501-083834" | |
| DATASET_NAME = "nraptisss/TMF921-intent-to-config-research-sota" | |
| TRAIN_SPLIT = "train_sota" | |
| # Generation — feasible for RTX 6000 Ada in ~20-24h | |
| N_SAMPLES = 8 # 8 completions per prompt (not 16) | |
| TEMPERATURE = 0.9 | |
| TOP_P = 0.95 | |
| MAX_NEW_TOKENS = 1536 | |
| NUM_PROMPTS = 200 # 200 prompts (not 2000) — focused on weak layers | |
| REWARD_THRESHOLD = 0.75 # Slightly lower threshold to get more samples | |
| MAX_PER_PROMPT = 3 # Cap at 3 per prompt | |
| # Training | |
| RFT_OUTPUT_DIR = "outputs/qwen3-8b-tmf921-rft" | |
| RFT_HUB_MODEL_ID = "nraptisss/Qwen3-8B-TMF921-Intent-RFT" | |
| RFT_DATA_DIR = "outputs/rft_data" | |
| # ============================================================ | |
| # Stage 1: Generate + Score + Filter (BATCHED) | |
| # ============================================================ | |
| def stage_generate(): | |
| set_seed(42) | |
| print("=" * 60) | |
| print("RFT Stage 1: Generate + Score + Filter (BATCHED)") | |
| print("=" * 60) | |
| print(f" N samples per prompt: {N_SAMPLES}") | |
| print(f" Temperature: {TEMPERATURE}") | |
| print(f" Num prompts: {NUM_PROMPTS}") | |
| print(f" Reward threshold: {REWARD_THRESHOLD}") | |
| print(f" Max per prompt: {MAX_PER_PROMPT}") | |
| print(f" Estimated time: ~20-24 hours") | |
| print("=" * 60) | |
| # Load model | |
| print("\nLoading SFT model...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, trust_remote_code=True, quantization_config=bnb_config, | |
| device_map={"": 0}, torch_dtype=torch.bfloat16, | |
| ) | |
| model = PeftModel.from_pretrained(model, SFT_ADAPTER) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print(f" GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Select prompts — focus on weak layers | |
| print("\nSelecting prompts (weak-layer focused)...") | |
| ds = load_dataset(DATASET_NAME, split=TRAIN_SPLIT) | |
| weak_layers = ["o1_nrm", "a1_policy", "tmf921_lifecycle_report", "tmf921_lifecycle_monitor", "tmf921_lifecycle_scale"] | |
| weak_examples = [ex for ex in ds if ex["target_layer"] in weak_layers] | |
| strong_examples = [ex for ex in ds if ex["target_layer"] not in weak_layers] | |
| random.seed(42) | |
| n_weak = min(len(weak_examples), int(NUM_PROMPTS * 0.7)) # 70% weak layers | |
| n_strong = NUM_PROMPTS - n_weak | |
| selected = random.sample(weak_examples, n_weak) + random.sample(strong_examples, min(n_strong, len(strong_examples))) | |
| random.shuffle(selected) | |
| print(f" Selected {len(selected)} prompts ({n_weak} weak, {n_strong} strong)") | |
| # Resume support | |
| Path(RFT_DATA_DIR).mkdir(parents=True, exist_ok=True) | |
| resume_path = f"{RFT_DATA_DIR}/rft_training_data.json" | |
| all_passing = [] | |
| done_ids = set() | |
| if os.path.exists(resume_path): | |
| with open(resume_path) as f: | |
| all_passing = json.load(f) | |
| done_ids = {ex["source_id"] for ex in all_passing} | |
| print(f" Resuming: {len(done_ids)} prompts already done, {len(all_passing)} samples collected") | |
| stats = {"total_generated": 0, "total_valid_json": 0, "total_passing": len(all_passing)} | |
| # Generate — one prompt at a time but N samples per call using num_return_sequences | |
| print(f"\nGenerating {N_SAMPLES} completions per prompt (batched)...") | |
| remaining = [ex for ex in selected if ex.get("id", "") not in done_ids] | |
| print(f" Remaining: {len(remaining)} prompts") | |
| for i, example in enumerate(tqdm(remaining, desc="Generating")): | |
| messages = example["messages"] | |
| prompt_msgs = [m for m in messages if m["role"] != "assistant"] | |
| gold_text = [m for m in messages if m["role"] == "assistant"][0]["content"] | |
| target_layer = example["target_layer"] | |
| source_id = example.get("id", f"idx_{i}") | |
| # Build prompt | |
| prompt_text = tokenizer.apply_chat_template( | |
| prompt_msgs, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device) | |
| # Generate N samples in ONE call | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| num_return_sequences=N_SAMPLES, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode all N completions | |
| new_tokens = outputs[:, inputs["input_ids"].shape[1]:] | |
| completions = tokenizer.batch_decode(new_tokens, skip_special_tokens=True) | |
| stats["total_generated"] += N_SAMPLES | |
| # Score and filter | |
| scored = [] | |
| for comp in completions: | |
| comp = comp.strip() | |
| obj, _ = parse_json(comp) | |
| if obj is not None: | |
| stats["total_valid_json"] += 1 | |
| f1 = compute_normalized_field_f1(comp, gold_text) | |
| scored.append((comp, f1)) | |
| # Keep best passing samples | |
| scored.sort(key=lambda x: -x[1]) | |
| seen = set() | |
| passing = [] | |
| for comp, f1 in scored: | |
| if f1 < REWARD_THRESHOLD: | |
| continue | |
| obj, _ = parse_json(comp) | |
| sig = frozenset(flatten_json(obj).keys()) if obj else None | |
| if sig in seen: | |
| continue | |
| seen.add(sig) | |
| passing.append((comp, f1)) | |
| if len(passing) >= MAX_PER_PROMPT: | |
| break | |
| for comp, f1 in passing: | |
| all_passing.append({ | |
| "messages": prompt_msgs + [{"role": "assistant", "content": comp}], | |
| "target_layer": target_layer, | |
| "reward_f1": f1, | |
| "source_id": source_id, | |
| }) | |
| stats["total_passing"] += 1 | |
| # Save every 10 prompts | |
| if (i + 1) % 10 == 0: | |
| _save_rft_data(all_passing, stats) | |
| tqdm.write(f" [{i+1}/{len(remaining)}] gen={stats['total_generated']}, " | |
| f"json={stats['total_valid_json']}, pass={stats['total_passing']}") | |
| # Final save | |
| _save_rft_data(all_passing, stats) | |
| print(f"\n{'='*60}") | |
| print(f"Generation complete!") | |
| print(f" Total generated: {stats['total_generated']}") | |
| print(f" Valid JSON: {stats['total_valid_json']} ({stats['total_valid_json']/max(1,stats['total_generated'])*100:.1f}%)") | |
| print(f" Passing (F1>{REWARD_THRESHOLD}): {stats['total_passing']} ({stats['total_passing']/max(1,stats['total_generated'])*100:.1f}%)") | |
| print(f" Saved to: {RFT_DATA_DIR}/") | |
| print(f"{'='*60}") | |
| def _save_rft_data(all_passing, stats): | |
| Path(RFT_DATA_DIR).mkdir(parents=True, exist_ok=True) | |
| with open(f"{RFT_DATA_DIR}/rft_training_data.json", "w") as f: | |
| json.dump(all_passing, f, indent=2, ensure_ascii=False) | |
| with open(f"{RFT_DATA_DIR}/rft_stats.json", "w") as f: | |
| json.dump(stats, f, indent=2) | |
| # ============================================================ | |
| # Stage 2: Fine-tune | |
| # ============================================================ | |
| def stage_train(): | |
| set_seed(42) | |
| print("=" * 60) | |
| print("RFT Stage 2: Fine-tune on filtered completions") | |
| print("=" * 60) | |
| data_path = f"{RFT_DATA_DIR}/rft_training_data.json" | |
| if not os.path.exists(data_path): | |
| raise FileNotFoundError(f"No data at {data_path}. Run --stage generate first.") | |
| with open(data_path) as f: | |
| raw_data = json.load(f) | |
| print(f" Loaded {len(raw_data)} filtered examples") | |
| if len(raw_data) < 10: | |
| print(" ERROR: Too few samples. Lower REWARD_THRESHOLD or increase NUM_PROMPTS.") | |
| return | |
| # Build dataset | |
| train_data = Dataset.from_list([{"messages": ex["messages"]} for ex in raw_data]) | |
| # Load SFT model | |
| print("\nLoading SFT model...") | |
| from trl import SFTConfig, SFTTrainer | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, trust_remote_code=True, quantization_config=bnb_config, | |
| device_map={"": 0}, torch_dtype=torch.bfloat16, | |
| ) | |
| model = PeftModel.from_pretrained(model, SFT_ADAPTER) | |
| model = model.merge_and_unload() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| peft_config = LoraConfig( | |
| r=32, lora_alpha=16, lora_dropout=0.05, | |
| bias="none", task_type="CAUSAL_LM", target_modules="all-linear", | |
| ) | |
| sft_config = SFTConfig( | |
| output_dir=RFT_OUTPUT_DIR, | |
| num_train_epochs=3, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=8, | |
| learning_rate=2e-5, | |
| lr_scheduler_type="cosine", | |
| warmup_steps=20, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| optim="paged_adamw_32bit", | |
| max_grad_norm=1.0, | |
| max_length=2048, | |
| assistant_only_loss=True, | |
| logging_strategy="steps", | |
| logging_steps=10, | |
| logging_first_step=True, | |
| disable_tqdm=True, | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| push_to_hub=True, | |
| hub_model_id=RFT_HUB_MODEL_ID, | |
| report_to="none", | |
| chat_template_kwargs={"enable_thinking": False}, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, args=sft_config, train_dataset=train_data, | |
| processing_class=tokenizer, peft_config=peft_config, | |
| ) | |
| print(f" Trainable params: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}") | |
| print(f" Training examples: {len(train_data)}") | |
| trainer.train() | |
| trainer.save_model(RFT_OUTPUT_DIR) | |
| if sft_config.push_to_hub: | |
| trainer.push_to_hub(commit_message="RFT: Best-of-N rejection fine-tuning") | |
| print(f"\n{'='*60}") | |
| print(f"RFT complete! Model at: https://huggingface.co/{RFT_HUB_MODEL_ID}") | |
| print(f"{'='*60}") | |
| # ============================================================ | |
| # Main | |
| # ============================================================ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--stage", choices=["generate", "train", "all"], default="all") | |
| args = parser.parse_args() | |
| if args.stage in ("generate", "all"): | |
| stage_generate() | |
| if args.stage in ("train", "all"): | |
| stage_train() | |
| if __name__ == "__main__": | |
| main() | |