| import json
|
| import torch
|
| from huggingface_hub import snapshot_download
|
| from safetensors.torch import load_file
|
|
|
| from bigram import BigramModel, BigramTokenizer
|
| from bigram.config import ModelConfig
|
|
|
|
|
| def load_bigram_nano_1(repo_id="aevynt/bigram-nano-1", device="cpu"):
|
| model_dir = snapshot_download(repo_id)
|
| with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
|
| cfg = json.load(f)
|
| model_cfg = ModelConfig(**cfg["model"])
|
| model = BigramModel(model_cfg)
|
| model.load_state_dict(load_file(f"{model_dir}/model.safetensors"))
|
| model.to(device)
|
| model.eval()
|
| tokenizer = BigramTokenizer.load(f"{model_dir}/tokenizer.json")
|
| return model, tokenizer
|
|
|
|
|
| @torch.no_grad()
|
| def generate_text(model, tokenizer, prompt, max_new_tokens=80, recurrence=4, temperature=0.7, top_k=20, device="cpu"):
|
| token_ids, tone_ids = tokenizer.encode(prompt, add_special=False)
|
| bos = tokenizer.token_to_id("<bos>")
|
| token_ids = [bos] + token_ids
|
| tone_ids = [0] + tone_ids
|
| prompt_len = len(token_ids)
|
| token_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
|
| tone_ids = torch.tensor([tone_ids], dtype=torch.long, device=device)
|
| out_ids, out_tones, _ = model.generate(
|
| token_ids,
|
| tone_ids,
|
| max_new_tokens=max_new_tokens,
|
| num_recurrence=recurrence,
|
| temperature=temperature,
|
| top_k=top_k,
|
| )
|
| ids = out_ids[0, prompt_len:].tolist()
|
| tones = out_tones[0, prompt_len:].tolist()
|
| eos = tokenizer.token_to_id("<eos>")
|
| if eos in ids:
|
| cut = ids.index(eos)
|
| ids, tones = ids[:cut], tones[:cut]
|
| return tokenizer.decode(ids, tones).strip()
|
|
|
|
|
| if __name__ == "__main__":
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| model, tokenizer = load_bigram_nano_1(device=device)
|
| for prompt in ["xin ch?o!", "b?n l? ai?", "m?y ?n c?m ch?a?", "t?m bi?t"]:
|
| print("Prompt:", prompt)
|
| print("Response:", generate_text(model, tokenizer, prompt, device=device))
|
| print()
|
|
|