#!/usr/bin/env python3 """Best-of-N Rejection Fine-Tuning (RFT) for TMF921 — FIXED for RTX 6000 Ada. Generates N=8 completions per prompt using BATCHED generation (not sequential). Reduced to 200 prompts focused on weak layers. Total runtime: ~20-24 hours. Usage: export PYTHONPATH="$PWD/src" python scripts/train_rft.py --stage generate # ~20-24h python scripts/train_rft.py --stage train # ~2h python scripts/train_rft.py --stage all # both sequential """ import argparse import gc import json import os import re import random from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch from datasets import Dataset, load_dataset from peft import LoraConfig, PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed from tqdm import tqdm # ============================================================ # JSON Parsing & Evaluation # ============================================================ def strip_code_fence(text: str) -> str: text = text.strip() if text.startswith("```"): text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE) text = re.sub(r"\s*```$", "", text) return text.strip() def extract_json_text(text: str) -> str: text = strip_code_fence(text) if not text: return text start = text.find("{") end = text.rfind("}") if start >= 0 and end > start: return text[start:end + 1].strip() return text def parse_json(text: str) -> Tuple[Optional[Any], Optional[str]]: candidate = extract_json_text(text) try: return json.loads(candidate), None except Exception as e: return None, str(e)[:200] def canonical_json(obj: Any) -> str: return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":")) def flatten_json(obj: Any, prefix: str = "") -> Dict[str, Any]: out = {} if isinstance(obj, dict): for k, v in obj.items(): p = f"{prefix}.{k}" if prefix else str(k) out.update(flatten_json(v, p)) elif isinstance(obj, list): for i, v in enumerate(obj): out.update(flatten_json(v, f"{prefix}[{i}]")) else: out[prefix] = obj return out VOLATILE_KEY_EXACT = { "id", "uuid", "href", "name", "description", "displayName", "label", "@schemaLocation", "schemaLocation", "version", "revision", "createdAt", "updatedAt", "modifiedAt", "lastModified", "timestamp", "creationDate", "lastUpdate", "requestedStartDate", "requestedCompletionDate", "startTime", "endTime", "validFrom", "validTo", "validFor", "correlationId", "requestId", "transactionId", "reservationId", } VOLATILE_KEY_FRAGMENTS = ["href", "schema", "timestamp", "uuid", "correlation", "transaction"] PROTECTED_KEYS = {"sst", "sd", "sliceType", "slice_type", "latency", "reliability", "dl", "ul", "maxUEs", "maxNumberOfUEs"} ID_LIKE_RE = re.compile(r"\b(?:intent|slice|policy|booking|cell|me|gnb|nsi|nssi|req|report|monitor|assurance)[-_][A-Za-z0-9._:-]+", re.IGNORECASE) HEX_RE = re.compile(r"\b[0-9a-f]{8,}\b", re.IGNORECASE) ISO_TIME_RE = re.compile(r"\b\d{4}-\d{2}-\d{2}[T ][0-9:.+-Z]*\b") def is_volatile_key(key: str) -> bool: if key in PROTECTED_KEYS: return False if key in VOLATILE_KEY_EXACT: return True lk = key.lower() if lk in {k.lower() for k in VOLATILE_KEY_EXACT}: return True return any(fragment in lk for fragment in VOLATILE_KEY_FRAGMENTS) def normalize_string(s: str) -> str: s = ISO_TIME_RE.sub("