SymbolicLight-PoC / src /generate.py
symboliclight-ai's picture
Upload 8 files
eee03bb verified
#!/usr/bin/env python3
"""
SymbolicLight-PoC text generation script.
Load a checkpoint and run single-prompt or interactive text generation.
Usage:
# Interactive mode, using the checkpoint next to this script
python generate.py
# Specify checkpoint
python generate.py --checkpoint best.pt
# Single prompt generation
python generate.py --prompt "Hello world"
# Enable experimental STDP updates
python generate.py --enable_stdp
"""
import argparse
import sys
from pathlib import Path
import torch
import tiktoken
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CHECKPOINT = SCRIPT_DIR / "best.pt"
sys.path.insert(0, str(SCRIPT_DIR))
from model import SymbolicLightConfig, SymbolicLightModel
def parse_args():
p = argparse.ArgumentParser(description="SymbolicLight-PoC Generator")
p.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT),
help="Checkpoint path")
p.add_argument("--prompt", type=str, default=None,
help="Single prompt generation mode (skip interactive chat)")
p.add_argument("--max_tokens", type=int, default=200,
help="Max number of tokens to generate")
p.add_argument("--temperature", type=float, default=0.8,
help="Sampling temperature (higher = more random, lower = more conservative)")
p.add_argument("--top_k", type=int, default=50,
help="Top-K sampling")
p.add_argument("--enable_stdp", action="store_true",
help="Enable experimental STDP updates during inference")
p.add_argument("--save_stdp", type=str, default=None,
help="Save updated weights here after STDP learning")
p.add_argument("--allow_random_init", action="store_true",
help="Allow random initialization when checkpoint is missing")
p.add_argument("--trust_checkpoint_pickle", action="store_true",
help="Allow unsafe pickle checkpoint loading if weights_only=True fails")
args = p.parse_args()
if args.max_tokens < 1:
p.error("--max_tokens must be >= 1")
if args.temperature <= 0:
p.error("--temperature must be > 0")
if args.top_k < 0:
p.error("--top_k must be >= 0")
return args
class TiktokenWrapper:
"""tiktoken GPT-2 tokenizer wrapper."""
def __init__(self, vocab_size=50257):
self.vocab_size = vocab_size
self.enc = tiktoken.get_encoding("gpt2")
def encode(self, text: str) -> list:
return self.enc.encode(text, allowed_special=set())
def decode(self, ids: list) -> str:
return self.enc.decode([int(i) for i in ids])
def _load_checkpoint(path: Path, device: torch.device, trust_pickle: bool):
try:
return torch.load(path, map_location=device, weights_only=True)
except Exception as exc:
if not trust_pickle:
raise RuntimeError(
"Failed to load checkpoint with weights_only=True. "
"If this is a trusted local checkpoint that requires pickle, "
"rerun with --trust_checkpoint_pickle."
) from exc
print("[Load] WARNING: falling back to weights_only=False for a trusted checkpoint.")
return torch.load(path, map_location=device, weights_only=False)
def _format_metric(value) -> str:
if value is None:
return "?"
try:
return f"{float(value):.4f}"
except (TypeError, ValueError):
return str(value)
def _select_device() -> torch.device:
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return torch.device("cuda")
return torch.device("cpu")
def load_model(checkpoint_path: str, enable_stdp: bool = False,
allow_random_init: bool = False,
trust_checkpoint_pickle: bool = False):
"""Load model and checkpoint"""
device = _select_device()
ckpt_path = Path(checkpoint_path).expanduser()
if ckpt_path.exists():
print(f"[Load] Loading checkpoint: {ckpt_path}")
ckpt = _load_checkpoint(ckpt_path, device, trust_checkpoint_pickle)
if not isinstance(ckpt, dict):
raise ValueError(f"Checkpoint must be a dict, got {type(ckpt).__name__}")
config_dict = ckpt.get("config")
if not isinstance(config_dict, dict):
raise KeyError("Checkpoint is missing a 'config' dictionary")
if "model" in ckpt:
state_dict = ckpt["model"]
elif "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
else:
raise KeyError("Checkpoint is missing model weights under 'model' or 'model_state_dict'")
config = SymbolicLightConfig(**config_dict)
config.enable_stdp = enable_stdp
model = SymbolicLightModel(config).to(device)
load_result = model.load_state_dict(state_dict, strict=False)
if load_result.missing_keys:
print(f"[Load] WARNING: missing keys: {load_result.missing_keys}")
ignored_unexpected = {"spike_encoder.v_mem"}
unexpected_keys = [
key for key in load_result.unexpected_keys
if key not in ignored_unexpected
]
if unexpected_keys:
print(f"[Load] WARNING: unexpected keys: {unexpected_keys}")
step = ckpt.get("global_step", ckpt.get("step", "?"))
loss = _format_metric(ckpt.get("best_loss", ckpt.get("loss")))
print(f"[Load] Model loaded (step={step}, loss={loss})")
else:
if not allow_random_init:
raise FileNotFoundError(
f"Checkpoint not found: {ckpt_path}. "
"Pass --allow_random_init only for code smoke tests."
)
print(f"[Load] WARNING: checkpoint not found at {ckpt_path}")
print("[Load] WARNING: initializing a random model for smoke testing only")
config = SymbolicLightConfig(enable_stdp=enable_stdp)
model = SymbolicLightModel(config).to(device)
model.eval()
return model, config, device
def generate_text(model, tokenizer, prompt: str, device,
max_tokens=200, temperature=0.8, top_k=50):
"""Generate text"""
# Encode
input_ids = tokenizer.encode(prompt)
if not input_ids:
raise ValueError("Prompt must contain at least one token")
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
if vocab_size:
invalid_ids = [token_id for token_id in input_ids if token_id < 0 or token_id >= vocab_size]
if invalid_ids:
sample = invalid_ids[:5]
raise ValueError(
f"Prompt contains token IDs outside model vocab_size={vocab_size}: {sample}"
)
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
# Generate
effective_top_k = min(top_k, vocab_size) if top_k > 0 and vocab_size else top_k
with torch.no_grad():
output_ids = model.generate(
input_tensor,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=effective_top_k,
)
# Only keep newly generated part
new_ids = output_ids[0, len(input_ids):].tolist()
generated_text = tokenizer.decode(new_ids)
# Calculate sparsity
with torch.no_grad():
test_input = input_tensor[:, :min(32, input_tensor.size(1))]
spikes, _ = model.spike_encoder(test_input)
sparsity = 1.0 - spikes.mean().item()
return generated_text, sparsity
def interactive_chat(model, tokenizer, device, args):
"""Interactive chat"""
print("\n" + "=" * 60)
print(" SymbolicLight Interactive Chat")
print("=" * 60)
print(f" Temperature: {args.temperature}")
print(f" Max tokens: {args.max_tokens}")
print(f" STDP Learn: {'ON' if args.enable_stdp else 'OFF'}")
print(f" Device: {device}")
print("-" * 60)
print(" Type your message and press Enter.")
print(" Type 'quit' to exit.")
print(" Type 'sparsity' to see network sparsity stats.")
if args.enable_stdp:
print(" Type 'save' to save STDP-updated weights.")
print("=" * 60 + "\n")
conversation_history = ""
turn = 0
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not user_input:
continue
if user_input.lower() == 'quit':
print("Bye!")
break
if user_input.lower() == 'sparsity':
try:
stats = model.get_sparsity_stats()
except Exception as exc:
print(f"\n[Sparsity Stats] unavailable: {exc}\n")
continue
print("\n[Sparsity Stats]")
for k, v in stats.items():
print(f" {k}: {v*100:.1f}% silent")
print()
continue
if user_input.lower() == 'save' and args.enable_stdp:
save_path = Path(args.save_stdp) if args.save_stdp else Path(args.checkpoint).with_name("stdp_updated.pt")
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"model": model.state_dict(),
"config": model.config.__dict__,
}, save_path)
print(f"[STDP] Weights saved to {save_path}\n")
continue
# Build context
turn += 1
conversation_history += f"{user_input} "
prompt = conversation_history
# Generate
try:
response, sparsity = generate_text(
model, tokenizer, prompt, device,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
)
except Exception as exc:
print(f"[Error] {exc}\n")
continue
# Update history
conversation_history += f"{response} "
# Display
print(f"SymbolicLight: {response}")
print(f" [sparsity: {sparsity*100:.1f}% | "
f"stdp: {'learning' if args.enable_stdp else 'off'}]\n")
def main():
args = parse_args()
# Load model
try:
model, config, device = load_model(
args.checkpoint,
enable_stdp=args.enable_stdp,
allow_random_init=args.allow_random_init,
trust_checkpoint_pickle=args.trust_checkpoint_pickle,
)
except Exception as exc:
print(f"[Error] {exc}", file=sys.stderr)
raise SystemExit(1) from exc
# Initialize tokenizer
tokenizer = TiktokenWrapper(config.vocab_size)
if args.prompt:
# Single prompt generation mode
print(f"\nPrompt: {args.prompt}")
try:
response, sparsity = generate_text(
model, tokenizer, args.prompt, device,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
)
except Exception as exc:
print(f"[Error] {exc}", file=sys.stderr)
raise SystemExit(1) from exc
print(f"Response: {response}")
print(f"Sparsity: {sparsity*100:.1f}%")
else:
# Interactive chat mode
interactive_chat(model, tokenizer, device, args)
if __name__ == "__main__":
main()