TaoNet-mini-T2 / code /TaoTrain /scripts /diagnostics /generate_checkpoint_samples.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Generate a few text samples from a saved checkpoint."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
from taoTrain.checkpointing.checkpoint import CheckpointManager
from taoTrain.config import ModelConfig
from taoTrain.inference.inferencer import Inferencer
from taoTrain.models import get_model
def clear_kernel_caches(model) -> None:
for module in model.modules():
clear = getattr(module, "clear_kernel_cache", None)
if callable(clear):
clear()
def generate_once(
model,
tokenizer,
prompt: str,
*,
device: torch.device,
max_new_tokens: int,
temperature: float,
top_p: float,
dtype: torch.dtype,
) -> str:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated = []
eos_token_id = getattr(tokenizer, "eos_token_id", None)
model.eval()
device_type = "cuda" if device.type == "cuda" else "cpu"
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
for _ in range(max_new_tokens):
clear_kernel_caches(model)
outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None)
logits = outputs["logits"][:, -1, :] / max(temperature, 1e-6)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative = torch.cumsum(sorted_probs, dim=-1)
remove = cumulative > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
indices_to_remove = sorted_indices[remove]
logits[0, indices_to_remove] = float("-inf")
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_id = int(next_token.item())
if eos_token_id is not None and token_id == eos_token_id:
break
generated.append(token_id)
input_ids = torch.cat([input_ids, next_token], dim=-1)
clear_kernel_caches(model)
return tokenizer.decode(generated, skip_special_tokens=True)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--tokenizer-path", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--prompt", action="append", default=[])
parser.add_argument("--max-new-tokens", type=int, default=80)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
args = parser.parse_args()
prompts = args.prompt or [
"The purpose of artificial intelligence is",
"In a small village,",
"<user>Hello, who are you?<assistant>",
]
device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
dtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[args.dtype]
tokenizer = Inferencer._load_tokenizer(args.tokenizer_path)
checkpoint_path = Path(args.checkpoint)
checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {}))
model = get_model(model_config, device=device)
model.load_state_dict(checkpoint["model_state"], strict=False)
samples = []
for prompt in prompts:
text = generate_once(
model,
tokenizer,
prompt,
device=device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
dtype=dtype,
)
samples.append({"prompt": prompt, "completion": text})
result = {
"checkpoint": args.checkpoint,
"tokenizer_path": args.tokenizer_path,
"device": str(device),
"dtype": str(dtype),
"max_new_tokens": args.max_new_tokens,
"temperature": args.temperature,
"top_p": args.top_p,
"samples": samples,
}
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8")
print(json.dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()