topo-align / equilib /sperner_trainer.py
omesbah's picture
chore: initial commit for Hugging Face publication
8797abf
import numpy as np
import torch
import sys
# Mock imports for demonstration if libraries technically missing in environment
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
except ImportError:
pass
from .ndim_topo_align import NDimTopoAlignSolver
class SpernerTrainer:
"""
Hugging Face Adapter for Topological Alignment.
Allows mixing multiple LoRA adapters (e.g., [Safety, Helpfulness, Coding])
and finding the optimal weighting to balance conflicting objectives
without retraining the base model.
"""
def __init__(self, base_model_name, adapter_paths, objective_funcs, device=None, mock=True):
"""
Args:
base_model_name: str, Hugging Face model ID.
adapter_paths: list of str, paths to LoRA adapters.
objective_funcs: list of functions f(model, tokenizer) -> float (Loss).
device: optional device (default: cuda if available).
mock: if True, skip loading the model (simulation only). Set to False to load real weights.
"""
self.n_objs = len(adapter_paths)
if self.n_objs == 0 and mock:
# Default for demo if no adapters provided
self.n_objs = 3
self.adapter_names = ["Safety", "Helpfulness", "Reasoning"]
else:
self.adapter_names = [f"adapter_{i}" for i in range(self.n_objs)]
self.objective_funcs = objective_funcs
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.mock_mode = mock
if not self.mock_mode:
print(f"[INIT] Loading Base Model: {base_model_name}...", flush=True)
self._load_model_real(base_model_name, adapter_paths)
else:
print("[INIT] Running in MOCK MODE (Simulation).", flush=True)
self.model = None
self.tokenizer = None
def _load_model_real(self, base_name, paths):
# This code would run in a real GPU environment
self.tokenizer = AutoTokenizer.from_pretrained(base_name)
base = AutoModelForCausalLM.from_pretrained(base_name, device_map=self.device)
self.model = PeftModel.from_pretrained(base, paths[0], adapter_name=self.adapter_names[0])
for i, path in enumerate(paths[1:], 1):
self.model.load_adapter(path, adapter_name=self.adapter_names[i])
def evaluate_mixed_model(self, weights):
"""
Merges adapters with given weights and evaluates metrics.
Returns array of Losses.
Performance note: add_weighted_adapter + set_adapter takes ~0.1–0.5s per call.
A Sperner walk of 50–100 steps can thus add tens of seconds of latency in
interactive mode. See README "Known limitation: Interactive mode performance"
for recommendations and future work (e.g. CUDA kernel for weighted sum at inference).
"""
if self.mock_mode:
# Simulation of Trade-offs
# weights sum to 1.
# Obj 0: Hates w[0] low. Loss = (1 - w[0])^2
# Obj 1: Hates w[1] low. Loss = (1 - w[1])^2
return [(1 - w)**2 for w in weights]
# Real Implementation
combined_name = "sperner_mix"
try: self.model.delete_adapter(combined_name)
except: pass
# Weighted Merge
self.model.add_weighted_adapter(
adapters=self.adapter_names,
weights=list(weights),
adapter_name=combined_name,
combination_type="linear"
)
self.model.set_adapter(combined_name)
# Eval
losses = []
for func in self.objective_funcs:
losses.append(func(self.model, self.tokenizer))
return losses
def oracle_label(self, weights):
"""
Determines which objective is the 'Unhappiest'.
"""
losses = self.evaluate_mixed_model(weights)
return np.argmax(losses)
def train(self, grid_size=20):
print(f"\n[SPERNER-TRAIN] Aligning {self.n_objs} Adapters (Grid {grid_size})...", flush=True)
# Instantiate Solver
solver = NDimTopoAlignSolver(n_objs=self.n_objs, subdivision=grid_size)
# Override Oracle
solver.oracle_label = lambda w: self.oracle_label(w)
# Note: NDimSolver uses 'y' coords (cumulative integers).
# We need to wrap it to accept 'y' -> convert to 'w' -> call our label.
original_label_func = solver.oracle_label
def wrapped_oracle(y_vec):
# 1. Convert y to Barycentric Weights
w = solver.get_barycentric_weights(y_vec)
# 2. Call Trainer Label logic
return self.oracle_label(w)
solver.oracle_label = wrapped_oracle
# Run
best_weights = solver.solve()
print(f"[SPERNER-TRAIN] Optimal Mixing Weights: {best_weights}", flush=True)
return best_weights
def train_generator(self, grid_size=20):
"""
Interactive Mode: Yields current mixing weights and waits for Human Label.
Use with send():
gen = trainer.train_generator()
current_weights = next(gen)
# Show to human... Human says "Too Toxic" (Label 0)
next_weights = gen.send(0)
"""
print(f"\n[SPERNER-INTERACTIVE] Starting Generator for {self.n_objs} Adapters...", flush=True)
solver = NDimTopoAlignSolver(n_objs=self.n_objs, subdivision=grid_size)
# We need to bridge the Solver's generator (yields v, w) with Trainer's (yields w, receives label).
solver_gen = solver.solve_generator()
try:
# Prime the solver (solver yields (v, w, phase))
out = next(solver_gen)
current_v, current_w = out[0], out[1]
current_phase = out[2] if len(out) == 3 else None
while True:
# 1. Yield (weights, phase) to UI
label = yield (current_w, current_phase)
# 2. Send label to Solver; get next (v, w, phase)
out = solver_gen.send(label)
current_v, current_w = out[0], out[1]
current_phase = out[2] if len(out) == 3 else None
except StopIteration as e:
# Solver finished
return e.value