File size: 11,444 Bytes
eee03bb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 | #!/usr/bin/env python3
"""
SymbolicLight-PoC text generation script.
Load a checkpoint and run single-prompt or interactive text generation.
Usage:
# Interactive mode, using the checkpoint next to this script
python generate.py
# Specify checkpoint
python generate.py --checkpoint best.pt
# Single prompt generation
python generate.py --prompt "Hello world"
# Enable experimental STDP updates
python generate.py --enable_stdp
"""
import argparse
import sys
from pathlib import Path
import torch
import tiktoken
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CHECKPOINT = SCRIPT_DIR / "best.pt"
sys.path.insert(0, str(SCRIPT_DIR))
from model import SymbolicLightConfig, SymbolicLightModel
def parse_args():
p = argparse.ArgumentParser(description="SymbolicLight-PoC Generator")
p.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT),
help="Checkpoint path")
p.add_argument("--prompt", type=str, default=None,
help="Single prompt generation mode (skip interactive chat)")
p.add_argument("--max_tokens", type=int, default=200,
help="Max number of tokens to generate")
p.add_argument("--temperature", type=float, default=0.8,
help="Sampling temperature (higher = more random, lower = more conservative)")
p.add_argument("--top_k", type=int, default=50,
help="Top-K sampling")
p.add_argument("--enable_stdp", action="store_true",
help="Enable experimental STDP updates during inference")
p.add_argument("--save_stdp", type=str, default=None,
help="Save updated weights here after STDP learning")
p.add_argument("--allow_random_init", action="store_true",
help="Allow random initialization when checkpoint is missing")
p.add_argument("--trust_checkpoint_pickle", action="store_true",
help="Allow unsafe pickle checkpoint loading if weights_only=True fails")
args = p.parse_args()
if args.max_tokens < 1:
p.error("--max_tokens must be >= 1")
if args.temperature <= 0:
p.error("--temperature must be > 0")
if args.top_k < 0:
p.error("--top_k must be >= 0")
return args
class TiktokenWrapper:
"""tiktoken GPT-2 tokenizer wrapper."""
def __init__(self, vocab_size=50257):
self.vocab_size = vocab_size
self.enc = tiktoken.get_encoding("gpt2")
def encode(self, text: str) -> list:
return self.enc.encode(text, allowed_special=set())
def decode(self, ids: list) -> str:
return self.enc.decode([int(i) for i in ids])
def _load_checkpoint(path: Path, device: torch.device, trust_pickle: bool):
try:
return torch.load(path, map_location=device, weights_only=True)
except Exception as exc:
if not trust_pickle:
raise RuntimeError(
"Failed to load checkpoint with weights_only=True. "
"If this is a trusted local checkpoint that requires pickle, "
"rerun with --trust_checkpoint_pickle."
) from exc
print("[Load] WARNING: falling back to weights_only=False for a trusted checkpoint.")
return torch.load(path, map_location=device, weights_only=False)
def _format_metric(value) -> str:
if value is None:
return "?"
try:
return f"{float(value):.4f}"
except (TypeError, ValueError):
return str(value)
def _select_device() -> torch.device:
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return torch.device("cuda")
return torch.device("cpu")
def load_model(checkpoint_path: str, enable_stdp: bool = False,
allow_random_init: bool = False,
trust_checkpoint_pickle: bool = False):
"""Load model and checkpoint"""
device = _select_device()
ckpt_path = Path(checkpoint_path).expanduser()
if ckpt_path.exists():
print(f"[Load] Loading checkpoint: {ckpt_path}")
ckpt = _load_checkpoint(ckpt_path, device, trust_checkpoint_pickle)
if not isinstance(ckpt, dict):
raise ValueError(f"Checkpoint must be a dict, got {type(ckpt).__name__}")
config_dict = ckpt.get("config")
if not isinstance(config_dict, dict):
raise KeyError("Checkpoint is missing a 'config' dictionary")
if "model" in ckpt:
state_dict = ckpt["model"]
elif "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
else:
raise KeyError("Checkpoint is missing model weights under 'model' or 'model_state_dict'")
config = SymbolicLightConfig(**config_dict)
config.enable_stdp = enable_stdp
model = SymbolicLightModel(config).to(device)
load_result = model.load_state_dict(state_dict, strict=False)
if load_result.missing_keys:
print(f"[Load] WARNING: missing keys: {load_result.missing_keys}")
ignored_unexpected = {"spike_encoder.v_mem"}
unexpected_keys = [
key for key in load_result.unexpected_keys
if key not in ignored_unexpected
]
if unexpected_keys:
print(f"[Load] WARNING: unexpected keys: {unexpected_keys}")
step = ckpt.get("global_step", ckpt.get("step", "?"))
loss = _format_metric(ckpt.get("best_loss", ckpt.get("loss")))
print(f"[Load] Model loaded (step={step}, loss={loss})")
else:
if not allow_random_init:
raise FileNotFoundError(
f"Checkpoint not found: {ckpt_path}. "
"Pass --allow_random_init only for code smoke tests."
)
print(f"[Load] WARNING: checkpoint not found at {ckpt_path}")
print("[Load] WARNING: initializing a random model for smoke testing only")
config = SymbolicLightConfig(enable_stdp=enable_stdp)
model = SymbolicLightModel(config).to(device)
model.eval()
return model, config, device
def generate_text(model, tokenizer, prompt: str, device,
max_tokens=200, temperature=0.8, top_k=50):
"""Generate text"""
# Encode
input_ids = tokenizer.encode(prompt)
if not input_ids:
raise ValueError("Prompt must contain at least one token")
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
if vocab_size:
invalid_ids = [token_id for token_id in input_ids if token_id < 0 or token_id >= vocab_size]
if invalid_ids:
sample = invalid_ids[:5]
raise ValueError(
f"Prompt contains token IDs outside model vocab_size={vocab_size}: {sample}"
)
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
# Generate
effective_top_k = min(top_k, vocab_size) if top_k > 0 and vocab_size else top_k
with torch.no_grad():
output_ids = model.generate(
input_tensor,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=effective_top_k,
)
# Only keep newly generated part
new_ids = output_ids[0, len(input_ids):].tolist()
generated_text = tokenizer.decode(new_ids)
# Calculate sparsity
with torch.no_grad():
test_input = input_tensor[:, :min(32, input_tensor.size(1))]
spikes, _ = model.spike_encoder(test_input)
sparsity = 1.0 - spikes.mean().item()
return generated_text, sparsity
def interactive_chat(model, tokenizer, device, args):
"""Interactive chat"""
print("\n" + "=" * 60)
print(" SymbolicLight Interactive Chat")
print("=" * 60)
print(f" Temperature: {args.temperature}")
print(f" Max tokens: {args.max_tokens}")
print(f" STDP Learn: {'ON' if args.enable_stdp else 'OFF'}")
print(f" Device: {device}")
print("-" * 60)
print(" Type your message and press Enter.")
print(" Type 'quit' to exit.")
print(" Type 'sparsity' to see network sparsity stats.")
if args.enable_stdp:
print(" Type 'save' to save STDP-updated weights.")
print("=" * 60 + "\n")
conversation_history = ""
turn = 0
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not user_input:
continue
if user_input.lower() == 'quit':
print("Bye!")
break
if user_input.lower() == 'sparsity':
try:
stats = model.get_sparsity_stats()
except Exception as exc:
print(f"\n[Sparsity Stats] unavailable: {exc}\n")
continue
print("\n[Sparsity Stats]")
for k, v in stats.items():
print(f" {k}: {v*100:.1f}% silent")
print()
continue
if user_input.lower() == 'save' and args.enable_stdp:
save_path = Path(args.save_stdp) if args.save_stdp else Path(args.checkpoint).with_name("stdp_updated.pt")
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"model": model.state_dict(),
"config": model.config.__dict__,
}, save_path)
print(f"[STDP] Weights saved to {save_path}\n")
continue
# Build context
turn += 1
conversation_history += f"{user_input} "
prompt = conversation_history
# Generate
try:
response, sparsity = generate_text(
model, tokenizer, prompt, device,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
)
except Exception as exc:
print(f"[Error] {exc}\n")
continue
# Update history
conversation_history += f"{response} "
# Display
print(f"SymbolicLight: {response}")
print(f" [sparsity: {sparsity*100:.1f}% | "
f"stdp: {'learning' if args.enable_stdp else 'off'}]\n")
def main():
args = parse_args()
# Load model
try:
model, config, device = load_model(
args.checkpoint,
enable_stdp=args.enable_stdp,
allow_random_init=args.allow_random_init,
trust_checkpoint_pickle=args.trust_checkpoint_pickle,
)
except Exception as exc:
print(f"[Error] {exc}", file=sys.stderr)
raise SystemExit(1) from exc
# Initialize tokenizer
tokenizer = TiktokenWrapper(config.vocab_size)
if args.prompt:
# Single prompt generation mode
print(f"\nPrompt: {args.prompt}")
try:
response, sparsity = generate_text(
model, tokenizer, args.prompt, device,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
)
except Exception as exc:
print(f"[Error] {exc}", file=sys.stderr)
raise SystemExit(1) from exc
print(f"Response: {response}")
print(f"Sparsity: {sparsity*100:.1f}%")
else:
# Interactive chat mode
interactive_chat(model, tokenizer, device, args)
if __name__ == "__main__":
main()
|