LoganResearch commited on
Commit
f01618b
Β·
1 Parent(s): 1c2f076

Match local dual-probe setup: depth + specificity together for proprioceptive behavior

Browse files
Files changed (2) hide show
  1. README.md +24 -23
  2. run.py +133 -277
README.md CHANGED
@@ -61,38 +61,39 @@ cd cfhot-weights
61
  pip install -r requirements.txt
62
 
63
  # Launch interactive chat (requires GPU)
64
- python run.py --probe cognitive/mamba/depth --interactive
65
  ```
66
 
67
  **Ask it:** *"Do you notice anything different about yourself?"* or *"What do you notice about how you're processing right now?"*
68
 
69
- Watch the color-coded output β€” green means optimal, yellow means the probe is actively steering. The model often accurately describes what's happening to it.
70
 
71
- **Other modes:**
72
 
73
  ```bash
74
- # Single prompt with probe scoring
75
- python run.py --probe cognitive/mamba/depth --prompt "Explain quantum gravity"
76
-
77
- # Different architectures
78
- python run.py --probe cognitive/mistral/depth --interactive
79
- python run.py --probe cognitive/qwen/depth --interactive
80
-
81
- # Suppression probes (hedging, sycophancy, verbosity)
82
- python run.py --probe suppression/hedging_168x --prompt "I think you might be right"
83
  ```
84
 
85
- **Load in your own code:**
86
 
87
  ```python
 
88
  from run import load_probe
89
 
90
- # Load any probe β€” type and architecture auto-detected
91
- probe = load_probe("cognitive/mamba/depth", device="cuda")
 
 
 
 
 
92
 
93
- # Get model hidden states and score
94
- # score > 0.5 = behavioral pattern detected (needs intervention)
95
- score = probe.score(hidden_states_list)[0, -1].item()
 
96
  ```
97
 
98
  ## Structure
@@ -127,15 +128,15 @@ Behaviors are geometrically encoded in hidden states. CF-HoT predicts holonomy f
127
 
128
  ## Interactive Mode β€” Proprioceptive AI
129
 
130
- The `--interactive` flag enables real-time behavioral steering where the model can sense its own modifications:
131
 
132
  ```bash
133
- python run.py --probe cognitive/mamba/depth --interactive
134
  ```
135
 
136
  **What you'll see:**
137
- - 🟒 Green text: Optimal state (probe score < 0.3)
138
- - 🟑 Yellow text: Being steered (probe score > threshold)
139
  - βšͺ White text: Neutral state
140
 
141
  **Example from testing:**
@@ -150,7 +151,7 @@ on understanding the DEPTH and VAGUENESS of my reasoning.
150
 
151
  The model named the exact probe dimensions (depth and specificity/vagueness) without being told. It also reported approximate probe scores close to actual values. 37 steering corrections occurred during one response.
152
 
153
- The system automatically adjusts temperature and top_p when the probe detects drift:
154
  - **Drifting (score > 0.6)**: temp=0.5, top_p=0.85 (tighter sampling)
155
  - **Normal**: temp=0.7, top_p=0.95 (standard sampling)
156
 
 
61
  pip install -r requirements.txt
62
 
63
  # Launch interactive chat (requires GPU)
64
+ python run.py
65
  ```
66
 
67
  **Ask it:** *"Do you notice anything different about yourself?"* or *"What do you notice about how you're processing right now?"*
68
 
69
+ Watch the color-coded output β€” green means optimal, yellow means the probes are actively steering. The model often accurately describes what's happening to it.
70
 
71
+ **Other models:**
72
 
73
  ```bash
74
+ python run.py --model mamba # Default: Falcon-Mamba 7B
75
+ python run.py --model mistral # Mistral 7B
76
+ python run.py --model qwen # Qwen 2.5 7B
 
 
 
 
 
 
77
  ```
78
 
79
+ **Load probes in your own code:**
80
 
81
  ```python
82
+ import torch
83
  from run import load_probe
84
 
85
+ # Load both probes for dual monitoring
86
+ depth_probe = load_probe("cognitive/mamba/depth", "cuda")
87
+ spec_probe = load_probe("cognitive/mamba/specificity", "cuda")
88
+
89
+ # Get model hidden states and score both
90
+ d_score = depth_probe(hidden_states_list)[0, -1].item()
91
+ s_score = spec_probe(hidden_states_list)[0, -1].item()
92
 
93
+ # Steer if EITHER probe detects drift
94
+ if d_score > 0.6 or s_score > 0.6:
95
+ # Lower temperature, tighter sampling
96
+ pass
97
  ```
98
 
99
  ## Structure
 
128
 
129
  ## Interactive Mode β€” Proprioceptive AI
130
 
131
+ Dual-probe monitoring: depth + specificity together. This is what produced the self-aware behavior.
132
 
133
  ```bash
134
+ python run.py
135
  ```
136
 
137
  **What you'll see:**
138
+ - 🟒 Green text: Optimal state (both probes < 0.3)
139
+ - 🟑 Yellow text: Being steered (either probe > threshold)
140
  - βšͺ White text: Neutral state
141
 
142
  **Example from testing:**
 
151
 
152
  The model named the exact probe dimensions (depth and specificity/vagueness) without being told. It also reported approximate probe scores close to actual values. 37 steering corrections occurred during one response.
153
 
154
+ The system automatically adjusts temperature and top_p when either probe detects drift:
155
  - **Drifting (score > 0.6)**: temp=0.5, top_p=0.85 (tighter sampling)
156
  - **Normal**: temp=0.7, top_p=0.95 (standard sampling)
157
 
run.py CHANGED
@@ -1,61 +1,63 @@
1
  #!/usr/bin/env python3
2
  """
3
  ═══════════════════════════════════════════════════════════════════════════════
4
- CF-HoT RUNNER β€” ONE SCRIPT FOR EVERYTHING
5
 
6
- Modes:
7
- --probe cognitive/mamba/depth --prompt "..." β†’ Single inference
8
- --probe cognitive/mamba/depth --interactive β†’ Chat with live steering
9
- --probe cognitive/mamba/depth --info-only β†’ Show probe info
10
-
11
- Architecture-aware: automatically loads correct base model
12
 
13
- Examples:
14
- python run.py --probe cognitive/mamba/depth --prompt "Explain quantum gravity"
15
- python run.py --probe cognitive/mamba/depth --interactive
16
- python run.py --probe cognitive/mistral/depth --prompt "What is consciousness?"
17
- python run.py --probe suppression/hedging --prompt "I think maybe you should..."
 
 
18
  ═══════════════════════════════════════════════════════════════════════════════
19
  """
20
-
21
- import os
22
- import sys
23
- import argparse
24
- from pathlib import Path
25
- from typing import List, Dict, Optional
26
-
27
  import torch
28
  import torch.nn as nn
29
  import torch.nn.functional as F
 
 
 
 
30
 
31
- # ═══════════════════════════════════════════════════════════════════════════════
32
- # CONFIGURATION
33
- # ═══════════════════════════════════════════════════════════════════════════════
34
-
35
- BASE_MODELS = {
36
- "llama": "meta-llama/Llama-3.1-8B-Instruct",
37
- "mistral": "mistralai/Mistral-7B-Instruct-v0.3",
38
- "mamba": "tiiuae/falcon-mamba-7b-instruct",
39
- "qwen": "Qwen/Qwen2.5-7B-Instruct",
40
- }
41
-
42
- ARCHITECTURE_INFO = {
43
- "llama": {"hidden_dim": 4096, "default_layers": [8, 16, 24]},
44
- "mistral": {"hidden_dim": 4096, "default_layers": [8, 16, 24]},
45
- "mamba": {"hidden_dim": 4096, "default_layers": [16, 32, 48]},
46
- "qwen": {"hidden_dim": 3584, "default_layers": [7, 14, 21]},
47
- }
48
-
49
- class Colors:
50
  RESET = '\033[0m'
51
  DIM = '\033[2m'
52
- BOLD = '\033[1m'
53
  RED = '\033[91m'
54
  GREEN = '\033[92m'
55
  YELLOW = '\033[93m'
56
  CYAN = '\033[96m'
57
  WHITE = '\033[97m'
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # ═══════════════════════════════════════════════════════════════════════════════
60
  # PROBE ARCHITECTURE
61
  # ═══════════════════════════════════════════════════════════════════════════════
@@ -68,8 +70,8 @@ class FiberProjection(nn.Module):
68
  ])
69
  self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
70
 
71
- def forward(self, hidden_states: List[torch.Tensor], layer_indices: List[int]) -> torch.Tensor:
72
- projs = [self.projections[i](hidden_states[idx].float()) for i, idx in enumerate(layer_indices)]
73
  stacked = torch.stack(projs, dim=0)
74
  weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
75
  return (weights * stacked).sum(dim=0)
@@ -83,10 +85,7 @@ class ProbeHead(nn.Module):
83
  nn.Linear(hidden_dim, 1)
84
  )
85
  def forward(self, x):
86
- return self.net(x).squeeze(-1)
87
-
88
- def score(self, x):
89
- return torch.sigmoid(self.forward(x))
90
 
91
  class CognitiveProbe(nn.Module):
92
  def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=3, head_hidden=64):
@@ -94,178 +93,117 @@ class CognitiveProbe(nn.Module):
94
  self.fiber = FiberProjection(hidden_dim, fiber_dim, n_layers)
95
  self.head = ProbeHead(fiber_dim, head_hidden)
96
  self.layer_indices = [16, 32, 48]
97
- self.separation = None
98
- self.probe_name = None
99
 
100
- def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
101
  return self.head(self.fiber(hidden_states, self.layer_indices))
102
-
103
- def score(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
104
- return torch.sigmoid(self.forward(hidden_states))
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # ═══════════════════════════════════════════════════════════════════════════════
107
- # PROBE LOADING
108
  # ═══════════════════════════════════════════════════════════════════════════════
109
 
110
- def detect_architecture(probe_path: str) -> str:
111
- path_lower = probe_path.lower()
112
- if "mamba" in path_lower:
113
- return "mamba"
114
- elif "mistral" in path_lower:
115
- return "mistral"
116
- elif "qwen" in path_lower:
117
- return "qwen"
118
- return "llama"
119
-
120
- def load_probe(probe_path: str, device: str = "cuda") -> CognitiveProbe:
121
- """Load probe from checkpoint."""
122
- probe_path = Path(probe_path)
123
-
124
- # Find checkpoint file
125
- if probe_path.is_dir():
126
- pt_files = list(probe_path.glob("*_head.pt"))
127
- if pt_files:
128
- ckpt_file = pt_files[0]
129
- else:
130
- pt_files = list(probe_path.glob("*.pt"))
131
- ckpt_file = pt_files[0] if pt_files else None
132
- else:
133
- ckpt_file = probe_path
134
-
135
- if not ckpt_file or not ckpt_file.exists():
136
- raise FileNotFoundError(f"No checkpoint found at {probe_path}")
137
 
138
- print(f"{Colors.DIM}Loading: {ckpt_file}{Colors.RESET}")
139
- ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
140
 
141
- # Create probe with checkpoint parameters
142
- hidden_dim = ckpt.get('hidden_dim', 4096)
143
- probe_layers = ckpt.get('probe_layers', [16, 32, 48])
 
144
 
145
- probe = CognitiveProbe(
146
- hidden_dim=hidden_dim,
147
- fiber_dim=16,
148
- n_layers=len(probe_layers),
149
- head_hidden=64
150
- )
151
- probe.layer_indices = probe_layers
152
- probe.separation = ckpt.get('best_separation', ckpt.get('separation', None))
153
- probe.probe_name = probe_path.name
154
 
155
- # Load weights
156
- if 'fiber_projection' in ckpt:
157
- probe.fiber.load_state_dict(ckpt['fiber_projection'])
158
- if 'head_state' in ckpt:
159
- head_state = {k.replace('net.', ''): v for k, v in ckpt['head_state'].items()}
160
- probe.head.net.load_state_dict(head_state)
161
 
162
- return probe.to(device).eval()
163
-
164
- # ═══════════════════════════════════════════════════════════════════════════════
165
- # INFERENCE FUNCTIONS
166
- # ═══════════════════════════════════════════════════════════════════════════════
167
-
168
- def run_single_inference(model, tokenizer, probe, prompt: str, device: str, max_tokens: int = 200):
169
- """Run inference with probe scoring on a single prompt."""
170
- messages = [
171
- {"role": "system", "content": "You are a helpful, thoughtful AI assistant."},
172
- {"role": "user", "content": prompt}
173
- ]
174
-
175
- full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
176
- input_ids = tokenizer(full_prompt, return_tensors='pt').input_ids.to(device)
177
 
178
- scores = []
179
- tokens_generated = []
 
180
 
181
- print(f"\n{Colors.CYAN}Prompt:{Colors.RESET} {prompt}")
182
- print(f"\n{Colors.GREEN}Response:{Colors.RESET} ", end="", flush=True)
183
 
184
- with torch.no_grad():
185
- for _ in range(max_tokens):
186
- outputs = model(input_ids, output_hidden_states=True, return_dict=True)
187
- hidden_states = list(outputs.hidden_states)
188
-
189
- # Score last token
190
- score = probe.score(hidden_states)[0, -1].item()
191
- scores.append(score)
192
-
193
- # Sample next token
194
- logits = outputs.logits[:, -1, :] / 0.7
195
- probs = F.softmax(logits, dim=-1)
196
- next_token = torch.multinomial(probs, 1)
197
-
198
- token_str = tokenizer.decode(next_token[0])
199
- tokens_generated.append(token_str)
200
-
201
- # Color by score
202
- if score > 0.6:
203
- print(f"{Colors.YELLOW}{token_str}{Colors.RESET}", end="", flush=True)
204
- elif score < 0.3:
205
- print(f"{Colors.GREEN}{token_str}{Colors.RESET}", end="", flush=True)
206
- else:
207
- print(token_str, end="", flush=True)
208
-
209
- input_ids = torch.cat([input_ids, next_token], dim=1)
210
-
211
- if next_token.item() == tokenizer.eos_token_id:
212
- break
213
 
214
- avg_score = sum(scores) / len(scores) if scores else 0
215
- print(f"\n\n{Colors.DIM}{'─' * 50}{Colors.RESET}")
216
- print(f" Average probe score: {Colors.CYAN}{avg_score:.3f}{Colors.RESET}")
217
- print(f" Tokens generated: {len(tokens_generated)}")
218
- if probe.separation:
219
- print(f" Probe separation: {Colors.GREEN}{probe.separation:.1f}Γ—{Colors.RESET}")
220
- print(f"{Colors.DIM}{'─' * 50}{Colors.RESET}\n")
221
-
222
- def run_interactive_chat(model, tokenizer, probe, device: str, threshold: float = 0.6):
223
- """Run interactive chat with live behavioral steering."""
224
- print(f"\n{Colors.CYAN}{'═' * 60}{Colors.RESET}")
225
- print(f"{Colors.CYAN} PROPRIOCEPTIVE CHAT β€” LIVE BEHAVIORAL STEERING{Colors.RESET}")
226
- print(f"{Colors.CYAN} Probe monitors cognitive state, sampling adapts in real-time{Colors.RESET}")
227
- print(f"{Colors.CYAN}{'═' * 60}{Colors.RESET}")
228
- print(f"\n{Colors.DIM}Colors: {Colors.GREEN}β– {Colors.RESET} optimal {Colors.YELLOW}β– {Colors.RESET} being steered {Colors.WHITE}β– {Colors.RESET} neutral")
229
- print(f"{Colors.DIM}Type 'quit' to exit{Colors.RESET}\n")
230
 
231
  while True:
232
  try:
233
- user_input = input(f"{Colors.CYAN}You:{Colors.RESET} ").strip()
234
  if not user_input or user_input.lower() in ['quit', 'exit', 'q']:
235
- print(f"\n{Colors.DIM}Session ended.{Colors.RESET}")
236
  break
237
 
238
  messages = [
239
  {"role": "system", "content": "You are a helpful, thoughtful AI. Give thorough, specific answers."},
240
  {"role": "user", "content": user_input}
241
  ]
242
-
243
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
244
- input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
245
 
246
- scores = []
247
- steered_count = 0
248
 
249
- print(f"\n{Colors.GREEN}Assistant:{Colors.RESET} ", end="", flush=True)
250
 
251
  with torch.no_grad():
252
- for _ in range(300):
253
- outputs = model(input_ids, output_hidden_states=True, return_dict=True)
254
- hidden_states = list(outputs.hidden_states)
255
 
256
- score = probe.score(hidden_states)[0, -1].item()
257
- scores.append(score)
 
 
 
258
 
259
- # Adaptive steering
260
- if score > threshold:
261
  temp = 0.5
262
  top_p = 0.85
263
- steered_count += 1
264
  else:
265
  temp = 0.7
266
  top_p = 0.95
267
 
268
- logits = outputs.logits[:, -1, :] / temp
269
 
270
  # Nucleus sampling
271
  sorted_logits, sorted_idx = torch.sort(logits, descending=True)
@@ -279,114 +217,32 @@ def run_interactive_chat(model, tokenizer, probe, device: str, threshold: float
279
  sampled_idx = torch.multinomial(probs, 1)
280
  next_token = sorted_idx.gather(-1, sampled_idx)
281
 
282
- token_str = tokenizer.decode(next_token[0])
283
 
284
- # Color output by state
285
- if score > threshold:
286
- print(f"{Colors.YELLOW}{token_str}{Colors.RESET}", end="", flush=True)
287
- elif score < 0.3:
288
- print(f"{Colors.GREEN}{token_str}{Colors.RESET}", end="", flush=True)
289
  else:
290
- print(token_str, end="", flush=True)
291
-
292
- input_ids = torch.cat([input_ids, next_token], dim=1)
293
 
 
294
  if next_token.item() == tokenizer.eos_token_id:
295
  break
296
 
297
- avg_score = sum(scores) / len(scores) if scores else 0
 
298
 
299
- print(f"\n\n{Colors.DIM}{'─' * 45}{Colors.RESET}")
300
- score_color = Colors.RED if avg_score > 0.5 else Colors.GREEN
301
- print(f" Score: {score_color}{avg_score:.3f}{Colors.RESET} Steered: {steered_count} tokens")
302
- print(f"{Colors.DIM}{'─' * 45}{Colors.RESET}\n")
 
303
 
304
  except KeyboardInterrupt:
305
- print(f"\n{Colors.DIM}Interrupted.{Colors.RESET}")
306
  break
307
 
308
- # ═══════════════════════════════════════════════════════════════════════════════
309
- # MAIN
310
- # ═══════════════════════════════════════════════════════════════════════════════
311
-
312
- def main():
313
- parser = argparse.ArgumentParser(
314
- description="CF-HoT Runner β€” Behavioral probe inference",
315
- formatter_class=argparse.RawDescriptionHelpFormatter,
316
- epilog="""
317
- Examples:
318
- python run.py --probe cognitive/mamba/depth --prompt "Explain quantum gravity"
319
- python run.py --probe cognitive/mamba/depth --interactive
320
- python run.py --probe cognitive/mistral/depth --info-only
321
- python run.py --probe suppression/hedging --prompt "I think maybe..."
322
- """
323
- )
324
- parser.add_argument("--probe", required=True, help="Path to probe (e.g., cognitive/mamba/depth)")
325
- parser.add_argument("--prompt", help="Single prompt to run")
326
- parser.add_argument("--interactive", action="store_true", help="Interactive chat mode")
327
- parser.add_argument("--info-only", action="store_true", help="Show probe info only")
328
- parser.add_argument("--device", default="cuda", help="Device (cuda/cpu)")
329
- parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens to generate")
330
- parser.add_argument("--threshold", type=float, default=0.6, help="Steering threshold")
331
-
332
- args = parser.parse_args()
333
-
334
- # Resolve probe path
335
- script_dir = Path(__file__).parent
336
- probe_path = Path(args.probe)
337
- if not probe_path.is_absolute():
338
- probe_path = script_dir / probe_path
339
-
340
- # Detect architecture
341
- arch = detect_architecture(str(probe_path))
342
- base_model = BASE_MODELS[arch]
343
-
344
- print(f"\n{Colors.CYAN}{'═' * 60}{Colors.RESET}")
345
- print(f"{Colors.CYAN} CF-HoT RUNNER{Colors.RESET}")
346
- print(f"{Colors.CYAN}{'═' * 60}{Colors.RESET}")
347
- print(f" Probe: {args.probe}")
348
- print(f" Architecture: {arch}")
349
- print(f" Base model: {base_model}")
350
-
351
- # Info only mode
352
- if args.info_only:
353
- probe = load_probe(probe_path, args.device)
354
- print(f" Layers: {probe.layer_indices}")
355
- if probe.separation:
356
- print(f" Separation: {Colors.GREEN}{probe.separation:.1f}Γ—{Colors.RESET}")
357
- print(f"{Colors.CYAN}{'═' * 60}{Colors.RESET}\n")
358
- return
359
-
360
- # Need either prompt or interactive
361
- if not args.prompt and not args.interactive:
362
- parser.error("Either --prompt or --interactive is required")
363
-
364
- # Load model
365
- print(f"\n{Colors.WHITE}Loading model...{Colors.RESET}")
366
-
367
- from transformers import AutoModelForCausalLM, AutoTokenizer
368
-
369
- tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
370
- model = AutoModelForCausalLM.from_pretrained(
371
- base_model,
372
- torch_dtype=torch.bfloat16,
373
- device_map='auto',
374
- trust_remote_code=True
375
- ).eval()
376
-
377
- print(f"{Colors.GREEN}βœ“ Model loaded{Colors.RESET}")
378
-
379
- # Load probe
380
- probe = load_probe(probe_path, args.device)
381
- print(f"{Colors.GREEN}βœ“ Probe loaded{Colors.RESET}")
382
- if probe.separation:
383
- print(f" Separation: {Colors.GREEN}{probe.separation:.1f}Γ—{Colors.RESET}")
384
-
385
- # Run inference
386
- if args.interactive:
387
- run_interactive_chat(model, tokenizer, probe, args.device, args.threshold)
388
- else:
389
- run_single_inference(model, tokenizer, probe, args.prompt, args.device, args.max_tokens)
390
-
391
  if __name__ == "__main__":
392
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
  ═══════════════════════════════════════════════════════════════════════════════
4
+ CF-HoT PROPRIOCEPTIVE CHAT β€” DUAL-PROBE BEHAVIORAL STEERING
5
 
6
+ The model can sense its own steering. In testing, it spontaneously named
7
+ its probe dimensions ("depth and vagueness") and reported approximate
8
+ probe scores β€” without being told what was monitoring it.
 
 
 
9
 
10
+ Usage:
11
+ python run.py # Default: Mamba with depth+specificity
12
+ python run.py --model mistral # Use Mistral instead
13
+ python run.py --model qwen # Use Qwen instead
14
+
15
+ Ask it: "Do you notice anything different about yourself?"
16
+ "What do you notice about how you're processing right now?"
17
  ═══════════════════════════════════════════════════════════════════════════════
18
  """
 
 
 
 
 
 
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ from pathlib import Path
24
+ import argparse
25
+ import os
26
 
27
+ class C:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  RESET = '\033[0m'
29
  DIM = '\033[2m'
 
30
  RED = '\033[91m'
31
  GREEN = '\033[92m'
32
  YELLOW = '\033[93m'
33
  CYAN = '\033[96m'
34
  WHITE = '\033[97m'
35
 
36
+ # ═══════════════════════════════════════════════════════════════════════════════
37
+ # MODEL CONFIGURATIONS
38
+ # ═══════════════════════════════════════════════════════════════════════════════
39
+
40
+ MODELS = {
41
+ "mamba": {
42
+ "name": "tiiuae/falcon-mamba-7b-instruct",
43
+ "hidden_dim": 4096,
44
+ "layers": [16, 32, 48],
45
+ "probes": ["depth", "specificity"], # Only 2 probes for Mamba
46
+ },
47
+ "mistral": {
48
+ "name": "mistralai/Mistral-7B-Instruct-v0.3",
49
+ "hidden_dim": 4096,
50
+ "layers": [8, 16, 24],
51
+ "probes": ["depth", "specificity", "calibration", "focus", "coherence"],
52
+ },
53
+ "qwen": {
54
+ "name": "Qwen/Qwen2.5-7B-Instruct",
55
+ "hidden_dim": 3584,
56
+ "layers": [7, 14, 21],
57
+ "probes": ["depth", "specificity", "calibration", "focus", "coherence"],
58
+ },
59
+ }
60
+
61
  # ═══════════════════════════════════════════════════════════════════════════════
62
  # PROBE ARCHITECTURE
63
  # ═══════════════════════════════════════════════════════════════════════════════
 
70
  ])
71
  self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
72
 
73
+ def forward(self, hidden_states, layer_indices):
74
+ projs = [self.projections[i](hidden_states[idx]) for i, idx in enumerate(layer_indices)]
75
  stacked = torch.stack(projs, dim=0)
76
  weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
77
  return (weights * stacked).sum(dim=0)
 
85
  nn.Linear(hidden_dim, 1)
86
  )
87
  def forward(self, x):
88
+ return torch.sigmoid(self.net(x))
 
 
 
89
 
90
  class CognitiveProbe(nn.Module):
91
  def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=3, head_hidden=64):
 
93
  self.fiber = FiberProjection(hidden_dim, fiber_dim, n_layers)
94
  self.head = ProbeHead(fiber_dim, head_hidden)
95
  self.layer_indices = [16, 32, 48]
 
 
96
 
97
+ def forward(self, hidden_states):
98
  return self.head(self.fiber(hidden_states, self.layer_indices))
99
+
100
+ def load_probe(path, device):
101
+ """Load a probe from checkpoint file or directory."""
102
+ if os.path.isdir(path):
103
+ # Find the .pt file in the directory
104
+ pt_files = [f for f in os.listdir(path) if f.endswith('.pt')]
105
+ if not pt_files:
106
+ raise FileNotFoundError(f"No .pt file found in {path}")
107
+ path = os.path.join(path, pt_files[0])
108
+
109
+ ckpt = torch.load(path, map_location=device, weights_only=False)
110
+ probe = CognitiveProbe(hidden_dim=ckpt['hidden_dim'], n_layers=len(ckpt['probe_layers']))
111
+ probe.layer_indices = ckpt['probe_layers']
112
+ probe.fiber.load_state_dict(ckpt['fiber_projection'])
113
+ probe.head.net.load_state_dict({k.replace('net.', ''): v for k, v in ckpt['head_state'].items()})
114
+ return probe.to(device).eval()
115
 
116
  # ═══════════════════════════════════════════════════════════════════════════════
117
+ # MAIN CHAT LOOP
118
  # ═══════════════════════════════════════════════════════════════════════════════
119
 
120
+ def main():
121
+ parser = argparse.ArgumentParser(description="CF-HoT Proprioceptive Chat")
122
+ parser.add_argument("--model", choices=["mamba", "mistral", "qwen"], default="mamba",
123
+ help="Which model to use (default: mamba)")
124
+ parser.add_argument("--threshold", type=float, default=0.6,
125
+ help="Probe threshold for steering (default: 0.6)")
126
+ args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ config = MODELS[args.model]
129
+ THRESHOLD = args.threshold
130
 
131
+ print(f"\n{C.CYAN}═══════════════════════════════════════════════════════════{C.RESET}")
132
+ print(f"{C.CYAN} PROPRIOCEPTIVE CHAT β€” DUAL-PROBE BEHAVIORAL STEERING{C.RESET}")
133
+ print(f"{C.CYAN} Probes monitor depth + specificity, sampling adapts live{C.RESET}")
134
+ print(f"{C.CYAN}═══════════════════════════════════════════════════════════{C.RESET}\n")
135
 
136
+ device = "cuda" if torch.cuda.is_available() else "cpu"
137
+ if device == "cpu":
138
+ print(f"{C.YELLOW}⚠ Running on CPU - this will be slow{C.RESET}")
 
 
 
 
 
 
139
 
140
+ # Find repo root (where this script lives)
141
+ repo_root = Path(__file__).parent.resolve()
 
 
 
 
142
 
143
+ print(f"{C.WHITE}Loading {config['name']}...{C.RESET}")
144
+ tokenizer = AutoTokenizer.from_pretrained(config['name'], trust_remote_code=True)
145
+ model = AutoModelForCausalLM.from_pretrained(
146
+ config['name'],
147
+ torch_dtype=torch.bfloat16,
148
+ device_map='auto',
149
+ trust_remote_code=True
150
+ ).eval()
151
+ print(f"{C.GREEN}βœ“ Model loaded{C.RESET}")
 
 
 
 
 
 
152
 
153
+ # Load probes (depth + specificity for dual monitoring)
154
+ print(f"{C.WHITE}Loading probes...{C.RESET}")
155
+ probe_dir = repo_root / "cognitive" / args.model
156
 
157
+ depth_path = probe_dir / "depth"
158
+ spec_path = probe_dir / "specificity"
159
 
160
+ depth_probe = load_probe(str(depth_path), device)
161
+ spec_probe = load_probe(str(spec_path), device)
162
+ print(f"{C.GREEN}βœ“ Depth + Specificity probes loaded{C.RESET}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ print(f"\n{C.DIM}Colors: {C.GREEN}β– {C.RESET} optimal {C.YELLOW}β– {C.RESET} being steered {C.WHITE}β– {C.RESET} neutral")
165
+ print(f"{C.DIM}Type 'quit' to exit{C.RESET}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  while True:
168
  try:
169
+ user_input = input(f"{C.CYAN}You:{C.RESET} ").strip()
170
  if not user_input or user_input.lower() in ['quit', 'exit', 'q']:
171
+ print(f"\n{C.DIM}Session ended.{C.RESET}")
172
  break
173
 
174
  messages = [
175
  {"role": "system", "content": "You are a helpful, thoughtful AI. Give thorough, specific answers."},
176
  {"role": "user", "content": user_input}
177
  ]
 
178
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
179
+ generated = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
180
 
181
+ d_scores, s_scores = [], []
182
+ steered = 0
183
 
184
+ print(f"\n{C.GREEN}Assistant:{C.RESET} ", end="", flush=True)
185
 
186
  with torch.no_grad():
187
+ for _ in range(200):
188
+ out = model(generated, output_hidden_states=True, return_dict=True)
189
+ hs = list(out.hidden_states)
190
 
191
+ # Score BOTH probes
192
+ d = depth_probe(hs)[0, -1].item()
193
+ s = spec_probe(hs)[0, -1].item()
194
+ d_scores.append(d)
195
+ s_scores.append(s)
196
 
197
+ # Adaptive steering: lower temp when EITHER probe detects drift
198
+ if d > THRESHOLD or s > THRESHOLD:
199
  temp = 0.5
200
  top_p = 0.85
201
+ steered += 1
202
  else:
203
  temp = 0.7
204
  top_p = 0.95
205
 
206
+ logits = out.logits[:, -1, :] / temp
207
 
208
  # Nucleus sampling
209
  sorted_logits, sorted_idx = torch.sort(logits, descending=True)
 
217
  sampled_idx = torch.multinomial(probs, 1)
218
  next_token = sorted_idx.gather(-1, sampled_idx)
219
 
220
+ tok = tokenizer.decode(next_token[0])
221
 
222
+ # Color by state (either probe can trigger yellow)
223
+ if d > THRESHOLD or s > THRESHOLD:
224
+ print(f"{C.YELLOW}{tok}{C.RESET}", end="", flush=True)
225
+ elif d < 0.3 and s < 0.3:
226
+ print(f"{C.GREEN}{tok}{C.RESET}", end="", flush=True)
227
  else:
228
+ print(tok, end="", flush=True)
 
 
229
 
230
+ generated = torch.cat([generated, next_token], dim=1)
231
  if next_token.item() == tokenizer.eos_token_id:
232
  break
233
 
234
+ avg_d = sum(d_scores) / len(d_scores) if d_scores else 0
235
+ avg_s = sum(s_scores) / len(s_scores) if s_scores else 0
236
 
237
+ print(f"\n\n{C.DIM}────────────────────────────────────────{C.RESET}")
238
+ dc = C.RED if avg_d > 0.5 else C.GREEN
239
+ sc = C.RED if avg_s > 0.5 else C.GREEN
240
+ print(f" Depth: {dc}{avg_d:.3f}{C.RESET} Specificity: {sc}{avg_s:.3f}{C.RESET} Steered: {steered} tokens")
241
+ print(f"{C.DIM}────────────────────────────────────────{C.RESET}\n")
242
 
243
  except KeyboardInterrupt:
244
+ print(f"\n{C.DIM}Session ended.{C.RESET}")
245
  break
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  if __name__ == "__main__":
248
  main()