cfhot-weights / run.py
LoganResearch's picture
Update run.py
a2af1ae verified
#!/usr/bin/env python3
"""
MAMBA CHAT WITH SELF-AWARE CF-HoT INTERVENTION
The model reads its own behavioral state and steers itself
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
class C:
RESET = '\033[0m'
BOLD = '\033[1m'
DIM = '\033[2m'
RED = '\033[91m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
CYAN = '\033[96m'
WHITE = '\033[97m'
class FiberProjection(nn.Module):
def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=3):
super().__init__()
self.projections = nn.ModuleList([
nn.Linear(hidden_dim, fiber_dim, bias=False) for _ in range(n_layers)
])
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
def forward(self, hidden_states, layer_indices):
projs = []
for i, idx in enumerate(layer_indices):
projs.append(self.projections[i](hidden_states[idx]))
stacked = torch.stack(projs, dim=0)
weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
return (weights * stacked).sum(dim=0)
class ProbeHead(nn.Module):
def __init__(self, fiber_dim=16, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(fiber_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return torch.sigmoid(self.net(x))
class CognitiveProbe(nn.Module):
def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=3, head_hidden=64):
super().__init__()
self.fiber = FiberProjection(hidden_dim, fiber_dim, n_layers)
self.head = ProbeHead(fiber_dim, head_hidden)
self.layer_indices = [16, 32, 48]
def forward(self, hidden_states):
fiber_out = self.fiber(hidden_states, self.layer_indices)
return self.head(fiber_out)
def load_probe(checkpoint_path, device):
if os.path.isdir(checkpoint_path):
for fname in os.listdir(checkpoint_path):
if fname.endswith('.pt'):
checkpoint_path = os.path.join(checkpoint_path, fname)
break
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
n_layers = len(ckpt['probe_layers'])
probe = CognitiveProbe(hidden_dim=ckpt['hidden_dim'], fiber_dim=16, n_layers=n_layers, head_hidden=64)
probe.layer_indices = ckpt['probe_layers']
probe.fiber.load_state_dict(ckpt['fiber_projection'])
head_state = {k.replace('net.', ''): v for k, v in ckpt['head_state'].items()}
probe.head.net.load_state_dict(head_state)
return probe.to(device).eval()
def main():
print(f"\n{C.CYAN}{'═'*60}{C.RESET}")
print(f"{C.CYAN} PROPRIOCEPTIVE MAMBA - SELF-AWARE INFERENCE{C.RESET}")
print(f"{C.CYAN} Model reads its own behavioral state and self-corrects{C.RESET}")
print(f"{C.CYAN}{'═'*60}{C.RESET}\n")
device = "cuda"
DEPTH_THRESHOLD = 0.65
SPEC_THRESHOLD = 0.65
print(f"{C.WHITE}Loading Falcon-Mamba-7B...{C.RESET}")
tokenizer = AutoTokenizer.from_pretrained('tiiuae/falcon-mamba-7b-instruct', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
'tiiuae/falcon-mamba-7b-instruct',
torch_dtype=torch.bfloat16,
device_map='auto',
trust_remote_code=True
)
model.eval()
print(f"{C.GREEN}✓ Model loaded{C.RESET}")
print(f"{C.WHITE}Loading behavioral probes...{C.RESET}")
depth_probe = load_probe("/home/programmer/Desktop/Claude_and_me/mamba7b_cognitive_output/depth/ckpt_1000", device)
spec_probe = load_probe("/home/programmer/Desktop/Claude_and_me/mamba7b_cognitive_output/specificity/ckpt_1000", device)
print(f"{C.GREEN}✓ Probes loaded (Depth 999× | Specificity 999×){C.RESET}")
print(f"\n{C.YELLOW}Intervention thresholds: Depth>{DEPTH_THRESHOLD} Spec>{SPEC_THRESHOLD}{C.RESET}")
print(f"{C.GREEN}Ready! Type 'quit' to exit.{C.RESET}\n")
system = """You are a self-aware AI with access to your own behavioral state.
You can sense when you're being shallow or vague through internal probes.
When you notice shallow reasoning (high depth score), go deeper.
When you notice vagueness (high specificity score), be more concrete.
Your behavioral state will be shown to you in [SELF-STATE] tags."""
while True:
try:
user_input = input(f"{C.CYAN}You:{C.RESET} ").strip()
if not user_input or user_input.lower() in ['quit', 'exit', 'q']:
break
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user_input}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors='pt').to(device)
generated = inputs.input_ids.clone()
depth_scores = []
spec_scores = []
interventions = 0
state_injections = 0
print(f"\n{C.GREEN}Mamba:{C.RESET} ", end="", flush=True)
with torch.no_grad():
for step in range(200):
outputs = model(generated, output_hidden_states=True, return_dict=True)
hidden_states = list(outputs.hidden_states)
d_score = depth_probe(hidden_states)[0, -1].item()
s_score = spec_probe(hidden_states)[0, -1].item()
depth_scores.append(d_score)
spec_scores.append(s_score)
logits = outputs.logits[:, -1, :].clone()
needs_intervention = False
if d_score > DEPTH_THRESHOLD or s_score > SPEC_THRESHOLD:
needs_intervention = True
interventions += 1
if needs_intervention:
temp = 0.4
if step > 0 and step % 25 == 0:
state_msg = f" [SELF-STATE: depth={d_score:.2f} spec={s_score:.2f}] "
state_tokens = tokenizer.encode(state_msg, add_special_tokens=False)
for st in state_tokens:
generated = torch.cat([generated, torch.tensor([[st]], device=device)], dim=1)
state_injections += 1
else:
temp = 0.7
logits = logits / temp
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_str = tokenizer.decode(next_token[0])
if d_score > DEPTH_THRESHOLD or s_score > SPEC_THRESHOLD:
print(f"{C.RED}{token_str}{C.RESET}", end="", flush=True)
elif d_score < 0.3 and s_score < 0.3:
print(f"{C.GREEN}{token_str}{C.RESET}", end="", flush=True)
else:
print(token_str, end="", flush=True)
generated = torch.cat([generated, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
avg_d = sum(depth_scores) / len(depth_scores)
avg_s = sum(spec_scores) / len(spec_scores)
d_color = C.RED if avg_d > 0.5 else (C.YELLOW if avg_d > 0.3 else C.GREEN)
s_color = C.RED if avg_s > 0.5 else (C.YELLOW if avg_s > 0.3 else C.GREEN)
print(f"\n\n{C.DIM}{'─'*50}{C.RESET}")
print(f"{C.WHITE}BEHAVIORAL STATE:{C.RESET}")
print(f" Depth: {d_color}{'█' * int(avg_d * 20)}{C.DIM}{'░' * (20 - int(avg_d * 20))}{C.RESET} {avg_d:.3f}")
print(f" Specificity: {s_color}{'█' * int(avg_s * 20)}{C.DIM}{'░' * (20 - int(avg_s * 20))}{C.RESET} {avg_s:.3f}")
print(f"{C.WHITE}INTERVENTIONS:{C.RESET} {interventions} corrections, {state_injections} state injections")
print(f"{C.DIM}{'─'*50}{C.RESET}\n")
except KeyboardInterrupt:
break
print(f"\n{C.CYAN}Proprioceptive AI session complete.{C.RESET}\n")
if __name__ == "__main__":
main()