TaoNet-mini-T2 / chat_ssm_fixed.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Interactive/sample generation with the RepoBridge-style SSM inference fix.
This intentionally overrides the checkpoint config at inference time:
- ssm_finite_tail_correction = True
- ssm_kernel_mode = recurrent
Those settings match the temporary chat-quality fix used in RepoBridge Model Chat.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parent
TAOTRAIN_SRC = ROOT / "code" / "TaoTrain" / "src"
SSM_SRC = ROOT / "code" / "Taotern_SSM"
for path in (TAOTRAIN_SRC, SSM_SRC):
if str(path) not in sys.path:
sys.path.insert(0, str(path))
import torch
from taoTrain.checkpointing.checkpoint import CheckpointManager
from taoTrain.config import ModelConfig
from taoTrain.inference.inferencer import Inferencer
from taoTrain.models import get_model
def apply_ssm_overrides(model: torch.nn.Module, *, kernel_mode: str, finite_tail: bool) -> int:
count = 0
for module in model.modules():
changed = False
if hasattr(module, "kernel_mode"):
module.kernel_mode = kernel_mode
changed = True
if hasattr(module, "finite_tail_correction"):
module.finite_tail_correction = finite_tail
changed = True
clear = getattr(module, "clear_kernel_cache", None)
if callable(clear):
clear()
if changed:
count += 1
return count
def load_fixed(checkpoint_path: Path, tokenizer_path: Path, device: torch.device, dtype: torch.dtype):
checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
config_dict = checkpoint.get("config", {})
model_config_dict = dict(config_dict.get("model", {}))
model_config_dict["ssm_finite_tail_correction"] = True
model_config_dict["ssm_kernel_mode"] = "recurrent"
model_config = ModelConfig(**model_config_dict)
tokenizer = Inferencer._load_tokenizer(tokenizer_path)
model = get_model(model_config, device=device)
model.load_state_dict(checkpoint["model_state"], strict=False)
model.to(device=device)
model.eval()
override_count = apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True)
return model, tokenizer, override_count
def generate(
model,
tokenizer,
prompt: str,
*,
device: torch.device,
dtype: torch.dtype,
max_new_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: float,
greedy: bool,
) -> str:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
prompt_len = input_ids.shape[1]
generated_ids: list[int] = []
eos_token_id = getattr(tokenizer, "eos_token_id", None)
device_type = "cuda" if device.type == "cuda" else "cpu"
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
for _ in range(max_new_tokens):
apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True)
outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None)
logits = outputs["logits"][:, -1, :]
if not greedy:
logits = logits / max(temperature, 1e-6)
if repetition_penalty != 1.0:
for token_id in torch.unique(input_ids[0, prompt_len:]):
logits[0, token_id] /= repetition_penalty
if greedy:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
else:
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative = torch.cumsum(sorted_probs, dim=-1)
remove = cumulative > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
indices_to_remove = sorted_indices[remove]
logits[0, indices_to_remove] = float("-inf")
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_id = int(next_token.item())
if eos_token_id is not None and token_id == eos_token_id:
break
generated_ids.append(token_id)
input_ids = torch.cat([input_ids, next_token], dim=-1)
apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True)
return tokenizer.decode(generated_ids, skip_special_tokens=True)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default=str(ROOT / "model" / "pretrain_final_model.pt"))
parser.add_argument("--tokenizer", default=str(ROOT / "tokenizer" / "tokenizer.model"))
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
parser.add_argument("--max-new-tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top-p", type=float, default=0.85)
parser.add_argument("--repetition-penalty", type=float, default=1.2)
parser.add_argument("--decode", choices=["greedy", "sample"], default="greedy")
parser.add_argument("--prompt", action="append", default=[])
parser.add_argument("--output", default=str(ROOT / "artifacts" / "local_test_samples_ssm_fixed.json"))
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists() and checkpoint_path.name == "pretrain_final_model.pt":
checkpoint_path = ROOT / "model" / "final_model.pt"
tokenizer_path = Path(args.tokenizer)
device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
dtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[args.dtype]
print(f"Loading checkpoint: {checkpoint_path}")
print("SSM fix: ssm_finite_tail_correction=true, ssm_kernel_mode=recurrent")
model, tokenizer, override_count = load_fixed(checkpoint_path, tokenizer_path, device, dtype)
print(f"device={device}")
print(f"ssm_overrides={override_count}")
if args.interactive:
print("Type 'quit' or 'exit' to stop.")
while True:
prompt = input("\nYou: ").strip()
if prompt.lower() in {"quit", "exit"}:
break
if not prompt:
continue
start = time.time()
completion = generate(
model,
tokenizer,
prompt,
device=device,
dtype=dtype,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
greedy=args.decode == "greedy",
)
elapsed = time.time() - start
print(f"\nAssistant: {completion}")
print(f"\n[{elapsed:.1f}s]")
return
prompts = args.prompt or [
"Fruit is now expensive so we should",
"<user>Hello, who are you?<assistant>",
"<user>Explain what artificial intelligence is in simple words.<assistant>",
]
samples = []
for prompt in prompts:
completion = generate(
model,
tokenizer,
prompt,
device=device,
dtype=dtype,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
greedy=args.decode == "greedy",
)
samples.append({"prompt": prompt, "completion": completion})
result = {
"checkpoint": str(checkpoint_path),
"tokenizer": str(tokenizer_path),
"device": str(device),
"dtype": str(dtype),
"ssm_finite_tail_correction": True,
"ssm_kernel_mode": "recurrent",
"ssm_overrides": override_count,
"decode": args.decode,
"temperature": args.temperature,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"max_new_tokens": args.max_new_tokens,
"samples": samples,
}
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8")
print(json.dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()