"""Interactive/sample generation with the RepoBridge-style SSM inference fix. This intentionally overrides the checkpoint config at inference time: - ssm_finite_tail_correction = True - ssm_kernel_mode = recurrent Those settings match the temporary chat-quality fix used in RepoBridge Model Chat. """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path ROOT = Path(__file__).resolve().parent TAOTRAIN_SRC = ROOT / "code" / "TaoTrain" / "src" SSM_SRC = ROOT / "code" / "Taotern_SSM" for path in (TAOTRAIN_SRC, SSM_SRC): if str(path) not in sys.path: sys.path.insert(0, str(path)) import torch from taoTrain.checkpointing.checkpoint import CheckpointManager from taoTrain.config import ModelConfig from taoTrain.inference.inferencer import Inferencer from taoTrain.models import get_model def apply_ssm_overrides(model: torch.nn.Module, *, kernel_mode: str, finite_tail: bool) -> int: count = 0 for module in model.modules(): changed = False if hasattr(module, "kernel_mode"): module.kernel_mode = kernel_mode changed = True if hasattr(module, "finite_tail_correction"): module.finite_tail_correction = finite_tail changed = True clear = getattr(module, "clear_kernel_cache", None) if callable(clear): clear() if changed: count += 1 return count def load_fixed(checkpoint_path: Path, tokenizer_path: Path, device: torch.device, dtype: torch.dtype): checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device) config_dict = checkpoint.get("config", {}) model_config_dict = dict(config_dict.get("model", {})) model_config_dict["ssm_finite_tail_correction"] = True model_config_dict["ssm_kernel_mode"] = "recurrent" model_config = ModelConfig(**model_config_dict) tokenizer = Inferencer._load_tokenizer(tokenizer_path) model = get_model(model_config, device=device) model.load_state_dict(checkpoint["model_state"], strict=False) model.to(device=device) model.eval() override_count = apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True) return model, tokenizer, override_count def generate( model, tokenizer, prompt: str, *, device: torch.device, dtype: torch.dtype, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, greedy: bool, ) -> str: input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) prompt_len = input_ids.shape[1] generated_ids: list[int] = [] eos_token_id = getattr(tokenizer, "eos_token_id", None) device_type = "cuda" if device.type == "cuda" else "cpu" autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled): for _ in range(max_new_tokens): apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True) outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None) logits = outputs["logits"][:, -1, :] if not greedy: logits = logits / max(temperature, 1e-6) if repetition_penalty != 1.0: for token_id in torch.unique(input_ids[0, prompt_len:]): logits[0, token_id] /= repetition_penalty if greedy: next_token = torch.argmax(logits, dim=-1, keepdim=True) else: if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumulative = torch.cumsum(sorted_probs, dim=-1) remove = cumulative > top_p remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False indices_to_remove = sorted_indices[remove] logits[0, indices_to_remove] = float("-inf") probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) token_id = int(next_token.item()) if eos_token_id is not None and token_id == eos_token_id: break generated_ids.append(token_id) input_ids = torch.cat([input_ids, next_token], dim=-1) apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True) return tokenizer.decode(generated_ids, skip_special_tokens=True) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default=str(ROOT / "model" / "pretrain_final_model.pt")) parser.add_argument("--tokenizer", default=str(ROOT / "tokenizer" / "tokenizer.model")) parser.add_argument("--device", default="cuda") parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16") parser.add_argument("--max-new-tokens", type=int, default=64) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top-p", type=float, default=0.85) parser.add_argument("--repetition-penalty", type=float, default=1.2) parser.add_argument("--decode", choices=["greedy", "sample"], default="greedy") parser.add_argument("--prompt", action="append", default=[]) parser.add_argument("--output", default=str(ROOT / "artifacts" / "local_test_samples_ssm_fixed.json")) parser.add_argument("--interactive", action="store_true") args = parser.parse_args() checkpoint_path = Path(args.checkpoint) if not checkpoint_path.exists() and checkpoint_path.name == "pretrain_final_model.pt": checkpoint_path = ROOT / "model" / "final_model.pt" tokenizer_path = Path(args.tokenizer) device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu") dtype = { "float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, }[args.dtype] print(f"Loading checkpoint: {checkpoint_path}") print("SSM fix: ssm_finite_tail_correction=true, ssm_kernel_mode=recurrent") model, tokenizer, override_count = load_fixed(checkpoint_path, tokenizer_path, device, dtype) print(f"device={device}") print(f"ssm_overrides={override_count}") if args.interactive: print("Type 'quit' or 'exit' to stop.") while True: prompt = input("\nYou: ").strip() if prompt.lower() in {"quit", "exit"}: break if not prompt: continue start = time.time() completion = generate( model, tokenizer, prompt, device=device, dtype=dtype, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, repetition_penalty=args.repetition_penalty, greedy=args.decode == "greedy", ) elapsed = time.time() - start print(f"\nAssistant: {completion}") print(f"\n[{elapsed:.1f}s]") return prompts = args.prompt or [ "Fruit is now expensive so we should", "Hello, who are you?", "Explain what artificial intelligence is in simple words.", ] samples = [] for prompt in prompts: completion = generate( model, tokenizer, prompt, device=device, dtype=dtype, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, repetition_penalty=args.repetition_penalty, greedy=args.decode == "greedy", ) samples.append({"prompt": prompt, "completion": completion}) result = { "checkpoint": str(checkpoint_path), "tokenizer": str(tokenizer_path), "device": str(device), "dtype": str(dtype), "ssm_finite_tail_correction": True, "ssm_kernel_mode": "recurrent", "ssm_overrides": override_count, "decode": args.decode, "temperature": args.temperature, "top_p": args.top_p, "repetition_penalty": args.repetition_penalty, "max_new_tokens": args.max_new_tokens, "samples": samples, } output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8") print(json.dumps(result, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()