|
|
| """ |
| SymbolicLight-PoC text generation script. |
| |
| Load a checkpoint and run single-prompt or interactive text generation. |
| |
| Usage: |
| # Interactive mode, using the checkpoint next to this script |
| python generate.py |
| |
| # Specify checkpoint |
| python generate.py --checkpoint best.pt |
| |
| # Single prompt generation |
| python generate.py --prompt "Hello world" |
| |
| # Enable experimental STDP updates |
| python generate.py --enable_stdp |
| """ |
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import tiktoken |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| DEFAULT_CHECKPOINT = SCRIPT_DIR / "best.pt" |
|
|
| sys.path.insert(0, str(SCRIPT_DIR)) |
| from model import SymbolicLightConfig, SymbolicLightModel |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="SymbolicLight-PoC Generator") |
| p.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT), |
| help="Checkpoint path") |
| p.add_argument("--prompt", type=str, default=None, |
| help="Single prompt generation mode (skip interactive chat)") |
| p.add_argument("--max_tokens", type=int, default=200, |
| help="Max number of tokens to generate") |
| p.add_argument("--temperature", type=float, default=0.8,
|
| help="Sampling temperature (higher = more random, lower = more conservative)")
|
| p.add_argument("--top_k", type=int, default=50, |
| help="Top-K sampling") |
| p.add_argument("--enable_stdp", action="store_true", |
| help="Enable experimental STDP updates during inference") |
| p.add_argument("--save_stdp", type=str, default=None, |
| help="Save updated weights here after STDP learning") |
| p.add_argument("--allow_random_init", action="store_true", |
| help="Allow random initialization when checkpoint is missing") |
| p.add_argument("--trust_checkpoint_pickle", action="store_true", |
| help="Allow unsafe pickle checkpoint loading if weights_only=True fails") |
|
|
| args = p.parse_args() |
| if args.max_tokens < 1: |
| p.error("--max_tokens must be >= 1") |
| if args.temperature <= 0: |
| p.error("--temperature must be > 0") |
| if args.top_k < 0: |
| p.error("--top_k must be >= 0") |
| return args |
|
|
|
|
| class TiktokenWrapper: |
| """tiktoken GPT-2 tokenizer wrapper.""" |
| def __init__(self, vocab_size=50257): |
| self.vocab_size = vocab_size |
| self.enc = tiktoken.get_encoding("gpt2") |
|
|
| def encode(self, text: str) -> list: |
| return self.enc.encode(text, allowed_special=set()) |
|
|
| def decode(self, ids: list) -> str: |
| return self.enc.decode([int(i) for i in ids]) |
|
|
|
|
| def _load_checkpoint(path: Path, device: torch.device, trust_pickle: bool): |
| try: |
| return torch.load(path, map_location=device, weights_only=True) |
| except Exception as exc: |
| if not trust_pickle: |
| raise RuntimeError( |
| "Failed to load checkpoint with weights_only=True. " |
| "If this is a trusted local checkpoint that requires pickle, " |
| "rerun with --trust_checkpoint_pickle." |
| ) from exc |
|
|
| print("[Load] WARNING: falling back to weights_only=False for a trusted checkpoint.") |
| return torch.load(path, map_location=device, weights_only=False) |
|
|
|
|
| def _format_metric(value) -> str: |
| if value is None: |
| return "?" |
| try: |
| return f"{float(value):.4f}" |
| except (TypeError, ValueError): |
| return str(value) |
|
|
|
|
| def _select_device() -> torch.device: |
| if torch.cuda.is_available() and torch.cuda.device_count() > 0: |
| return torch.device("cuda") |
| return torch.device("cpu") |
|
|
|
|
| def load_model(checkpoint_path: str, enable_stdp: bool = False, |
| allow_random_init: bool = False, |
| trust_checkpoint_pickle: bool = False): |
| """Load model and checkpoint""" |
| device = _select_device() |
|
|
| ckpt_path = Path(checkpoint_path).expanduser() |
| if ckpt_path.exists(): |
| print(f"[Load] Loading checkpoint: {ckpt_path}") |
| ckpt = _load_checkpoint(ckpt_path, device, trust_checkpoint_pickle) |
| if not isinstance(ckpt, dict): |
| raise ValueError(f"Checkpoint must be a dict, got {type(ckpt).__name__}") |
|
|
| config_dict = ckpt.get("config") |
| if not isinstance(config_dict, dict): |
| raise KeyError("Checkpoint is missing a 'config' dictionary") |
|
|
| if "model" in ckpt: |
| state_dict = ckpt["model"] |
| elif "model_state_dict" in ckpt: |
| state_dict = ckpt["model_state_dict"] |
| else: |
| raise KeyError("Checkpoint is missing model weights under 'model' or 'model_state_dict'") |
|
|
| config = SymbolicLightConfig(**config_dict) |
| config.enable_stdp = enable_stdp |
| model = SymbolicLightModel(config).to(device) |
| load_result = model.load_state_dict(state_dict, strict=False) |
|
|
| if load_result.missing_keys: |
| print(f"[Load] WARNING: missing keys: {load_result.missing_keys}") |
|
|
| ignored_unexpected = {"spike_encoder.v_mem"} |
| unexpected_keys = [ |
| key for key in load_result.unexpected_keys |
| if key not in ignored_unexpected |
| ] |
| if unexpected_keys: |
| print(f"[Load] WARNING: unexpected keys: {unexpected_keys}") |
|
|
| step = ckpt.get("global_step", ckpt.get("step", "?")) |
| loss = _format_metric(ckpt.get("best_loss", ckpt.get("loss"))) |
| print(f"[Load] Model loaded (step={step}, loss={loss})") |
| else: |
| if not allow_random_init: |
| raise FileNotFoundError( |
| f"Checkpoint not found: {ckpt_path}. " |
| "Pass --allow_random_init only for code smoke tests." |
| ) |
|
|
| print(f"[Load] WARNING: checkpoint not found at {ckpt_path}") |
| print("[Load] WARNING: initializing a random model for smoke testing only") |
| config = SymbolicLightConfig(enable_stdp=enable_stdp) |
| model = SymbolicLightModel(config).to(device) |
|
|
| model.eval() |
| return model, config, device |
|
|
|
|
| def generate_text(model, tokenizer, prompt: str, device, |
| max_tokens=200, temperature=0.8, top_k=50): |
| """Generate text""" |
| |
| input_ids = tokenizer.encode(prompt) |
| if not input_ids: |
| raise ValueError("Prompt must contain at least one token") |
|
|
| vocab_size = getattr(getattr(model, "config", None), "vocab_size", None) |
| if vocab_size: |
| invalid_ids = [token_id for token_id in input_ids if token_id < 0 or token_id >= vocab_size] |
| if invalid_ids: |
| sample = invalid_ids[:5] |
| raise ValueError( |
| f"Prompt contains token IDs outside model vocab_size={vocab_size}: {sample}" |
| ) |
|
|
| input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
|
| |
| effective_top_k = min(top_k, vocab_size) if top_k > 0 and vocab_size else top_k |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_tensor, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_k=effective_top_k, |
| ) |
|
|
|
|
| new_ids = output_ids[0, len(input_ids):].tolist()
|
| generated_text = tokenizer.decode(new_ids)
|
|
|
|
|
| with torch.no_grad():
|
| test_input = input_tensor[:, :min(32, input_tensor.size(1))]
|
| spikes, _ = model.spike_encoder(test_input)
|
| sparsity = 1.0 - spikes.mean().item()
|
|
|
| return generated_text, sparsity
|
|
|
|
|
| def interactive_chat(model, tokenizer, device, args):
|
| """Interactive chat"""
|
| print("\n" + "=" * 60)
|
| print(" SymbolicLight Interactive Chat")
|
| print("=" * 60)
|
| print(f" Temperature: {args.temperature}")
|
| print(f" Max tokens: {args.max_tokens}")
|
| print(f" STDP Learn: {'ON' if args.enable_stdp else 'OFF'}")
|
| print(f" Device: {device}")
|
| print("-" * 60)
|
| print(" Type your message and press Enter.")
|
| print(" Type 'quit' to exit.")
|
| print(" Type 'sparsity' to see network sparsity stats.")
|
| if args.enable_stdp:
|
| print(" Type 'save' to save STDP-updated weights.")
|
| print("=" * 60 + "\n")
|
|
|
| conversation_history = ""
|
| turn = 0
|
|
|
| while True:
|
| try:
|
| user_input = input("You: ").strip()
|
| except (EOFError, KeyboardInterrupt):
|
| print("\nBye!")
|
| break
|
|
|
| if not user_input:
|
| continue
|
|
|
| if user_input.lower() == 'quit':
|
| print("Bye!")
|
| break
|
|
|
| if user_input.lower() == 'sparsity': |
| try: |
| stats = model.get_sparsity_stats() |
| except Exception as exc: |
| print(f"\n[Sparsity Stats] unavailable: {exc}\n") |
| continue |
|
|
| print("\n[Sparsity Stats]") |
| for k, v in stats.items(): |
| print(f" {k}: {v*100:.1f}% silent") |
| print() |
| continue |
|
|
| if user_input.lower() == 'save' and args.enable_stdp: |
| save_path = Path(args.save_stdp) if args.save_stdp else Path(args.checkpoint).with_name("stdp_updated.pt") |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save({ |
| "model": model.state_dict(), |
| "config": model.config.__dict__, |
| }, save_path) |
| print(f"[STDP] Weights saved to {save_path}\n") |
| continue
|
|
|
|
|
| turn += 1
|
| conversation_history += f"{user_input} "
|
| prompt = conversation_history
|
|
|
| |
| try: |
| response, sparsity = generate_text( |
| model, tokenizer, prompt, device, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| ) |
| except Exception as exc: |
| print(f"[Error] {exc}\n") |
| continue |
|
|
| |
| conversation_history += f"{response} " |
|
|
|
|
| print(f"SymbolicLight: {response}")
|
| print(f" [sparsity: {sparsity*100:.1f}% | "
|
| f"stdp: {'learning' if args.enable_stdp else 'off'}]\n")
|
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| |
| try: |
| model, config, device = load_model( |
| args.checkpoint, |
| enable_stdp=args.enable_stdp, |
| allow_random_init=args.allow_random_init, |
| trust_checkpoint_pickle=args.trust_checkpoint_pickle, |
| ) |
| except Exception as exc: |
| print(f"[Error] {exc}", file=sys.stderr) |
| raise SystemExit(1) from exc |
|
|
|
|
| tokenizer = TiktokenWrapper(config.vocab_size)
|
|
|
| if args.prompt: |
| |
| print(f"\nPrompt: {args.prompt}") |
| try: |
| response, sparsity = generate_text( |
| model, tokenizer, args.prompt, device, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| ) |
| except Exception as exc: |
| print(f"[Error] {exc}", file=sys.stderr) |
| raise SystemExit(1) from exc |
|
|
| print(f"Response: {response}") |
| print(f"Sparsity: {sparsity*100:.1f}%") |
| else: |
| |
| interactive_chat(model, tokenizer, device, args) |
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|