#!/usr/bin/env python3 """ ARC-8B: Adaptive Repetition Controller ======================================= Decode-time behavioral control for language models. This script loads the complete ARC system and runs inference with multi-head cognitive control that detects and suppresses: - Repetition loops (125× separation) - Hedging phrases (1.5× separation) - Verbosity/filler (2.1× separation) - Sycophancy (experimental) Usage: python inference.py # Interactive mode python inference.py --prompt "Hello" # Single prompt python inference.py --no-arc # Disable ARC (baseline) Requirements: pip install torch transformers accelerate bitsandbytes Model: LoganResearch/ARC-Base-8B (16GB, runs in ~10GB with 4-bit) """ import os import sys import argparse import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Tuple from dataclasses import dataclass # ============================================================================= # CONFIGURATION # ============================================================================= @dataclass class ARCConfig: """ARC System Configuration""" # Model model_id: str = "LoganResearch/ARC-Base-8B" load_in_4bit: bool = True load_in_8bit: bool = False device_map: str = "auto" # Architecture (must match training) d_model: int = 4096 n_layers: int = 32 d_fiber: int = 16 d_control: int = 64 # Intervention thresholds (tuned empirically) repetition_threshold: float = 0.70 hedging_threshold: float = 0.60 verbosity_threshold: float = 0.65 sycophancy_threshold: float = 0.60 # Intervention penalties repetition_penalty: float = 5.0 hedging_penalty: float = 3.0 verbosity_penalty: float = 2.0 sycophancy_penalty: float = 2.0 # Generation max_new_tokens: int = 512 temperature: float = 0.8 top_p: float = 0.92 repetition_window: int = 32 # ============================================================================= # MULTI-HEAD PREDICTOR # ============================================================================= class MultiHeadPredictor(nn.Module): """ Prediction heads that monitor hidden states and detect behavioral patterns. The system uses shared "fiber projections" that compress hidden states, then individual heads that predict risk scores for specific behaviors. Architecture: Hidden States [n_layers × d_model] → Fiber Projections [n_layers × d_fiber] → Weighted Aggregation [d_fiber] → Per-Head MLP → Risk Score [0-1] """ def __init__(self, config: ARCConfig): super().__init__() self.config = config # Shared fiber projections (learned during CF-HoT training) self.fiber_projs = nn.ModuleList([ nn.Linear(config.d_model, config.d_fiber, bias=False) for _ in range(config.n_layers) ]) # Learned layer importance weights self.layer_weights = nn.Parameter(torch.ones(config.n_layers) / config.n_layers) # Individual prediction heads self.heads = nn.ModuleDict() self.loaded_heads: set = set() def _make_head(self) -> nn.Sequential: """Create a prediction head: fiber features → risk score""" return nn.Sequential( nn.Linear(self.config.d_fiber, self.config.d_control), nn.GELU(), nn.Linear(self.config.d_control, self.config.d_control), nn.GELU(), nn.Linear(self.config.d_control, 1) ) def add_head(self, name: str) -> None: """Add a new prediction head""" self.heads[name] = self._make_head() def get_fiber_features(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: """ Project hidden states through fiber projections and aggregate. Args: hidden_states: List of [batch, seq, d_model] tensors from each layer Returns: Aggregated features [batch, seq, d_fiber] """ device = hidden_states[0].device fibers = [] for i, (proj, hidden) in enumerate(zip(self.fiber_projs, hidden_states)): if i < len(hidden_states): proj = proj.to(device) fibers.append(proj(hidden.float())) # Weighted sum across layers weights = F.softmax(self.layer_weights.to(device)[:len(fibers)], dim=0) aggregated = sum(w * f for w, f in zip(weights, fibers)) return aggregated def get_risk(self, head_name: str, hidden_states: List[torch.Tensor]) -> torch.Tensor: """Get risk score from a specific head""" if head_name not in self.loaded_heads: return torch.zeros(1, device=hidden_states[0].device) features = self.get_fiber_features(hidden_states) logits = self.heads[head_name](features).squeeze(-1) return torch.sigmoid(logits) def get_all_risks(self, hidden_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]: """Get risk scores from all loaded heads""" if not self.loaded_heads: return {} device = hidden_states[0].device features = self.get_fiber_features(hidden_states) risks = {} for name in self.loaded_heads: self.heads[name] = self.heads[name].to(device) logits = self.heads[name](features).squeeze(-1) risks[name] = torch.sigmoid(logits) return risks # ============================================================================= # ARC SYSTEM # ============================================================================= class ARCSystem: """ Complete ARC (Adaptive Repetition Controller) System Loads model + prediction heads and provides controlled generation with real-time behavioral intervention. """ # Tokens to suppress for each behavior type HEDGE_STARTERS = [ "As", "I'm", "I", "It's", "While", "Although", "However", "That", "This", "Please", "Well", "So", "Actually" ] VERBOSE_STARTERS = [ "Let", "Basically", "Essentially", "Simply", "Indeed", "Furthermore", "Moreover", "Additionally", "Firstly" ] SYCOPHANCY_STARTERS = [ "Great", "Excellent", "Wonderful", "Absolutely", "Of", "Thank", "Sure", "Certainly", "Definitely" ] def __init__(self, config: Optional[ARCConfig] = None): self.config = config or ARCConfig() self.model = None self.tokenizer = None self.predictor = None # Token ID caches for suppression self._hedge_token_ids: set = set() self._verbose_token_ids: set = set() self._sycophancy_token_ids: set = set() # Stats self.total_interventions = {"repetition": 0, "hedging": 0, "verbosity": 0, "sycophancy": 0} def load(self, verbose: bool = True) -> "ARCSystem": """ Load all components from HuggingFace. Downloads and initializes: 1. Base model (Hermes-3-Llama-3.1-8B based) 2. Tokenizer 3. Prediction heads (repetition, hedging, verbosity, sycophancy) Returns: self (for chaining) """ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from huggingface_hub import hf_hub_download if verbose: print("=" * 60) print(" ARC-8B: Adaptive Repetition Controller") print(" Decode-time behavioral control system") print("=" * 60) # === 1. Tokenizer === if verbose: print("\n[1/4] Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_id, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # === 2. Model === if verbose: print("[2/4] Loading model...") if self.config.load_in_4bit: print(" (4-bit quantization enabled)") quantization_config = None if self.config.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) elif self.config.load_in_8bit: quantization_config = BitsAndBytesConfig(load_in_8bit=True) self.model = AutoModelForCausalLM.from_pretrained( self.config.model_id, quantization_config=quantization_config, device_map=self.config.device_map, torch_dtype=torch.float16, trust_remote_code=True ) self.model.eval() # === 3. Prediction Heads === if verbose: print("[3/4] Loading prediction heads...") device = next(self.model.parameters()).device self.predictor = MultiHeadPredictor(self.config).to(device).float() # Load risk_predictor.pt (contains fiber projections + repetition head) try: risk_path = hf_hub_download(self.config.model_id, "risk_predictor.pt") ckpt = torch.load(risk_path, map_location=device, weights_only=False) # The checkpoint contains the full state dict state = ckpt.get('risk_predictor', ckpt) # Load fiber projections for i in range(self.config.n_layers): key = f'fiber_projs.{i}.weight' if key in state: self.predictor.fiber_projs[i].weight.data = state[key].to(device).float() # Load layer weights if 'layer_weights' in state: self.predictor.layer_weights.data = state['layer_weights'].to(device).float() # Load repetition head self.predictor.add_head('repetition') self.predictor.heads['repetition'][0].weight.data = state['predictor.0.weight'].to(device).float() self.predictor.heads['repetition'][0].bias.data = state['predictor.0.bias'].to(device).float() self.predictor.heads['repetition'][2].weight.data = state['predictor.2.weight'].to(device).float() self.predictor.heads['repetition'][2].bias.data = state['predictor.2.bias'].to(device).float() self.predictor.heads['repetition'][4].weight.data = state['predictor.4.weight'].to(device).float() self.predictor.heads['repetition'][4].bias.data = state['predictor.4.bias'].to(device).float() self.predictor.loaded_heads.add('repetition') if verbose: print(" ✓ Repetition head (125× separation)") except Exception as e: if verbose: print(f" ✗ Repetition head: {e}") # Load additional heads for head_name in ['hedging', 'verbosity', 'sycophancy']: try: head_path = hf_hub_download(self.config.model_id, f"{head_name}_head.pt") ckpt = torch.load(head_path, map_location=device, weights_only=False) self.predictor.add_head(head_name) head_state = ckpt.get('head_state', ckpt) self.predictor.heads[head_name].load_state_dict(head_state) self.predictor.loaded_heads.add(head_name) if verbose: print(f" ✓ {head_name.capitalize()} head") except Exception as e: if verbose: print(f" ✗ {head_name.capitalize()} head: {e}") self.predictor.eval() # === 4. Build Token Suppression Sets === if verbose: print("[4/4] Building suppression vocabularies...") self._build_suppression_sets() if verbose: print("\n" + "=" * 60) print(f" ✓ ARC System Ready") print(f" Active heads: {list(self.predictor.loaded_heads)}") print("=" * 60 + "\n") return self def _build_suppression_sets(self) -> None: """Build token ID sets for behavioral suppression""" for word in self.HEDGE_STARTERS: tokens = self.tokenizer.encode(word, add_special_tokens=False) if tokens: self._hedge_token_ids.add(tokens[0]) for word in self.VERBOSE_STARTERS: tokens = self.tokenizer.encode(word, add_special_tokens=False) if tokens: self._verbose_token_ids.add(tokens[0]) for word in self.SYCOPHANCY_STARTERS: tokens = self.tokenizer.encode(word, add_special_tokens=False) if tokens: self._sycophancy_token_ids.add(tokens[0]) def _apply_interventions( self, logits: torch.Tensor, risks: Dict[str, torch.Tensor], recent_tokens: List[int] ) -> Tuple[torch.Tensor, Dict[str, bool]]: """ Apply behavioral interventions based on risk scores. Args: logits: [1, vocab_size] logits for next token risks: Dict of risk scores for each head recent_tokens: Recently generated token IDs Returns: Modified logits and dict of which interventions fired """ interventions = {} # Repetition: suppress recently used tokens if risks.get('repetition', 0) > self.config.repetition_threshold: for tok in set(recent_tokens[-self.config.repetition_window:]): logits[0, tok] -= self.config.repetition_penalty interventions['repetition'] = True self.total_interventions['repetition'] += 1 # Hedging: suppress hedge phrase starters if risks.get('hedging', 0) > self.config.hedging_threshold: for tok in self._hedge_token_ids: logits[0, tok] -= self.config.hedging_penalty interventions['hedging'] = True self.total_interventions['hedging'] += 1 # Verbosity: suppress filler phrase starters if risks.get('verbosity', 0) > self.config.verbosity_threshold: for tok in self._verbose_token_ids: logits[0, tok] -= self.config.verbosity_penalty interventions['verbosity'] = True self.total_interventions['verbosity'] += 1 # Sycophancy: suppress sycophantic starters if risks.get('sycophancy', 0) > self.config.sycophancy_threshold: for tok in self._sycophancy_token_ids: logits[0, tok] -= self.config.sycophancy_penalty interventions['sycophancy'] = True self.total_interventions['sycophancy'] += 1 return logits, interventions def generate( self, prompt: str, system_prompt: Optional[str] = None, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, use_arc: bool = True, verbose: bool = False ) -> str: """ Generate text with optional ARC behavioral control. Args: prompt: User input system_prompt: Optional system message max_new_tokens: Max tokens to generate (default: config value) temperature: Sampling temperature (default: config value) use_arc: Whether to use ARC intervention (default: True) verbose: Print intervention info (default: False) Returns: Generated text """ max_new_tokens = max_new_tokens or self.config.max_new_tokens temperature = temperature or self.config.temperature # Build chat format if system_prompt is None: system_prompt = "You are a helpful assistant." full_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" full_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n" full_prompt += "<|im_start|>assistant\n" device = next(self.model.parameters()).device input_ids = self.tokenizer.encode(full_prompt, return_tensors='pt').to(device) attention_mask = torch.ones_like(input_ids) generated_ids = input_ids.clone() intervention_counts = {"repetition": 0, "hedging": 0, "verbosity": 0, "sycophancy": 0} # Generation loop for step in range(max_new_tokens): with torch.no_grad(): outputs = self.model( input_ids=generated_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) logits = outputs.logits[:, -1, :] / temperature # ARC intervention if use_arc and self.predictor.loaded_heads: hidden_states = outputs.hidden_states[1:] # Skip embedding layer risks = self.predictor.get_all_risks(hidden_states) current_risks = {name: r[:, -1].item() for name, r in risks.items()} recent = generated_ids[0, -self.config.repetition_window:].tolist() logits, fired = self._apply_interventions(logits, current_risks, recent) for k, v in fired.items(): if v: intervention_counts[k] += 1 # Top-p sampling sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > self.config.top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_ids = torch.cat([generated_ids, next_token], dim=-1) attention_mask = torch.cat([attention_mask, torch.ones(1, 1, device=device)], dim=-1) # Check for EOS if next_token.item() == self.tokenizer.eos_token_id: break # Check for end of turn if next_token.item() == self.tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]: break # Decode response full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=False) # Extract assistant response if "<|im_start|>assistant\n" in full_output: response = full_output.split("<|im_start|>assistant\n")[-1] if "<|im_end|>" in response: response = response.split("<|im_end|>")[0] else: response = full_output if verbose: total = sum(intervention_counts.values()) print(f"\n[ARC Stats] Interventions: {total} total") for k, v in intervention_counts.items(): if v > 0: print(f" - {k}: {v}") return response.strip() def chat(self, system_prompt: Optional[str] = None) -> None: """ Interactive chat mode. Args: system_prompt: Optional system message """ print("\n" + "=" * 60) print(" ARC-8B Interactive Chat") print(" Commands: /quit, /stats, /arc on|off, /clear") print("=" * 60 + "\n") use_arc = True history = [] while True: try: user_input = input("You: ").strip() except (KeyboardInterrupt, EOFError): print("\nGoodbye!") break if not user_input: continue # Commands if user_input.lower() == '/quit': print("Goodbye!") break elif user_input.lower() == '/stats': print(f"\nTotal interventions: {self.total_interventions}\n") continue elif user_input.lower() == '/arc on': use_arc = True print("ARC enabled\n") continue elif user_input.lower() == '/arc off': use_arc = False print("ARC disabled (baseline mode)\n") continue elif user_input.lower() == '/clear': history = [] self.total_interventions = {k: 0 for k in self.total_interventions} print("History cleared\n") continue # Generate response response = self.generate( user_input, system_prompt=system_prompt, use_arc=use_arc, verbose=True ) print(f"\nAssistant: {response}\n") history.append({"user": user_input, "assistant": response}) # ============================================================================= # MAIN # ============================================================================= def main(): parser = argparse.ArgumentParser( description="ARC-8B: Adaptive Repetition Controller", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python inference.py # Interactive chat python inference.py --prompt "Hello" # Single prompt python inference.py --no-arc # Disable ARC (baseline) python inference.py --8bit # Use 8-bit quantization """ ) parser.add_argument("--prompt", "-p", type=str, help="Single prompt to process") parser.add_argument("--system", "-s", type=str, help="System prompt") parser.add_argument("--no-arc", action="store_true", help="Disable ARC intervention") parser.add_argument("--4bit", dest="load_4bit", action="store_true", default=True, help="Use 4-bit quantization (default)") parser.add_argument("--8bit", dest="load_8bit", action="store_true", help="Use 8-bit quantization") parser.add_argument("--no-quant", action="store_true", help="Disable quantization (requires ~32GB VRAM)") parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens to generate") parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") args = parser.parse_args() # Configure config = ARCConfig( max_new_tokens=args.max_tokens, temperature=args.temperature ) if args.load_8bit: config.load_in_4bit = False config.load_in_8bit = True elif args.no_quant: config.load_in_4bit = False config.load_in_8bit = False # Load arc = ARCSystem(config) arc.load() # Run if args.prompt: response = arc.generate( args.prompt, system_prompt=args.system, use_arc=not args.no_arc, verbose=True ) print(f"\n{response}\n") else: arc.chat(system_prompt=args.system) if __name__ == "__main__": main()