File size: 2,086 Bytes
eff2702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

from bigram import BigramModel, BigramTokenizer
from bigram.config import ModelConfig


def load_bigram_nano_1(repo_id="aevynt/bigram-nano-1", device="cpu"):
    model_dir = snapshot_download(repo_id)
    with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
        cfg = json.load(f)
    model_cfg = ModelConfig(**cfg["model"])
    model = BigramModel(model_cfg)
    model.load_state_dict(load_file(f"{model_dir}/model.safetensors"))
    model.to(device)
    model.eval()
    tokenizer = BigramTokenizer.load(f"{model_dir}/tokenizer.json")
    return model, tokenizer


@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=80, recurrence=4, temperature=0.7, top_k=20, device="cpu"):
    token_ids, tone_ids = tokenizer.encode(prompt, add_special=False)
    bos = tokenizer.token_to_id("<bos>")
    token_ids = [bos] + token_ids
    tone_ids = [0] + tone_ids
    prompt_len = len(token_ids)
    token_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
    tone_ids = torch.tensor([tone_ids], dtype=torch.long, device=device)
    out_ids, out_tones, _ = model.generate(
        token_ids,
        tone_ids,
        max_new_tokens=max_new_tokens,
        num_recurrence=recurrence,
        temperature=temperature,
        top_k=top_k,
    )
    ids = out_ids[0, prompt_len:].tolist()
    tones = out_tones[0, prompt_len:].tolist()
    eos = tokenizer.token_to_id("<eos>")
    if eos in ids:
        cut = ids.index(eos)
        ids, tones = ids[:cut], tones[:cut]
    return tokenizer.decode(ids, tones).strip()


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, tokenizer = load_bigram_nano_1(device=device)
    for prompt in ["xin ch?o!", "b?n l? ai?", "m?y ?n c?m ch?a?", "t?m bi?t"]:
        print("Prompt:", prompt)
        print("Response:", generate_text(model, tokenizer, prompt, device=device))
        print()