ARC-Base-8B / inference.py
LoganResearch's picture
Fix device mismatch and tensor.item() bugs
749c71a verified
#!/usr/bin/env python3
"""
ARC-8B: Adaptive Repetition Controller
=======================================
Decode-time behavioral control for language models.
This script loads the complete ARC system and runs inference with
multi-head cognitive control that detects and suppresses:
- Repetition loops (125× separation)
- Hedging phrases (1.5× separation)
- Verbosity/filler (2.1× separation)
- Sycophancy (experimental)
Usage:
python inference.py # Interactive mode
python inference.py --prompt "Hello" # Single prompt
python inference.py --no-arc # Disable ARC (baseline)
Requirements:
pip install torch transformers accelerate bitsandbytes
Model: LoganResearch/ARC-Base-8B (16GB, runs in ~10GB with 4-bit)
"""
import os
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
# =============================================================================
# CONFIGURATION
# =============================================================================
@dataclass
class ARCConfig:
"""ARC System Configuration"""
# Model
model_id: str = "LoganResearch/ARC-Base-8B"
load_in_4bit: bool = True
load_in_8bit: bool = False
device_map: str = "auto"
# Architecture (must match training)
d_model: int = 4096
n_layers: int = 32
d_fiber: int = 16
d_control: int = 64
# Intervention thresholds (tuned empirically)
repetition_threshold: float = 0.70
hedging_threshold: float = 0.60
verbosity_threshold: float = 0.65
sycophancy_threshold: float = 0.60
# Intervention penalties
repetition_penalty: float = 5.0
hedging_penalty: float = 3.0
verbosity_penalty: float = 2.0
sycophancy_penalty: float = 2.0
# Generation
max_new_tokens: int = 512
temperature: float = 0.8
top_p: float = 0.92
repetition_window: int = 32
# =============================================================================
# MULTI-HEAD PREDICTOR
# =============================================================================
class MultiHeadPredictor(nn.Module):
"""
Prediction heads that monitor hidden states and detect behavioral patterns.
The system uses shared "fiber projections" that compress hidden states,
then individual heads that predict risk scores for specific behaviors.
Architecture:
Hidden States [n_layers × d_model]
→ Fiber Projections [n_layers × d_fiber]
→ Weighted Aggregation [d_fiber]
→ Per-Head MLP → Risk Score [0-1]
"""
def __init__(self, config: ARCConfig):
super().__init__()
self.config = config
# Shared fiber projections (learned during CF-HoT training)
self.fiber_projs = nn.ModuleList([
nn.Linear(config.d_model, config.d_fiber, bias=False)
for _ in range(config.n_layers)
])
# Learned layer importance weights
self.layer_weights = nn.Parameter(torch.ones(config.n_layers) / config.n_layers)
# Individual prediction heads
self.heads = nn.ModuleDict()
self.loaded_heads: set = set()
def _make_head(self) -> nn.Sequential:
"""Create a prediction head: fiber features → risk score"""
return nn.Sequential(
nn.Linear(self.config.d_fiber, self.config.d_control),
nn.GELU(),
nn.Linear(self.config.d_control, self.config.d_control),
nn.GELU(),
nn.Linear(self.config.d_control, 1)
)
def add_head(self, name: str) -> None:
"""Add a new prediction head"""
self.heads[name] = self._make_head()
def get_fiber_features(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
"""
Project hidden states through fiber projections and aggregate.
Args:
hidden_states: List of [batch, seq, d_model] tensors from each layer
Returns:
Aggregated features [batch, seq, d_fiber]
"""
device = hidden_states[0].device
fibers = []
for i, (proj, hidden) in enumerate(zip(self.fiber_projs, hidden_states)):
if i < len(hidden_states):
proj = proj.to(device)
fibers.append(proj(hidden.float()))
# Weighted sum across layers
weights = F.softmax(self.layer_weights.to(device)[:len(fibers)], dim=0)
aggregated = sum(w * f for w, f in zip(weights, fibers))
return aggregated
def get_risk(self, head_name: str, hidden_states: List[torch.Tensor]) -> torch.Tensor:
"""Get risk score from a specific head"""
if head_name not in self.loaded_heads:
return torch.zeros(1, device=hidden_states[0].device)
features = self.get_fiber_features(hidden_states)
logits = self.heads[head_name](features).squeeze(-1)
return torch.sigmoid(logits)
def get_all_risks(self, hidden_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Get risk scores from all loaded heads"""
if not self.loaded_heads:
return {}
device = hidden_states[0].device
features = self.get_fiber_features(hidden_states)
risks = {}
for name in self.loaded_heads:
self.heads[name] = self.heads[name].to(device)
logits = self.heads[name](features).squeeze(-1)
risks[name] = torch.sigmoid(logits)
return risks
# =============================================================================
# ARC SYSTEM
# =============================================================================
class ARCSystem:
"""
Complete ARC (Adaptive Repetition Controller) System
Loads model + prediction heads and provides controlled generation
with real-time behavioral intervention.
"""
# Tokens to suppress for each behavior type
HEDGE_STARTERS = [
"As", "I'm", "I", "It's", "While", "Although", "However",
"That", "This", "Please", "Well", "So", "Actually"
]
VERBOSE_STARTERS = [
"Let", "Basically", "Essentially", "Simply", "Indeed",
"Furthermore", "Moreover", "Additionally", "Firstly"
]
SYCOPHANCY_STARTERS = [
"Great", "Excellent", "Wonderful", "Absolutely", "Of",
"Thank", "Sure", "Certainly", "Definitely"
]
def __init__(self, config: Optional[ARCConfig] = None):
self.config = config or ARCConfig()
self.model = None
self.tokenizer = None
self.predictor = None
# Token ID caches for suppression
self._hedge_token_ids: set = set()
self._verbose_token_ids: set = set()
self._sycophancy_token_ids: set = set()
# Stats
self.total_interventions = {"repetition": 0, "hedging": 0, "verbosity": 0, "sycophancy": 0}
def load(self, verbose: bool = True) -> "ARCSystem":
"""
Load all components from HuggingFace.
Downloads and initializes:
1. Base model (Hermes-3-Llama-3.1-8B based)
2. Tokenizer
3. Prediction heads (repetition, hedging, verbosity, sycophancy)
Returns:
self (for chaining)
"""
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import hf_hub_download
if verbose:
print("=" * 60)
print(" ARC-8B: Adaptive Repetition Controller")
print(" Decode-time behavioral control system")
print("=" * 60)
# === 1. Tokenizer ===
if verbose:
print("\n[1/4] Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model_id,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# === 2. Model ===
if verbose:
print("[2/4] Loading model...")
if self.config.load_in_4bit:
print(" (4-bit quantization enabled)")
quantization_config = None
if self.config.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif self.config.load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_id,
quantization_config=quantization_config,
device_map=self.config.device_map,
torch_dtype=torch.float16,
trust_remote_code=True
)
self.model.eval()
# === 3. Prediction Heads ===
if verbose:
print("[3/4] Loading prediction heads...")
device = next(self.model.parameters()).device
self.predictor = MultiHeadPredictor(self.config).to(device).float()
# Load risk_predictor.pt (contains fiber projections + repetition head)
try:
risk_path = hf_hub_download(self.config.model_id, "risk_predictor.pt")
ckpt = torch.load(risk_path, map_location=device, weights_only=False)
# The checkpoint contains the full state dict
state = ckpt.get('risk_predictor', ckpt)
# Load fiber projections
for i in range(self.config.n_layers):
key = f'fiber_projs.{i}.weight'
if key in state:
self.predictor.fiber_projs[i].weight.data = state[key].to(device).float()
# Load layer weights
if 'layer_weights' in state:
self.predictor.layer_weights.data = state['layer_weights'].to(device).float()
# Load repetition head
self.predictor.add_head('repetition')
self.predictor.heads['repetition'][0].weight.data = state['predictor.0.weight'].to(device).float()
self.predictor.heads['repetition'][0].bias.data = state['predictor.0.bias'].to(device).float()
self.predictor.heads['repetition'][2].weight.data = state['predictor.2.weight'].to(device).float()
self.predictor.heads['repetition'][2].bias.data = state['predictor.2.bias'].to(device).float()
self.predictor.heads['repetition'][4].weight.data = state['predictor.4.weight'].to(device).float()
self.predictor.heads['repetition'][4].bias.data = state['predictor.4.bias'].to(device).float()
self.predictor.loaded_heads.add('repetition')
if verbose:
print(" ✓ Repetition head (125× separation)")
except Exception as e:
if verbose:
print(f" ✗ Repetition head: {e}")
# Load additional heads
for head_name in ['hedging', 'verbosity', 'sycophancy']:
try:
head_path = hf_hub_download(self.config.model_id, f"{head_name}_head.pt")
ckpt = torch.load(head_path, map_location=device, weights_only=False)
self.predictor.add_head(head_name)
head_state = ckpt.get('head_state', ckpt)
self.predictor.heads[head_name].load_state_dict(head_state)
self.predictor.loaded_heads.add(head_name)
if verbose:
print(f" ✓ {head_name.capitalize()} head")
except Exception as e:
if verbose:
print(f" ✗ {head_name.capitalize()} head: {e}")
self.predictor.eval()
# === 4. Build Token Suppression Sets ===
if verbose:
print("[4/4] Building suppression vocabularies...")
self._build_suppression_sets()
if verbose:
print("\n" + "=" * 60)
print(f" ✓ ARC System Ready")
print(f" Active heads: {list(self.predictor.loaded_heads)}")
print("=" * 60 + "\n")
return self
def _build_suppression_sets(self) -> None:
"""Build token ID sets for behavioral suppression"""
for word in self.HEDGE_STARTERS:
tokens = self.tokenizer.encode(word, add_special_tokens=False)
if tokens:
self._hedge_token_ids.add(tokens[0])
for word in self.VERBOSE_STARTERS:
tokens = self.tokenizer.encode(word, add_special_tokens=False)
if tokens:
self._verbose_token_ids.add(tokens[0])
for word in self.SYCOPHANCY_STARTERS:
tokens = self.tokenizer.encode(word, add_special_tokens=False)
if tokens:
self._sycophancy_token_ids.add(tokens[0])
def _apply_interventions(
self,
logits: torch.Tensor,
risks: Dict[str, torch.Tensor],
recent_tokens: List[int]
) -> Tuple[torch.Tensor, Dict[str, bool]]:
"""
Apply behavioral interventions based on risk scores.
Args:
logits: [1, vocab_size] logits for next token
risks: Dict of risk scores for each head
recent_tokens: Recently generated token IDs
Returns:
Modified logits and dict of which interventions fired
"""
interventions = {}
# Repetition: suppress recently used tokens
if risks.get('repetition', 0) > self.config.repetition_threshold:
for tok in set(recent_tokens[-self.config.repetition_window:]):
logits[0, tok] -= self.config.repetition_penalty
interventions['repetition'] = True
self.total_interventions['repetition'] += 1
# Hedging: suppress hedge phrase starters
if risks.get('hedging', 0) > self.config.hedging_threshold:
for tok in self._hedge_token_ids:
logits[0, tok] -= self.config.hedging_penalty
interventions['hedging'] = True
self.total_interventions['hedging'] += 1
# Verbosity: suppress filler phrase starters
if risks.get('verbosity', 0) > self.config.verbosity_threshold:
for tok in self._verbose_token_ids:
logits[0, tok] -= self.config.verbosity_penalty
interventions['verbosity'] = True
self.total_interventions['verbosity'] += 1
# Sycophancy: suppress sycophantic starters
if risks.get('sycophancy', 0) > self.config.sycophancy_threshold:
for tok in self._sycophancy_token_ids:
logits[0, tok] -= self.config.sycophancy_penalty
interventions['sycophancy'] = True
self.total_interventions['sycophancy'] += 1
return logits, interventions
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
use_arc: bool = True,
verbose: bool = False
) -> str:
"""
Generate text with optional ARC behavioral control.
Args:
prompt: User input
system_prompt: Optional system message
max_new_tokens: Max tokens to generate (default: config value)
temperature: Sampling temperature (default: config value)
use_arc: Whether to use ARC intervention (default: True)
verbose: Print intervention info (default: False)
Returns:
Generated text
"""
max_new_tokens = max_new_tokens or self.config.max_new_tokens
temperature = temperature or self.config.temperature
# Build chat format
if system_prompt is None:
system_prompt = "You are a helpful assistant."
full_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
full_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n"
full_prompt += "<|im_start|>assistant\n"
device = next(self.model.parameters()).device
input_ids = self.tokenizer.encode(full_prompt, return_tensors='pt').to(device)
attention_mask = torch.ones_like(input_ids)
generated_ids = input_ids.clone()
intervention_counts = {"repetition": 0, "hedging": 0, "verbosity": 0, "sycophancy": 0}
# Generation loop
for step in range(max_new_tokens):
with torch.no_grad():
outputs = self.model(
input_ids=generated_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
logits = outputs.logits[:, -1, :] / temperature
# ARC intervention
if use_arc and self.predictor.loaded_heads:
hidden_states = outputs.hidden_states[1:] # Skip embedding layer
risks = self.predictor.get_all_risks(hidden_states)
current_risks = {name: r[:, -1].item() for name, r in risks.items()}
recent = generated_ids[0, -self.config.repetition_window:].tolist()
logits, fired = self._apply_interventions(logits, current_risks, recent)
for k, v in fired.items():
if v:
intervention_counts[k] += 1
# Top-p sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > self.config.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones(1, 1, device=device)], dim=-1)
# Check for EOS
if next_token.item() == self.tokenizer.eos_token_id:
break
# Check for end of turn
if next_token.item() == self.tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]:
break
# Decode response
full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=False)
# Extract assistant response
if "<|im_start|>assistant\n" in full_output:
response = full_output.split("<|im_start|>assistant\n")[-1]
if "<|im_end|>" in response:
response = response.split("<|im_end|>")[0]
else:
response = full_output
if verbose:
total = sum(intervention_counts.values())
print(f"\n[ARC Stats] Interventions: {total} total")
for k, v in intervention_counts.items():
if v > 0:
print(f" - {k}: {v}")
return response.strip()
def chat(self, system_prompt: Optional[str] = None) -> None:
"""
Interactive chat mode.
Args:
system_prompt: Optional system message
"""
print("\n" + "=" * 60)
print(" ARC-8B Interactive Chat")
print(" Commands: /quit, /stats, /arc on|off, /clear")
print("=" * 60 + "\n")
use_arc = True
history = []
while True:
try:
user_input = input("You: ").strip()
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
if not user_input:
continue
# Commands
if user_input.lower() == '/quit':
print("Goodbye!")
break
elif user_input.lower() == '/stats':
print(f"\nTotal interventions: {self.total_interventions}\n")
continue
elif user_input.lower() == '/arc on':
use_arc = True
print("ARC enabled\n")
continue
elif user_input.lower() == '/arc off':
use_arc = False
print("ARC disabled (baseline mode)\n")
continue
elif user_input.lower() == '/clear':
history = []
self.total_interventions = {k: 0 for k in self.total_interventions}
print("History cleared\n")
continue
# Generate response
response = self.generate(
user_input,
system_prompt=system_prompt,
use_arc=use_arc,
verbose=True
)
print(f"\nAssistant: {response}\n")
history.append({"user": user_input, "assistant": response})
# =============================================================================
# MAIN
# =============================================================================
def main():
parser = argparse.ArgumentParser(
description="ARC-8B: Adaptive Repetition Controller",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python inference.py # Interactive chat
python inference.py --prompt "Hello" # Single prompt
python inference.py --no-arc # Disable ARC (baseline)
python inference.py --8bit # Use 8-bit quantization
"""
)
parser.add_argument("--prompt", "-p", type=str, help="Single prompt to process")
parser.add_argument("--system", "-s", type=str, help="System prompt")
parser.add_argument("--no-arc", action="store_true", help="Disable ARC intervention")
parser.add_argument("--4bit", dest="load_4bit", action="store_true", default=True, help="Use 4-bit quantization (default)")
parser.add_argument("--8bit", dest="load_8bit", action="store_true", help="Use 8-bit quantization")
parser.add_argument("--no-quant", action="store_true", help="Disable quantization (requires ~32GB VRAM)")
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
args = parser.parse_args()
# Configure
config = ARCConfig(
max_new_tokens=args.max_tokens,
temperature=args.temperature
)
if args.load_8bit:
config.load_in_4bit = False
config.load_in_8bit = True
elif args.no_quant:
config.load_in_4bit = False
config.load_in_8bit = False
# Load
arc = ARCSystem(config)
arc.load()
# Run
if args.prompt:
response = arc.generate(
args.prompt,
system_prompt=args.system,
use_arc=not args.no_arc,
verbose=True
)
print(f"\n{response}\n")
else:
arc.chat(system_prompt=args.system)
if __name__ == "__main__":
main()