bigram-nano-1 / sample_inference.py
lehungquangminh's picture
Add Bigram Nano 1 safetensors release
eff2702 verified
import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from bigram import BigramModel, BigramTokenizer
from bigram.config import ModelConfig
def load_bigram_nano_1(repo_id="aevynt/bigram-nano-1", device="cpu"):
model_dir = snapshot_download(repo_id)
with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
model_cfg = ModelConfig(**cfg["model"])
model = BigramModel(model_cfg)
model.load_state_dict(load_file(f"{model_dir}/model.safetensors"))
model.to(device)
model.eval()
tokenizer = BigramTokenizer.load(f"{model_dir}/tokenizer.json")
return model, tokenizer
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=80, recurrence=4, temperature=0.7, top_k=20, device="cpu"):
token_ids, tone_ids = tokenizer.encode(prompt, add_special=False)
bos = tokenizer.token_to_id("<bos>")
token_ids = [bos] + token_ids
tone_ids = [0] + tone_ids
prompt_len = len(token_ids)
token_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
tone_ids = torch.tensor([tone_ids], dtype=torch.long, device=device)
out_ids, out_tones, _ = model.generate(
token_ids,
tone_ids,
max_new_tokens=max_new_tokens,
num_recurrence=recurrence,
temperature=temperature,
top_k=top_k,
)
ids = out_ids[0, prompt_len:].tolist()
tones = out_tones[0, prompt_len:].tolist()
eos = tokenizer.token_to_id("<eos>")
if eos in ids:
cut = ids.index(eos)
ids, tones = ids[:cut], tones[:cut]
return tokenizer.decode(ids, tones).strip()
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
model, tokenizer = load_bigram_nano_1(device=device)
for prompt in ["xin ch?o!", "b?n l? ai?", "m?y ?n c?m ch?a?", "t?m bi?t"]:
print("Prompt:", prompt)
print("Response:", generate_text(model, tokenizer, prompt, device=device))
print()