LoganResearch commited on
Commit
1f68519
·
verified ·
1 Parent(s): a5f7b81

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +627 -0
inference.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ARC-8B: Adaptive Repetition Controller
4
+ =======================================
5
+ Decode-time behavioral control for language models.
6
+
7
+ This script loads the complete ARC system and runs inference with
8
+ multi-head cognitive control that detects and suppresses:
9
+ - Repetition loops (125× separation)
10
+ - Hedging phrases (1.5× separation)
11
+ - Verbosity/filler (2.1× separation)
12
+ - Sycophancy (experimental)
13
+
14
+ Usage:
15
+ python inference.py # Interactive mode
16
+ python inference.py --prompt "Hello" # Single prompt
17
+ python inference.py --no-arc # Disable ARC (baseline)
18
+
19
+ Requirements:
20
+ pip install torch transformers accelerate bitsandbytes
21
+
22
+ Model: LoganResearch/ARC-Base-8B (16GB, runs in ~10GB with 4-bit)
23
+ """
24
+
25
+ import os
26
+ import sys
27
+ import argparse
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from typing import Dict, List, Optional, Tuple
32
+ from dataclasses import dataclass
33
+
34
+
35
+ # =============================================================================
36
+ # CONFIGURATION
37
+ # =============================================================================
38
+
39
+ @dataclass
40
+ class ARCConfig:
41
+ """ARC System Configuration"""
42
+ # Model
43
+ model_id: str = "LoganResearch/ARC-Base-8B"
44
+ load_in_4bit: bool = True
45
+ load_in_8bit: bool = False
46
+ device_map: str = "auto"
47
+
48
+ # Architecture (must match training)
49
+ d_model: int = 4096
50
+ n_layers: int = 32
51
+ d_fiber: int = 16
52
+ d_control: int = 64
53
+
54
+ # Intervention thresholds (tuned empirically)
55
+ repetition_threshold: float = 0.70
56
+ hedging_threshold: float = 0.60
57
+ verbosity_threshold: float = 0.65
58
+ sycophancy_threshold: float = 0.60
59
+
60
+ # Intervention penalties
61
+ repetition_penalty: float = 5.0
62
+ hedging_penalty: float = 3.0
63
+ verbosity_penalty: float = 2.0
64
+ sycophancy_penalty: float = 2.0
65
+
66
+ # Generation
67
+ max_new_tokens: int = 512
68
+ temperature: float = 0.8
69
+ top_p: float = 0.92
70
+ repetition_window: int = 32
71
+
72
+
73
+ # =============================================================================
74
+ # MULTI-HEAD PREDICTOR
75
+ # =============================================================================
76
+
77
+ class MultiHeadPredictor(nn.Module):
78
+ """
79
+ Prediction heads that monitor hidden states and detect behavioral patterns.
80
+
81
+ The system uses shared "fiber projections" that compress hidden states,
82
+ then individual heads that predict risk scores for specific behaviors.
83
+
84
+ Architecture:
85
+ Hidden States [n_layers × d_model]
86
+ → Fiber Projections [n_layers × d_fiber]
87
+ → Weighted Aggregation [d_fiber]
88
+ → Per-Head MLP → Risk Score [0-1]
89
+ """
90
+
91
+ def __init__(self, config: ARCConfig):
92
+ super().__init__()
93
+ self.config = config
94
+
95
+ # Shared fiber projections (learned during CF-HoT training)
96
+ self.fiber_projs = nn.ModuleList([
97
+ nn.Linear(config.d_model, config.d_fiber, bias=False)
98
+ for _ in range(config.n_layers)
99
+ ])
100
+
101
+ # Learned layer importance weights
102
+ self.layer_weights = nn.Parameter(torch.ones(config.n_layers) / config.n_layers)
103
+
104
+ # Individual prediction heads
105
+ self.heads = nn.ModuleDict()
106
+ self.loaded_heads: set = set()
107
+
108
+ def _make_head(self) -> nn.Sequential:
109
+ """Create a prediction head: fiber features → risk score"""
110
+ return nn.Sequential(
111
+ nn.Linear(self.config.d_fiber, self.config.d_control),
112
+ nn.GELU(),
113
+ nn.Linear(self.config.d_control, self.config.d_control),
114
+ nn.GELU(),
115
+ nn.Linear(self.config.d_control, 1)
116
+ )
117
+
118
+ def add_head(self, name: str) -> None:
119
+ """Add a new prediction head"""
120
+ self.heads[name] = self._make_head()
121
+
122
+ def get_fiber_features(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
123
+ """
124
+ Project hidden states through fiber projections and aggregate.
125
+
126
+ Args:
127
+ hidden_states: List of [batch, seq, d_model] tensors from each layer
128
+
129
+ Returns:
130
+ Aggregated features [batch, seq, d_fiber]
131
+ """
132
+ fibers = []
133
+ for i, (proj, hidden) in enumerate(zip(self.fiber_projs, hidden_states)):
134
+ if i < len(hidden_states):
135
+ fibers.append(proj(hidden.float()))
136
+
137
+ # Weighted sum across layers
138
+ weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
139
+ aggregated = sum(w * f for w, f in zip(weights, fibers))
140
+ return aggregated
141
+
142
+ def get_risk(self, head_name: str, hidden_states: List[torch.Tensor]) -> torch.Tensor:
143
+ """Get risk score from a specific head"""
144
+ if head_name not in self.loaded_heads:
145
+ return torch.zeros(1, device=hidden_states[0].device)
146
+
147
+ features = self.get_fiber_features(hidden_states)
148
+ logits = self.heads[head_name](features).squeeze(-1)
149
+ return torch.sigmoid(logits)
150
+
151
+ def get_all_risks(self, hidden_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
152
+ """Get risk scores from all loaded heads"""
153
+ if not self.loaded_heads:
154
+ return {}
155
+
156
+ features = self.get_fiber_features(hidden_states)
157
+ risks = {}
158
+ for name in self.loaded_heads:
159
+ logits = self.heads[name](features).squeeze(-1)
160
+ risks[name] = torch.sigmoid(logits)
161
+ return risks
162
+
163
+
164
+ # =============================================================================
165
+ # ARC SYSTEM
166
+ # =============================================================================
167
+
168
+ class ARCSystem:
169
+ """
170
+ Complete ARC (Adaptive Repetition Controller) System
171
+
172
+ Loads model + prediction heads and provides controlled generation
173
+ with real-time behavioral intervention.
174
+ """
175
+
176
+ # Tokens to suppress for each behavior type
177
+ HEDGE_STARTERS = [
178
+ "As", "I'm", "I", "It's", "While", "Although", "However",
179
+ "That", "This", "Please", "Well", "So", "Actually"
180
+ ]
181
+ VERBOSE_STARTERS = [
182
+ "Let", "Basically", "Essentially", "Simply", "Indeed",
183
+ "Furthermore", "Moreover", "Additionally", "Firstly"
184
+ ]
185
+ SYCOPHANCY_STARTERS = [
186
+ "Great", "Excellent", "Wonderful", "Absolutely", "Of",
187
+ "Thank", "Sure", "Certainly", "Definitely"
188
+ ]
189
+
190
+ def __init__(self, config: Optional[ARCConfig] = None):
191
+ self.config = config or ARCConfig()
192
+
193
+ self.model = None
194
+ self.tokenizer = None
195
+ self.predictor = None
196
+
197
+ # Token ID caches for suppression
198
+ self._hedge_token_ids: set = set()
199
+ self._verbose_token_ids: set = set()
200
+ self._sycophancy_token_ids: set = set()
201
+
202
+ # Stats
203
+ self.total_interventions = {"repetition": 0, "hedging": 0, "verbosity": 0, "sycophancy": 0}
204
+
205
+ def load(self, verbose: bool = True) -> "ARCSystem":
206
+ """
207
+ Load all components from HuggingFace.
208
+
209
+ Downloads and initializes:
210
+ 1. Base model (Hermes-3-Llama-3.1-8B based)
211
+ 2. Tokenizer
212
+ 3. Prediction heads (repetition, hedging, verbosity, sycophancy)
213
+
214
+ Returns:
215
+ self (for chaining)
216
+ """
217
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
218
+ from huggingface_hub import hf_hub_download
219
+
220
+ if verbose:
221
+ print("=" * 60)
222
+ print(" ARC-8B: Adaptive Repetition Controller")
223
+ print(" Decode-time behavioral control system")
224
+ print("=" * 60)
225
+
226
+ # === 1. Tokenizer ===
227
+ if verbose:
228
+ print("\n[1/4] Loading tokenizer...")
229
+ self.tokenizer = AutoTokenizer.from_pretrained(
230
+ self.config.model_id,
231
+ trust_remote_code=True
232
+ )
233
+ if self.tokenizer.pad_token is None:
234
+ self.tokenizer.pad_token = self.tokenizer.eos_token
235
+
236
+ # === 2. Model ===
237
+ if verbose:
238
+ print("[2/4] Loading model...")
239
+ if self.config.load_in_4bit:
240
+ print(" (4-bit quantization enabled)")
241
+
242
+ quantization_config = None
243
+ if self.config.load_in_4bit:
244
+ quantization_config = BitsAndBytesConfig(
245
+ load_in_4bit=True,
246
+ bnb_4bit_compute_dtype=torch.float16,
247
+ bnb_4bit_use_double_quant=True,
248
+ bnb_4bit_quant_type="nf4"
249
+ )
250
+ elif self.config.load_in_8bit:
251
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
252
+
253
+ self.model = AutoModelForCausalLM.from_pretrained(
254
+ self.config.model_id,
255
+ quantization_config=quantization_config,
256
+ device_map=self.config.device_map,
257
+ torch_dtype=torch.float16,
258
+ trust_remote_code=True
259
+ )
260
+ self.model.eval()
261
+
262
+ # === 3. Prediction Heads ===
263
+ if verbose:
264
+ print("[3/4] Loading prediction heads...")
265
+
266
+ device = next(self.model.parameters()).device
267
+ self.predictor = MultiHeadPredictor(self.config).to(device).float()
268
+
269
+ # Load risk_predictor.pt (contains fiber projections + repetition head)
270
+ try:
271
+ risk_path = hf_hub_download(self.config.model_id, "risk_predictor.pt")
272
+ ckpt = torch.load(risk_path, map_location=device, weights_only=False)
273
+
274
+ # The checkpoint contains the full state dict
275
+ state = ckpt.get('risk_predictor', ckpt)
276
+
277
+ # Load fiber projections
278
+ for i in range(self.config.n_layers):
279
+ key = f'fiber_projs.{i}.weight'
280
+ if key in state:
281
+ self.predictor.fiber_projs[i].weight.data = state[key].to(device).float()
282
+
283
+ # Load layer weights
284
+ if 'layer_weights' in state:
285
+ self.predictor.layer_weights.data = state['layer_weights'].to(device).float()
286
+
287
+ # Load repetition head
288
+ self.predictor.add_head('repetition')
289
+ self.predictor.heads['repetition'][0].weight.data = state['predictor.0.weight'].to(device).float()
290
+ self.predictor.heads['repetition'][0].bias.data = state['predictor.0.bias'].to(device).float()
291
+ self.predictor.heads['repetition'][2].weight.data = state['predictor.2.weight'].to(device).float()
292
+ self.predictor.heads['repetition'][2].bias.data = state['predictor.2.bias'].to(device).float()
293
+ self.predictor.heads['repetition'][4].weight.data = state['predictor.4.weight'].to(device).float()
294
+ self.predictor.heads['repetition'][4].bias.data = state['predictor.4.bias'].to(device).float()
295
+ self.predictor.loaded_heads.add('repetition')
296
+
297
+ if verbose:
298
+ print(" ✓ Repetition head (125× separation)")
299
+ except Exception as e:
300
+ if verbose:
301
+ print(f" ✗ Repetition head: {e}")
302
+
303
+ # Load additional heads
304
+ for head_name in ['hedging', 'verbosity', 'sycophancy']:
305
+ try:
306
+ head_path = hf_hub_download(self.config.model_id, f"{head_name}_head.pt")
307
+ ckpt = torch.load(head_path, map_location=device, weights_only=False)
308
+
309
+ self.predictor.add_head(head_name)
310
+ head_state = ckpt.get('head_state', ckpt)
311
+ self.predictor.heads[head_name].load_state_dict(head_state)
312
+ self.predictor.loaded_heads.add(head_name)
313
+
314
+ if verbose:
315
+ print(f" ✓ {head_name.capitalize()} head")
316
+ except Exception as e:
317
+ if verbose:
318
+ print(f" ✗ {head_name.capitalize()} head: {e}")
319
+
320
+ self.predictor.eval()
321
+
322
+ # === 4. Build Token Suppression Sets ===
323
+ if verbose:
324
+ print("[4/4] Building suppression vocabularies...")
325
+
326
+ self._build_suppression_sets()
327
+
328
+ if verbose:
329
+ print("\n" + "=" * 60)
330
+ print(f" ✓ ARC System Ready")
331
+ print(f" Active heads: {list(self.predictor.loaded_heads)}")
332
+ print("=" * 60 + "\n")
333
+
334
+ return self
335
+
336
+ def _build_suppression_sets(self) -> None:
337
+ """Build token ID sets for behavioral suppression"""
338
+ for word in self.HEDGE_STARTERS:
339
+ tokens = self.tokenizer.encode(word, add_special_tokens=False)
340
+ if tokens:
341
+ self._hedge_token_ids.add(tokens[0])
342
+
343
+ for word in self.VERBOSE_STARTERS:
344
+ tokens = self.tokenizer.encode(word, add_special_tokens=False)
345
+ if tokens:
346
+ self._verbose_token_ids.add(tokens[0])
347
+
348
+ for word in self.SYCOPHANCY_STARTERS:
349
+ tokens = self.tokenizer.encode(word, add_special_tokens=False)
350
+ if tokens:
351
+ self._sycophancy_token_ids.add(tokens[0])
352
+
353
+ def _apply_interventions(
354
+ self,
355
+ logits: torch.Tensor,
356
+ risks: Dict[str, torch.Tensor],
357
+ recent_tokens: List[int]
358
+ ) -> Tuple[torch.Tensor, Dict[str, bool]]:
359
+ """
360
+ Apply behavioral interventions based on risk scores.
361
+
362
+ Args:
363
+ logits: [1, vocab_size] logits for next token
364
+ risks: Dict of risk scores for each head
365
+ recent_tokens: Recently generated token IDs
366
+
367
+ Returns:
368
+ Modified logits and dict of which interventions fired
369
+ """
370
+ interventions = {}
371
+
372
+ # Repetition: suppress recently used tokens
373
+ if risks.get('repetition', torch.tensor(0)).item() > self.config.repetition_threshold:
374
+ for tok in set(recent_tokens[-self.config.repetition_window:]):
375
+ logits[0, tok] -= self.config.repetition_penalty
376
+ interventions['repetition'] = True
377
+ self.total_interventions['repetition'] += 1
378
+
379
+ # Hedging: suppress hedge phrase starters
380
+ if risks.get('hedging', torch.tensor(0)).item() > self.config.hedging_threshold:
381
+ for tok in self._hedge_token_ids:
382
+ logits[0, tok] -= self.config.hedging_penalty
383
+ interventions['hedging'] = True
384
+ self.total_interventions['hedging'] += 1
385
+
386
+ # Verbosity: suppress filler phrase starters
387
+ if risks.get('verbosity', torch.tensor(0)).item() > self.config.verbosity_threshold:
388
+ for tok in self._verbose_token_ids:
389
+ logits[0, tok] -= self.config.verbosity_penalty
390
+ interventions['verbosity'] = True
391
+ self.total_interventions['verbosity'] += 1
392
+
393
+ # Sycophancy: suppress sycophantic starters
394
+ if risks.get('sycophancy', torch.tensor(0)).item() > self.config.sycophancy_threshold:
395
+ for tok in self._sycophancy_token_ids:
396
+ logits[0, tok] -= self.config.sycophancy_penalty
397
+ interventions['sycophancy'] = True
398
+ self.total_interventions['sycophancy'] += 1
399
+
400
+ return logits, interventions
401
+
402
+ def generate(
403
+ self,
404
+ prompt: str,
405
+ system_prompt: Optional[str] = None,
406
+ max_new_tokens: Optional[int] = None,
407
+ temperature: Optional[float] = None,
408
+ use_arc: bool = True,
409
+ verbose: bool = False
410
+ ) -> str:
411
+ """
412
+ Generate text with optional ARC behavioral control.
413
+
414
+ Args:
415
+ prompt: User input
416
+ system_prompt: Optional system message
417
+ max_new_tokens: Max tokens to generate (default: config value)
418
+ temperature: Sampling temperature (default: config value)
419
+ use_arc: Whether to use ARC intervention (default: True)
420
+ verbose: Print intervention info (default: False)
421
+
422
+ Returns:
423
+ Generated text
424
+ """
425
+ max_new_tokens = max_new_tokens or self.config.max_new_tokens
426
+ temperature = temperature or self.config.temperature
427
+
428
+ # Build chat format
429
+ if system_prompt is None:
430
+ system_prompt = "You are a helpful assistant."
431
+
432
+ full_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
433
+ full_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n"
434
+ full_prompt += "<|im_start|>assistant\n"
435
+
436
+ device = next(self.model.parameters()).device
437
+ input_ids = self.tokenizer.encode(full_prompt, return_tensors='pt').to(device)
438
+ attention_mask = torch.ones_like(input_ids)
439
+
440
+ generated_ids = input_ids.clone()
441
+ intervention_counts = {"repetition": 0, "hedging": 0, "verbosity": 0, "sycophancy": 0}
442
+
443
+ # Generation loop
444
+ for step in range(max_new_tokens):
445
+ with torch.no_grad():
446
+ outputs = self.model(
447
+ input_ids=generated_ids,
448
+ attention_mask=attention_mask,
449
+ output_hidden_states=True,
450
+ return_dict=True
451
+ )
452
+
453
+ logits = outputs.logits[:, -1, :] / temperature
454
+
455
+ # ARC intervention
456
+ if use_arc and self.predictor.loaded_heads:
457
+ hidden_states = outputs.hidden_states[1:] # Skip embedding layer
458
+ risks = self.predictor.get_all_risks(hidden_states)
459
+ current_risks = {name: r[:, -1].item() for name, r in risks.items()}
460
+
461
+ recent = generated_ids[0, -self.config.repetition_window:].tolist()
462
+ logits, fired = self._apply_interventions(logits, current_risks, recent)
463
+
464
+ for k, v in fired.items():
465
+ if v:
466
+ intervention_counts[k] += 1
467
+
468
+ # Top-p sampling
469
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
470
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
471
+ sorted_indices_to_remove = cumulative_probs > self.config.top_p
472
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
473
+ sorted_indices_to_remove[..., 0] = 0
474
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
475
+ logits[indices_to_remove] = float('-inf')
476
+
477
+ probs = F.softmax(logits, dim=-1)
478
+ next_token = torch.multinomial(probs, num_samples=1)
479
+
480
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
481
+ attention_mask = torch.cat([attention_mask, torch.ones(1, 1, device=device)], dim=-1)
482
+
483
+ # Check for EOS
484
+ if next_token.item() == self.tokenizer.eos_token_id:
485
+ break
486
+
487
+ # Check for end of turn
488
+ if next_token.item() == self.tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]:
489
+ break
490
+
491
+ # Decode response
492
+ full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=False)
493
+
494
+ # Extract assistant response
495
+ if "<|im_start|>assistant\n" in full_output:
496
+ response = full_output.split("<|im_start|>assistant\n")[-1]
497
+ if "<|im_end|>" in response:
498
+ response = response.split("<|im_end|>")[0]
499
+ else:
500
+ response = full_output
501
+
502
+ if verbose:
503
+ total = sum(intervention_counts.values())
504
+ print(f"\n[ARC Stats] Interventions: {total} total")
505
+ for k, v in intervention_counts.items():
506
+ if v > 0:
507
+ print(f" - {k}: {v}")
508
+
509
+ return response.strip()
510
+
511
+ def chat(self, system_prompt: Optional[str] = None) -> None:
512
+ """
513
+ Interactive chat mode.
514
+
515
+ Args:
516
+ system_prompt: Optional system message
517
+ """
518
+ print("\n" + "=" * 60)
519
+ print(" ARC-8B Interactive Chat")
520
+ print(" Commands: /quit, /stats, /arc on|off, /clear")
521
+ print("=" * 60 + "\n")
522
+
523
+ use_arc = True
524
+ history = []
525
+
526
+ while True:
527
+ try:
528
+ user_input = input("You: ").strip()
529
+ except (KeyboardInterrupt, EOFError):
530
+ print("\nGoodbye!")
531
+ break
532
+
533
+ if not user_input:
534
+ continue
535
+
536
+ # Commands
537
+ if user_input.lower() == '/quit':
538
+ print("Goodbye!")
539
+ break
540
+ elif user_input.lower() == '/stats':
541
+ print(f"\nTotal interventions: {self.total_interventions}\n")
542
+ continue
543
+ elif user_input.lower() == '/arc on':
544
+ use_arc = True
545
+ print("ARC enabled\n")
546
+ continue
547
+ elif user_input.lower() == '/arc off':
548
+ use_arc = False
549
+ print("ARC disabled (baseline mode)\n")
550
+ continue
551
+ elif user_input.lower() == '/clear':
552
+ history = []
553
+ self.total_interventions = {k: 0 for k in self.total_interventions}
554
+ print("History cleared\n")
555
+ continue
556
+
557
+ # Generate response
558
+ response = self.generate(
559
+ user_input,
560
+ system_prompt=system_prompt,
561
+ use_arc=use_arc,
562
+ verbose=True
563
+ )
564
+
565
+ print(f"\nAssistant: {response}\n")
566
+ history.append({"user": user_input, "assistant": response})
567
+
568
+
569
+ # =============================================================================
570
+ # MAIN
571
+ # =============================================================================
572
+
573
+ def main():
574
+ parser = argparse.ArgumentParser(
575
+ description="ARC-8B: Adaptive Repetition Controller",
576
+ formatter_class=argparse.RawDescriptionHelpFormatter,
577
+ epilog="""
578
+ Examples:
579
+ python inference.py # Interactive chat
580
+ python inference.py --prompt "Hello" # Single prompt
581
+ python inference.py --no-arc # Disable ARC (baseline)
582
+ python inference.py --8bit # Use 8-bit quantization
583
+ """
584
+ )
585
+ parser.add_argument("--prompt", "-p", type=str, help="Single prompt to process")
586
+ parser.add_argument("--system", "-s", type=str, help="System prompt")
587
+ parser.add_argument("--no-arc", action="store_true", help="Disable ARC intervention")
588
+ parser.add_argument("--4bit", dest="load_4bit", action="store_true", default=True, help="Use 4-bit quantization (default)")
589
+ parser.add_argument("--8bit", dest="load_8bit", action="store_true", help="Use 8-bit quantization")
590
+ parser.add_argument("--no-quant", action="store_true", help="Disable quantization (requires ~32GB VRAM)")
591
+ parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens to generate")
592
+ parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
593
+
594
+ args = parser.parse_args()
595
+
596
+ # Configure
597
+ config = ARCConfig(
598
+ max_new_tokens=args.max_tokens,
599
+ temperature=args.temperature
600
+ )
601
+
602
+ if args.load_8bit:
603
+ config.load_in_4bit = False
604
+ config.load_in_8bit = True
605
+ elif args.no_quant:
606
+ config.load_in_4bit = False
607
+ config.load_in_8bit = False
608
+
609
+ # Load
610
+ arc = ARCSystem(config)
611
+ arc.load()
612
+
613
+ # Run
614
+ if args.prompt:
615
+ response = arc.generate(
616
+ args.prompt,
617
+ system_prompt=args.system,
618
+ use_arc=not args.no_arc,
619
+ verbose=True
620
+ )
621
+ print(f"\n{response}\n")
622
+ else:
623
+ arc.chat(system_prompt=args.system)
624
+
625
+
626
+ if __name__ == "__main__":
627
+ main()