File size: 2,302 Bytes
c0741ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
import torch.nn.functional as F
from model import MiniText
import random
# -----------------------
# config
# -----------------------
MODEL_PATH = "minitext.pt"
DEVICE = "cpu"
# -----------------------
# load model
# -----------------------
model = MiniText().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
# -----------------------
# sampling utils
# -----------------------
def sample_logits(logits, temperature=1.0, top_k=0):
logits = logits / temperature
if top_k > 0:
values, _ = torch.topk(logits, top_k)
min_val = values[:, -1].unsqueeze(-1)
logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1).item()
# -----------------------
# text generation
# -----------------------
def generate(
prompt="o",
max_new_tokens=300,
temperature=0.5,
top_k=50,
top_p=0.95,
repetition_penalty=1.2,
seed=None,
h=None
):
if seed is not None:
torch.manual_seed(seed)
random.seed(seed)
bytes_in = list(prompt.encode("utf-8", errors="ignore"))
output = bytes_in.copy()
# feed prompt
x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE)
with torch.no_grad():
_, h = model(x, h)
last = x[:, -1:]
for _ in range(max_new_tokens):
with torch.no_grad():
logits, h = model(last, h)
next_byte = sample_logits(
logits[:, -1],
temperature=temperature,
top_k=top_k
)
output.append(next_byte)
last = torch.tensor([[next_byte]], device=DEVICE)
return bytes(output).decode(errors="ignore"), h
h = None
print("MiniText-v1.5 Chat | digite 'exit' para sair")
while True:
user = input("usuario: ")
if user.lower() == "quit":
break
prompt = f"usuario: {user}\nia: "
text, h = generate(
prompt=prompt,
max_new_tokens=120,
temperature=0.5,
top_k=50,
top_p=0.95,
repetition_penalty=1.2,
h=h
)
reply = text.split("ia:")[-1].strip()
print("ia:", reply)
|