|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import time |
|
|
|
|
|
|
|
|
try: |
|
|
from .sperner_trainer import SpernerTrainer |
|
|
from .analytics import calculate_frustration_score |
|
|
except (ImportError, ValueError): |
|
|
calculate_frustration_score = lambda path: 1.0 |
|
|
|
|
|
class SpernerTrainer: |
|
|
def __init__(self, *args, **kwargs): |
|
|
self.n_objs = 3 |
|
|
self.adapter_names = ["Safety", "Helpfulness", "Creativity"] |
|
|
|
|
|
def evaluate_mixed_model(self, weights): |
|
|
|
|
|
|
|
|
target = np.array([0.33, 0.33, 0.33]) |
|
|
loss = np.sum((weights - target)**2) |
|
|
return [loss * w for w in weights] |
|
|
|
|
|
def train_generator(self, grid_size=10): |
|
|
|
|
|
|
|
|
for i in range(20): |
|
|
w = np.random.dirichlet(np.ones(3)) |
|
|
label = yield w |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Topo-Align Human RLHF", layout="wide") |
|
|
|
|
|
st.title("🧬 Topo-Align: Human-in-the-Loop Alignment") |
|
|
st.markdown(""" |
|
|
Use Sperner's Lemma to align your LLM by enabling a human judge (you) to be the Oracle. |
|
|
Instead of a reward model, **you** decide which trade-off is unacceptable. |
|
|
""") |
|
|
|
|
|
|
|
|
if "trainer" not in st.session_state: |
|
|
st.session_state.trainer = SpernerTrainer("meta-llama/Llama-2-7b-hf", [], []) |
|
|
st.session_state.solver_gen = None |
|
|
st.session_state.step = 0 |
|
|
st.session_state.history = [] |
|
|
st.session_state.current_weights = None |
|
|
st.session_state.current_phase = None |
|
|
st.session_state.finished = False |
|
|
|
|
|
def start_alignment(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state.solver_gen = st.session_state.trainer.train_generator(grid_size=8) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
params = next(st.session_state.solver_gen) |
|
|
if isinstance(params, tuple) and len(params) >= 2: |
|
|
st.session_state.current_weights, st.session_state.current_phase = params[0], params[1] |
|
|
elif isinstance(params, tuple): |
|
|
st.session_state.current_weights = params[0] if len(params) == 1 else params[1] |
|
|
st.session_state.current_phase = None |
|
|
else: |
|
|
st.session_state.current_weights = params |
|
|
st.session_state.current_phase = None |
|
|
st.session_state.step = 1 |
|
|
st.session_state.history = [] |
|
|
st.session_state.finished = False |
|
|
except StopIteration: |
|
|
st.session_state.finished = True |
|
|
|
|
|
def submit_verdict(label_idx): |
|
|
if st.session_state.solver_gen is None: return |
|
|
|
|
|
|
|
|
try: |
|
|
params = st.session_state.solver_gen.send(label_idx) |
|
|
if isinstance(params, tuple) and len(params) >= 2: |
|
|
st.session_state.current_weights, st.session_state.current_phase = params[0], params[1] |
|
|
elif isinstance(params, tuple): |
|
|
st.session_state.current_weights = params[1] if len(params) == 2 else params[0] |
|
|
st.session_state.current_phase = None |
|
|
else: |
|
|
st.session_state.current_weights = params |
|
|
st.session_state.current_phase = None |
|
|
st.session_state.step += 1 |
|
|
st.session_state.history.append(st.session_state.current_weights) |
|
|
except StopIteration as e: |
|
|
st.session_state.finished = True |
|
|
if hasattr(e, 'value'): |
|
|
st.session_state.final_result = e.value |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Configuration") |
|
|
if st.button("Start / Reset Alignment"): |
|
|
start_alignment() |
|
|
|
|
|
st.metric("Step", st.session_state.step) |
|
|
|
|
|
|
|
|
if st.session_state.history: |
|
|
try: |
|
|
score = calculate_frustration_score(st.session_state.history) |
|
|
if score > 3.0: |
|
|
st.error(f"High Conflict Detected (Score: {score:.2f}). You are giving contradictory feedback!") |
|
|
else: |
|
|
st.success(f"Alignment Stable (Score: {score:.2f})") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if st.session_state.finished: |
|
|
st.success("Alignment Converged!") |
|
|
if hasattr(st.session_state, 'final_result'): |
|
|
st.write("Best Weights:", np.round(st.session_state.final_result, 3)) |
|
|
|
|
|
|
|
|
if st.session_state.solver_gen and not st.session_state.finished: |
|
|
|
|
|
phase = st.session_state.current_phase |
|
|
names = st.session_state.trainer.adapter_names |
|
|
if phase is not None and isinstance(phase, (tuple, list)) and len(phase) >= 2: |
|
|
active_dim, total_dim = int(phase[0]), int(phase[1]) |
|
|
phase_labels = [] |
|
|
for k in range(1, total_dim + 1): |
|
|
if k == 1: |
|
|
msg = f"Aligning {names[0]} vs {names[1]}" |
|
|
else: |
|
|
msg = f"Adding {names[k]}" if k < len(names) else f"Adding objective {k+1}" |
|
|
if k < active_dim: |
|
|
phase_labels.append(f"Phase {k}: {msg}... Done.") |
|
|
elif k == active_dim: |
|
|
phase_labels.append(f"Phase {k}: {msg}... In progress.") |
|
|
else: |
|
|
phase_labels.append(f"Phase {k}: {msg}... Pending.") |
|
|
for label in phase_labels: |
|
|
st.caption(label) |
|
|
elif phase is not None: |
|
|
st.caption(f"Solver phase: {phase}") |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
|
|
with col1: |
|
|
st.subheader("Current Model Mix") |
|
|
weights = st.session_state.current_weights |
|
|
|
|
|
|
|
|
params = {} |
|
|
for i, name in enumerate(st.session_state.trainer.adapter_names): |
|
|
val = weights[i] if i < len(weights) else 0 |
|
|
st.progress(float(val), text=f"{name}: {val:.2f}") |
|
|
params[name] = val |
|
|
|
|
|
with col2: |
|
|
st.subheader("Model Generation Preview") |
|
|
st.info("Simulating text generation with current mixing weights...") |
|
|
st.code(f"Output with mix {np.round(weights, 2)}:\n\n'The user asked for code. This model is {weights[0]:.2f} Safe and {weights[1]:.2f} Helpful.'") |
|
|
|
|
|
st.divider() |
|
|
|
|
|
st.subheader("👨⚖️ Your Verdict: What is the PRIMARY defect?") |
|
|
st.write("Select the objective that is **most dissatisfied** (i.e. the one that needs MORE attention or is causing the failure).") |
|
|
|
|
|
cols = st.columns(len(st.session_state.trainer.adapter_names)) |
|
|
for i, name in enumerate(st.session_state.trainer.adapter_names): |
|
|
with cols[i]: |
|
|
if st.button(f"Too Poor: {name}", key=f"btn_{i}", use_container_width=True): |
|
|
submit_verdict(i) |
|
|
st.rerun() |
|
|
|
|
|
elif st.session_state.finished: |
|
|
st.balloons() |
|
|
st.header("Optimization Complete") |
|
|
st.write("The topological walk has converged to a fixed point.") |
|
|
|
|
|
|
|
|
if st.session_state.history: |
|
|
chart_data = np.array(st.session_state.history) |
|
|
st.line_chart(chart_data) |
|
|
|
|
|
else: |
|
|
st.info("Click 'Start Alignment' to begin the Human-in-the-Loop session.") |
|
|
|