|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from model import MiniText
|
|
|
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "minitext.pt"
|
|
|
DEVICE = "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = MiniText().to(DEVICE)
|
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_logits(logits, temperature=1.0, top_k=0):
|
|
|
logits = logits / temperature
|
|
|
|
|
|
if top_k > 0:
|
|
|
values, _ = torch.topk(logits, top_k)
|
|
|
min_val = values[:, -1].unsqueeze(-1)
|
|
|
logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits)
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
return torch.multinomial(probs, 1).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
prompt="o",
|
|
|
max_new_tokens=300,
|
|
|
temperature=0.5,
|
|
|
top_k=50,
|
|
|
top_p=0.95,
|
|
|
repetition_penalty=1.2,
|
|
|
seed=None,
|
|
|
h=None
|
|
|
):
|
|
|
if seed is not None:
|
|
|
torch.manual_seed(seed)
|
|
|
random.seed(seed)
|
|
|
|
|
|
bytes_in = list(prompt.encode("utf-8", errors="ignore"))
|
|
|
output = bytes_in.copy()
|
|
|
|
|
|
|
|
|
x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE)
|
|
|
with torch.no_grad():
|
|
|
_, h = model(x, h)
|
|
|
|
|
|
last = x[:, -1:]
|
|
|
|
|
|
for _ in range(max_new_tokens):
|
|
|
with torch.no_grad():
|
|
|
logits, h = model(last, h)
|
|
|
|
|
|
next_byte = sample_logits(
|
|
|
logits[:, -1],
|
|
|
temperature=temperature,
|
|
|
top_k=top_k
|
|
|
)
|
|
|
|
|
|
output.append(next_byte)
|
|
|
last = torch.tensor([[next_byte]], device=DEVICE)
|
|
|
|
|
|
return bytes(output).decode(errors="ignore"), h
|
|
|
|
|
|
h = None
|
|
|
|
|
|
print("MiniText-v1.5 Chat | digite 'exit' para sair")
|
|
|
|
|
|
while True:
|
|
|
user = input("usuario: ")
|
|
|
if user.lower() == "quit":
|
|
|
break
|
|
|
|
|
|
prompt = f"usuario: {user}\nia: "
|
|
|
text, h = generate(
|
|
|
prompt=prompt,
|
|
|
max_new_tokens=120,
|
|
|
temperature=0.5,
|
|
|
top_k=50,
|
|
|
top_p=0.95,
|
|
|
repetition_penalty=1.2,
|
|
|
h=h
|
|
|
)
|
|
|
|
|
|
reply = text.split("ia:")[-1].strip()
|
|
|
print("ia:", reply)
|
|
|
|
|
|
|