File size: 2,302 Bytes
c0741ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn.functional as F
from model import MiniText
import random

# -----------------------
# config
# -----------------------
MODEL_PATH = "minitext.pt"
DEVICE = "cpu"

# -----------------------
# load model
# -----------------------
model = MiniText().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# -----------------------
# sampling utils
# -----------------------
def sample_logits(logits, temperature=1.0, top_k=0):
    logits = logits / temperature

    if top_k > 0:
        values, _ = torch.topk(logits, top_k)
        min_val = values[:, -1].unsqueeze(-1)
        logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits)

    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, 1).item()

# -----------------------
# text generation
# -----------------------
def generate(

    prompt="o",

    max_new_tokens=300,

    temperature=0.5,

    top_k=50,

    top_p=0.95,

    repetition_penalty=1.2,

    seed=None,

    h=None

):
    if seed is not None:
        torch.manual_seed(seed)
        random.seed(seed)

    bytes_in = list(prompt.encode("utf-8", errors="ignore"))
    output = bytes_in.copy()

    # feed prompt
    x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        _, h = model(x, h)

    last = x[:, -1:]

    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits, h = model(last, h)

        next_byte = sample_logits(
            logits[:, -1],
            temperature=temperature,
            top_k=top_k
        )

        output.append(next_byte)
        last = torch.tensor([[next_byte]], device=DEVICE)

    return bytes(output).decode(errors="ignore"), h

h = None

print("MiniText-v1.5 Chat | digite 'exit' para sair")

while True:
    user = input("usuario: ")
    if user.lower() == "quit":
        break

    prompt = f"usuario: {user}\nia: "
    text, h = generate(
        prompt=prompt,
        max_new_tokens=120,
        temperature=0.5,
        top_k=50,
        top_p=0.95,
        repetition_penalty=1.2,
        h=h
    )

    reply = text.split("ia:")[-1].strip()
    print("ia:", reply)