|
|
| """
|
| SymbolicLight — Validation Script
|
| ================================
|
| Evaluate the trained model on the TinyStories validation set.
|
|
|
| Metrics:
|
| 1. Validation Loss / Perplexity
|
| 2. Sparsity statistics (average/min/max)
|
| 3. Simple text generation demo
|
|
|
| Usage:
|
| python validate.py
|
| python validate.py --checkpoint checkpoints/best.pt
|
| python validate.py --generate --prompt "Once upon a time"
|
| """
|
|
|
| import argparse
|
| import math
|
| import sys
|
| import time
|
| import os
|
| import torch
|
| import torch.nn.functional as F
|
| from model import SymbolicLightConfig, SymbolicLightModel
|
|
|
|
|
| if sys.platform == 'win32':
|
| os.system('chcp 65001 > nul')
|
| sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|
|
|
|
| def parse_args():
|
| p = argparse.ArgumentParser(description="SymbolicLight Validation")
|
| p.add_argument("--checkpoint", type=str, default="./checkpoints/best.pt",
|
| help="Model checkpoint path")
|
| p.add_argument("--max_samples", type=int, default=5000,
|
| help="Maximum number of validation samples (to reduce wait time)")
|
| p.add_argument("--batch_size", type=int, default=16,
|
| help="Validation batch size")
|
| p.add_argument("--seq_len", type=int, default=256,
|
| help="Sequence length")
|
| p.add_argument("--generate", action="store_true",
|
| help="Whether to run text generation demo")
|
| p.add_argument("--prompt", type=str, default="Once upon a time",
|
| help="Prompt for generation")
|
| p.add_argument("--max_new_tokens", type=int, default=200,
|
| help="Maximum number of generated tokens")
|
| p.add_argument("--temperature", type=float, default=0.8,
|
| help="Generation temperature")
|
| p.add_argument("--top_k", type=int, default=50,
|
| help="Top-K sampling")
|
| return p.parse_args()
|
|
|
|
|
| def load_model(checkpoint_path, device):
|
| """Load model and checkpoint"""
|
| print(f"[Model] Loading checkpoint: {checkpoint_path}")
|
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
|
|
|
| if "config" in ckpt:
|
| cfg_dict = ckpt["config"]
|
| config = SymbolicLightConfig()
|
| for k, v in cfg_dict.items():
|
| if hasattr(config, k):
|
| setattr(config, k, v)
|
| print(f"[Model] Config loaded from checkpoint")
|
| else:
|
| config = SymbolicLightConfig()
|
| config.vocab_size = 50257
|
| print(f"[Model] Using default config")
|
|
|
| model = SymbolicLightModel(config)
|
|
|
|
|
| if "model_state_dict" in ckpt:
|
| model.load_state_dict(ckpt["model_state_dict"], strict=False)
|
| elif "model" in ckpt:
|
| model.load_state_dict(ckpt["model"], strict=False)
|
| else:
|
| model.load_state_dict(ckpt, strict=False)
|
|
|
| model = model.to(device)
|
| model.eval()
|
|
|
|
|
|
|
| for block in model.blocks:
|
| block.entropy_gate.threshold = 0.0
|
| model.stdp.enabled = False
|
| print(f"[Model] Disabled entropy gate early exit and STDP for validation")
|
|
|
|
|
| n_params = sum(p.numel() for p in model.parameters())
|
| step = ckpt.get("step", "?")
|
| loss = ckpt.get("best_loss", ckpt.get("loss", "?"))
|
| print(f"[Model] Parameters: {n_params:,} ({n_params/1e6:.1f}M)")
|
| print(f"[Model] Checkpoint step: {step}, loss: {loss}")
|
|
|
| return model, config
|
|
|
|
|
| def load_validation_data(seq_len, max_samples):
|
| """Load TinyStories validation set"""
|
| import tiktoken
|
| from datasets import load_dataset
|
|
|
| enc = tiktoken.get_encoding("gpt2")
|
| print(f"[Data] Loading TinyStories (validation) from HuggingFace...")
|
| ds = load_dataset("roneneldan/TinyStories", split="validation")
|
| print(f"[Data] Loaded {len(ds):,} validation stories")
|
|
|
|
|
| print(f"[Data] Tokenizing...")
|
| all_tokens = []
|
| for i, example in enumerate(ds):
|
| text = example.get("text", "")
|
| tokens = enc.encode(text, allowed_special=set())
|
| all_tokens.extend(tokens)
|
| if len(all_tokens) > max_samples * seq_len * 2:
|
| break
|
| if (i + 1) % 50000 == 0:
|
| print(f" ... tokenized {i+1:,} stories ({len(all_tokens):,} tokens)")
|
|
|
| n_samples = min(max_samples, (len(all_tokens) - 1) // seq_len)
|
| print(f"[Data] Total: {len(all_tokens):,} tokens, {n_samples:,} validation samples")
|
|
|
|
|
| tokens_tensor = torch.tensor(all_tokens[:n_samples * seq_len + 1], dtype=torch.long)
|
|
|
| return tokens_tensor, n_samples, enc
|
|
|
|
|
| @torch.no_grad()
|
| def validate(model, tokens_tensor, n_samples, seq_len, batch_size, device):
|
| """Calculate loss, perplexity, and sparsity on the validation set"""
|
| model.eval()
|
|
|
| total_loss = 0.0
|
| total_tokens = 0
|
| sparsity_list = []
|
| n_batches = 0
|
|
|
| print(f"\n{'='*60}")
|
| print(f" VALIDATION ({n_samples:,} samples, batch_size={batch_size})")
|
| print(f"{'='*60}")
|
|
|
| start_time = time.time()
|
|
|
| for start_idx in range(0, n_samples, batch_size):
|
| end_idx = min(start_idx + batch_size, n_samples)
|
| actual_bs = end_idx - start_idx
|
|
|
|
|
| x_list = []
|
| y_list = []
|
| for i in range(start_idx, end_idx):
|
| offset = i * seq_len
|
| x_list.append(tokens_tensor[offset:offset + seq_len])
|
| y_list.append(tokens_tensor[offset + 1:offset + seq_len + 1])
|
|
|
| x = torch.stack(x_list).to(device)
|
| y = torch.stack(y_list).to(device)
|
|
|
|
|
| with torch.amp.autocast('cuda', dtype=torch.float16):
|
| logits = model(x)
|
|
|
|
|
| loss = F.cross_entropy(
|
| logits.view(-1, logits.size(-1)),
|
| y.view(-1)
|
| )
|
|
|
| total_loss += loss.item() * actual_bs * seq_len
|
| total_tokens += actual_bs * seq_len
|
|
|
|
|
|
|
|
|
|
|
| n_batches += 1
|
| if n_batches % 50 == 0:
|
| avg_loss_so_far = total_loss / total_tokens
|
| avg_ppl_so_far = math.exp(min(avg_loss_so_far, 20))
|
| elapsed = time.time() - start_time
|
| print(f" Batch {n_batches:4d} | Loss: {avg_loss_so_far:.4f} | "
|
| f"PPL: {avg_ppl_so_far:7.2f} | "
|
| f"Time: {elapsed:.1f}s")
|
|
|
|
|
| avg_loss = total_loss / total_tokens
|
| avg_ppl = math.exp(avg_loss)
|
| elapsed = time.time() - start_time
|
|
|
| print(f"\n{'='*60}")
|
| print(f" VALIDATION RESULTS")
|
| print(f"{'='*60}")
|
| print(f" Validation Loss: {avg_loss:.4f}")
|
| print(f" Validation Perplexity: {avg_ppl:.2f}")
|
|
|
| if sparsity_list:
|
| avg_sp = sum(sparsity_list) / len(sparsity_list) * 100
|
| min_sp = min(sparsity_list) * 100
|
| max_sp = max(sparsity_list) * 100
|
| print(f" Sparsity (avg): {avg_sp:.1f}%")
|
| print(f" Sparsity (min/max): {min_sp:.1f}% / {max_sp:.1f}%")
|
|
|
| print(f" Total tokens: {total_tokens:,}")
|
| print(f" Time: {elapsed:.1f}s")
|
| print(f" Throughput: {total_tokens/elapsed:,.0f} tok/s")
|
| print(f"{'='*60}\n")
|
|
|
| return avg_loss, avg_ppl
|
|
|
|
|
| @torch.no_grad()
|
| def generate_text(model, enc, prompt, max_new_tokens, temperature, top_k, device):
|
| """Autoregressive text generation"""
|
| model.eval()
|
|
|
| print(f"\n{'='*60}")
|
| print(f" TEXT GENERATION")
|
| print(f"{'='*60}")
|
| print(f" Prompt: \"{prompt}\"")
|
| print(f" Temperature: {temperature}, Top-K: {top_k}")
|
| print(f" Max new tokens: {max_new_tokens}")
|
| print(f"{'='*60}\n")
|
|
|
|
|
| token_ids = enc.encode(prompt, allowed_special=set())
|
| tokens = torch.tensor([token_ids], dtype=torch.long, device=device)
|
|
|
| generated = list(token_ids)
|
| start_time = time.time()
|
|
|
| for i in range(max_new_tokens):
|
|
|
| input_ids = tokens[:, -256:]
|
|
|
| with torch.amp.autocast('cuda', dtype=torch.float16):
|
| logits = model(input_ids)
|
|
|
|
|
| next_logits = logits[:, -1, :] / temperature
|
|
|
|
|
| if top_k > 0:
|
| values, _ = torch.topk(next_logits, top_k)
|
| min_val = values[:, -1].unsqueeze(-1)
|
| next_logits = torch.where(
|
| next_logits < min_val,
|
| torch.full_like(next_logits, float('-inf')),
|
| next_logits
|
| )
|
|
|
| probs = F.softmax(next_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
|
|
| generated.append(next_token.item())
|
| tokens = torch.cat([tokens, next_token], dim=1)
|
|
|
|
|
|
|
|
|
| elapsed = time.time() - start_time
|
| output_text = enc.decode(generated)
|
|
|
| print(f"--- Generated Text ---")
|
| print(output_text)
|
| print(f"--- End ---")
|
| print(f"\n[{max_new_tokens} tokens in {elapsed:.2f}s, "
|
| f"{max_new_tokens/elapsed:.1f} tok/s]")
|
|
|
| return output_text
|
|
|
|
|
| def main():
|
| args = parse_args()
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"[Device] {device}")
|
|
|
|
|
| model, config = load_model(args.checkpoint, device)
|
|
|
|
|
| tokens_tensor, n_samples, enc = load_validation_data(args.seq_len, args.max_samples)
|
|
|
|
|
| val_loss, val_ppl = validate(model, tokens_tensor, n_samples,
|
| args.seq_len, args.batch_size, device)
|
|
|
|
|
| if args.generate or True:
|
| prompts = [
|
| "Once upon a time",
|
| "The little cat",
|
| "Mom said to the children",
|
| ]
|
| for prompt in prompts:
|
| generate_text(model, enc, prompt,
|
| args.max_new_tokens, args.temperature,
|
| args.top_k, device)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|