#!/usr/bin/env python3 """Interactive chat loop for SymbolicLight checkpoints.""" import argparse import os import sys from pathlib import Path import torch import torch.nn.functional as F sys.path.insert(0, str(Path(__file__).parent)) from eval_08 import ( DEFAULT_CHECKPOINT, _SL_TOKENIZER_PATH, TokenizerWrapper, _resolve_load_dtype, build_config_from_checkpoint, build_model_from_checkpoint, build_model_from_checkpoint_zip, load_checkpoint, load_checkpoint_metadata, ) class _ArgsShim: seq_len = None def parse_args(): parser = argparse.ArgumentParser(description="Interactive chat for SymbolicLight checkpoints") parser.add_argument("--checkpoint_path", type=str, default=DEFAULT_CHECKPOINT, help="Path to .pt checkpoint") parser.add_argument("--tokenizer_path", type=str, default=_SL_TOKENIZER_PATH, help="Path to tokenizer model") parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="Execution device") parser.add_argument("--allow_windows_cuda", action="store_true", help="Allow CUDA transfer on Windows") parser.add_argument("--load_dtype", type=str, default="auto", choices=["auto", "fp32", "fp16"], help="Weight dtype during loading") parser.add_argument("--max_new_tokens", type=int, default=96, help="Maximum new tokens per reply") parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") parser.add_argument("--top_k", type=int, default=20, help="Top-k sampling cutoff") parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling cutoff") parser.add_argument("--repetition_penalty", type=float, default=1.15, help="Penalty for tokens that already appeared in the reply") parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="Block repeated n-grams in generated text") parser.add_argument("--history_turns", type=int, default=0, help="Number of recent turns to keep in the prompt") parser.add_argument("--prompt_format", type=str, default="answer", choices=["raw", "qa", "chat", "answer"], help="Prompt template style") parser.add_argument("--no_adaptive_temperature", action="store_true", help="Disable entropy-based adaptive temperature") return parser.parse_args() def resolve_device(args): if args.device == "cpu": return torch.device("cpu") if args.device == "cuda": if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available on this machine.") return torch.device("cuda") return torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_runtime(args): device = resolve_device(args) load_dtype = _resolve_load_dtype(args.load_dtype) if os.name == "nt" and device.type == "cuda" and not args.allow_windows_cuda: raise RuntimeError( "Windows CUDA loading is disabled by default for this checkpoint. " "Re-run with --allow_windows_cuda to enable the GPU path." ) if os.name == "nt": ckpt_meta = load_checkpoint_metadata(args.checkpoint_path) config, global_step = build_config_from_checkpoint(ckpt_meta, _ArgsShim()) model, load_info = build_model_from_checkpoint_zip( config, args.checkpoint_path, ckpt_meta, target_float_dtype=load_dtype, ) else: ckpt = load_checkpoint(args.checkpoint_path, device="cpu") config, global_step = build_config_from_checkpoint(ckpt, _ArgsShim()) model, load_info = build_model_from_checkpoint(config, ckpt) model = model.to(device) model.eval() tokenizer = TokenizerWrapper(args.tokenizer_path) return device, config, global_step, model, tokenizer, load_info def trim_history(history, keep_turns): if keep_turns <= 0: return [] return history[-keep_turns:] def build_prompt(history, user_text, prompt_format): if prompt_format == "answer": lines = [ "Answer the question below in short, natural, coherent English.", "Do not repeat yourself, do not produce outlines, and do not drift into unrelated textbook content.", ] for old_user, old_assistant in history: lines.append(f"Question: {old_user}") lines.append(f"Short answer: {old_assistant}") lines.append(f"Question: {user_text}") lines.append("Short answer:") return "\n".join(lines) if prompt_format == "raw": if not history: return user_text parts = [] for old_user, old_assistant in history: parts.append(old_user) parts.append(old_assistant) parts.append(user_text) return "\n".join(parts) if prompt_format == "qa": lines = [] for old_user, old_assistant in history: lines.append(f"Question: {old_user}") lines.append(f"Answer: {old_assistant}") lines.append(f"Question: {user_text}") lines.append("Answer:") return "\n".join(lines) lines = [] for old_user, old_assistant in history: lines.append(f"User: {old_user}") lines.append(f"Assistant: {old_assistant}") lines.append(f"User: {user_text}") lines.append("Assistant:") return "\n".join(lines) def clean_reply(text, prompt_format): reply = text.strip() if prompt_format == "answer": for marker in ["Question:", "Short answer:", "User:", "Assistant:", "###", "Answer:"]: reply = reply.split(marker)[0].strip() elif prompt_format == "qa": reply = reply.split("Question:")[0].strip() elif prompt_format == "chat": reply = reply.split("User:")[0].strip() reply = dedupe_lines(reply) reply = strip_leaked_instructions(reply) reply = trim_to_sentences(reply, max_sentences=2) return reply.strip() def dedupe_lines(text): seen = set() cleaned = [] for raw_line in text.splitlines(): line = raw_line.strip() if not line: continue if line in seen: continue seen.add(line) cleaned.append(line) return "\n".join(cleaned) def trim_to_sentences(text, max_sentences=2): if not text: return text out = [] sentence_count = 0 for ch in text: out.append(ch) if ch in "。!?!?": sentence_count += 1 if sentence_count >= max_sentences: break trimmed = "".join(out).strip() if trimmed: return trimmed return text.strip() def strip_leaked_instructions(text): cleaned = text leaked_phrases = [ "Answer the question below in short, natural, coherent English.", "Do not repeat yourself, do not produce outlines, and do not drift into unrelated textbook content.", "Coherent answer:", "Short answer:", ] for phrase in leaked_phrases: cleaned = cleaned.replace(phrase, "").strip() return cleaned def canned_reply(user_text): normalized = user_text.strip().lower() if normalized in {"你好", "您好", "hi", "hello", "hey"}: return "Hello. You can ask a specific question and I will try to answer briefly." if normalized in {"你是谁", "你是谁?", "请介绍一下你自己", "你好,请介绍一下你自己"}: return "I am the local SymbolicLight demo script. The loaded weights are a pre-training checkpoint, so replies can run end to end but may not always be stable." if normalized in {"你在说什么", "你在说什么?", "are you ok?", "are you ok"}: return "The previous reply was unstable. This checkpoint is still a pre-training artifact, not a polished dialogue model. Try a shorter and more specific question." if normalized in {"说中文", "请说中文"}: return "I can try, but this public demo is configured primarily for brief English replies." return None def apply_repetition_penalty(logits, token_ids, penalty): if penalty <= 1.0 or not token_ids: return logits unique_ids = set(token_ids) for token_id in unique_ids: value = logits[0, token_id] logits[0, token_id] = value / penalty if value > 0 else value * penalty return logits def calc_banned_tokens(generated_ids, ngram_size): if ngram_size <= 0 or len(generated_ids) < ngram_size - 1: return set() prefix = tuple(generated_ids[-(ngram_size - 1):]) banned = set() for i in range(len(generated_ids) - ngram_size + 1): if tuple(generated_ids[i:i + ngram_size - 1]) == prefix: banned.add(generated_ids[i + ngram_size - 1]) return banned def sample_next_token(logits, top_k, top_p): if top_k > 0: top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) cutoff = top_k_vals[:, -1].unsqueeze(-1) logits = logits.masked_fill(logits < cutoff, float("-inf")) if 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_mask = cumulative_probs > top_p sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() sorted_mask[..., 0] = False removal_mask = torch.zeros_like(sorted_mask, dtype=torch.bool) removal_mask.scatter_(1, sorted_indices, sorted_mask) logits = logits.masked_fill(removal_mask, float("-inf")) probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1) def generate_reply(model, tokenizer, device, prompt, args): input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) generated = input_ids.clone() generated_reply_ids = [] past_key_values = [{} for _ in range(len(model.blocks) + 1)] eos_id = tokenizer.eos_id() logits = model.forward(input_ids, use_cache=True, past_key_values=past_key_values) for _ in range(args.max_new_tokens): next_logits = logits[:, -1, :] / max(args.temperature, 1e-5) next_logits = apply_repetition_penalty(next_logits, generated_reply_ids, args.repetition_penalty) banned_tokens = calc_banned_tokens(generated_reply_ids, args.no_repeat_ngram_size) if banned_tokens: banned_tensor = torch.tensor(sorted(banned_tokens), device=next_logits.device, dtype=torch.long) next_logits.index_fill_(1, banned_tensor, float("-inf")) next_token = sample_next_token(next_logits, args.top_k, args.top_p) token_id = next_token.item() if token_id == eos_id: break generated = torch.cat([generated, next_token], dim=1) generated_reply_ids.append(token_id) partial_text = tokenizer.decode(generated_reply_ids) if any(stop in partial_text for stop in ["Question:", "Short answer:", "User:", "Assistant:", "###", "Answer:"]): break logits = model.forward(next_token, use_cache=True, past_key_values=past_key_values) return clean_reply(tokenizer.decode(generated_reply_ids), args.prompt_format) def main(): args = parse_args() device, config, global_step, model, tokenizer, load_info = load_runtime(args) print("=" * 60) print(" SymbolicLight Interactive Chat") print("=" * 60) print(f"Checkpoint: {Path(args.checkpoint_path).resolve()}") print(f"Tokenizer: {Path(args.tokenizer_path).resolve()}") print(f"Device: {device}") print(f"Step: {global_step}") print(f"Seq len: {config.max_seq_len}") print(f"Format: {args.prompt_format}") if load_info["missing_keys"] or load_info["unexpected_keys"]: print(f"Missing keys: {len(load_info['missing_keys'])}") print(f"Unexpected keys: {len(load_info['unexpected_keys'])}") print("Type 'exit' or 'quit' to stop.") print() history = [] while True: try: user_text = input("You> ").strip() except (EOFError, KeyboardInterrupt): print("\n[Exit]") break if not user_text: continue if user_text.lower() in {"exit", "quit"}: print("[Exit]") break fallback = canned_reply(user_text) if fallback is not None: print(f"Model> {fallback}\n") history.append((user_text, fallback)) continue prompt = build_prompt(trim_history(history, args.history_turns), user_text, args.prompt_format) with torch.no_grad(): reply = generate_reply(model, tokenizer, device, prompt, args) print(f"Model> {reply}\n") history.append((user_text, reply)) if __name__ == "__main__": main()