import json import torch from huggingface_hub import snapshot_download from safetensors.torch import load_file from bigram import BigramModel, BigramTokenizer from bigram.config import ModelConfig def load_bigram_nano_1(repo_id="aevynt/bigram-nano-1", device="cpu"): model_dir = snapshot_download(repo_id) with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f: cfg = json.load(f) model_cfg = ModelConfig(**cfg["model"]) model = BigramModel(model_cfg) model.load_state_dict(load_file(f"{model_dir}/model.safetensors")) model.to(device) model.eval() tokenizer = BigramTokenizer.load(f"{model_dir}/tokenizer.json") return model, tokenizer @torch.no_grad() def generate_text(model, tokenizer, prompt, max_new_tokens=80, recurrence=4, temperature=0.7, top_k=20, device="cpu"): token_ids, tone_ids = tokenizer.encode(prompt, add_special=False) bos = tokenizer.token_to_id("") token_ids = [bos] + token_ids tone_ids = [0] + tone_ids prompt_len = len(token_ids) token_ids = torch.tensor([token_ids], dtype=torch.long, device=device) tone_ids = torch.tensor([tone_ids], dtype=torch.long, device=device) out_ids, out_tones, _ = model.generate( token_ids, tone_ids, max_new_tokens=max_new_tokens, num_recurrence=recurrence, temperature=temperature, top_k=top_k, ) ids = out_ids[0, prompt_len:].tolist() tones = out_tones[0, prompt_len:].tolist() eos = tokenizer.token_to_id("") if eos in ids: cut = ids.index(eos) ids, tones = ids[:cut], tones[:cut] return tokenizer.decode(ids, tones).strip() if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer = load_bigram_nano_1(device=device) for prompt in ["xin ch?o!", "b?n l? ai?", "m?y ?n c?m ch?a?", "t?m bi?t"]: print("Prompt:", prompt) print("Response:", generate_text(model, tokenizer, prompt, device=device)) print()