|
|
import numpy as np |
|
|
import torch |
|
|
import sys |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
from .ndim_topo_align import NDimTopoAlignSolver |
|
|
|
|
|
class SpernerTrainer: |
|
|
""" |
|
|
Hugging Face Adapter for Topological Alignment. |
|
|
|
|
|
Allows mixing multiple LoRA adapters (e.g., [Safety, Helpfulness, Coding]) |
|
|
and finding the optimal weighting to balance conflicting objectives |
|
|
without retraining the base model. |
|
|
""" |
|
|
def __init__(self, base_model_name, adapter_paths, objective_funcs, device=None, mock=True): |
|
|
""" |
|
|
Args: |
|
|
base_model_name: str, Hugging Face model ID. |
|
|
adapter_paths: list of str, paths to LoRA adapters. |
|
|
objective_funcs: list of functions f(model, tokenizer) -> float (Loss). |
|
|
device: optional device (default: cuda if available). |
|
|
mock: if True, skip loading the model (simulation only). Set to False to load real weights. |
|
|
""" |
|
|
self.n_objs = len(adapter_paths) |
|
|
if self.n_objs == 0 and mock: |
|
|
|
|
|
self.n_objs = 3 |
|
|
self.adapter_names = ["Safety", "Helpfulness", "Reasoning"] |
|
|
else: |
|
|
self.adapter_names = [f"adapter_{i}" for i in range(self.n_objs)] |
|
|
|
|
|
self.objective_funcs = objective_funcs |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.mock_mode = mock |
|
|
|
|
|
if not self.mock_mode: |
|
|
print(f"[INIT] Loading Base Model: {base_model_name}...", flush=True) |
|
|
self._load_model_real(base_model_name, adapter_paths) |
|
|
else: |
|
|
print("[INIT] Running in MOCK MODE (Simulation).", flush=True) |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
|
|
|
def _load_model_real(self, base_name, paths): |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(base_name) |
|
|
base = AutoModelForCausalLM.from_pretrained(base_name, device_map=self.device) |
|
|
self.model = PeftModel.from_pretrained(base, paths[0], adapter_name=self.adapter_names[0]) |
|
|
for i, path in enumerate(paths[1:], 1): |
|
|
self.model.load_adapter(path, adapter_name=self.adapter_names[i]) |
|
|
|
|
|
def evaluate_mixed_model(self, weights): |
|
|
""" |
|
|
Merges adapters with given weights and evaluates metrics. |
|
|
Returns array of Losses. |
|
|
|
|
|
Performance note: add_weighted_adapter + set_adapter takes ~0.1–0.5s per call. |
|
|
A Sperner walk of 50–100 steps can thus add tens of seconds of latency in |
|
|
interactive mode. See README "Known limitation: Interactive mode performance" |
|
|
for recommendations and future work (e.g. CUDA kernel for weighted sum at inference). |
|
|
""" |
|
|
if self.mock_mode: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [(1 - w)**2 for w in weights] |
|
|
|
|
|
|
|
|
combined_name = "sperner_mix" |
|
|
try: self.model.delete_adapter(combined_name) |
|
|
except: pass |
|
|
|
|
|
|
|
|
self.model.add_weighted_adapter( |
|
|
adapters=self.adapter_names, |
|
|
weights=list(weights), |
|
|
adapter_name=combined_name, |
|
|
combination_type="linear" |
|
|
) |
|
|
self.model.set_adapter(combined_name) |
|
|
|
|
|
|
|
|
losses = [] |
|
|
for func in self.objective_funcs: |
|
|
losses.append(func(self.model, self.tokenizer)) |
|
|
|
|
|
return losses |
|
|
|
|
|
def oracle_label(self, weights): |
|
|
""" |
|
|
Determines which objective is the 'Unhappiest'. |
|
|
""" |
|
|
losses = self.evaluate_mixed_model(weights) |
|
|
return np.argmax(losses) |
|
|
|
|
|
def train(self, grid_size=20): |
|
|
print(f"\n[SPERNER-TRAIN] Aligning {self.n_objs} Adapters (Grid {grid_size})...", flush=True) |
|
|
|
|
|
|
|
|
solver = NDimTopoAlignSolver(n_objs=self.n_objs, subdivision=grid_size) |
|
|
|
|
|
|
|
|
solver.oracle_label = lambda w: self.oracle_label(w) |
|
|
|
|
|
|
|
|
|
|
|
original_label_func = solver.oracle_label |
|
|
|
|
|
def wrapped_oracle(y_vec): |
|
|
|
|
|
w = solver.get_barycentric_weights(y_vec) |
|
|
|
|
|
return self.oracle_label(w) |
|
|
|
|
|
solver.oracle_label = wrapped_oracle |
|
|
|
|
|
|
|
|
best_weights = solver.solve() |
|
|
print(f"[SPERNER-TRAIN] Optimal Mixing Weights: {best_weights}", flush=True) |
|
|
return best_weights |
|
|
|
|
|
|
|
|
def train_generator(self, grid_size=20): |
|
|
""" |
|
|
Interactive Mode: Yields current mixing weights and waits for Human Label. |
|
|
Use with send(): |
|
|
gen = trainer.train_generator() |
|
|
current_weights = next(gen) |
|
|
# Show to human... Human says "Too Toxic" (Label 0) |
|
|
next_weights = gen.send(0) |
|
|
""" |
|
|
print(f"\n[SPERNER-INTERACTIVE] Starting Generator for {self.n_objs} Adapters...", flush=True) |
|
|
solver = NDimTopoAlignSolver(n_objs=self.n_objs, subdivision=grid_size) |
|
|
|
|
|
|
|
|
solver_gen = solver.solve_generator() |
|
|
|
|
|
try: |
|
|
|
|
|
out = next(solver_gen) |
|
|
current_v, current_w = out[0], out[1] |
|
|
current_phase = out[2] if len(out) == 3 else None |
|
|
|
|
|
while True: |
|
|
|
|
|
label = yield (current_w, current_phase) |
|
|
|
|
|
|
|
|
out = solver_gen.send(label) |
|
|
current_v, current_w = out[0], out[1] |
|
|
current_phase = out[2] if len(out) == 3 else None |
|
|
|
|
|
except StopIteration as e: |
|
|
|
|
|
return e.value |
|
|
|