| | import torch
|
| | import torch.nn.functional as F
|
| | from model import MiniText
|
| | import random
|
| |
|
| |
|
| |
|
| |
|
| | MODEL_PATH = "minitext.pt"
|
| | DEVICE = "cpu"
|
| |
|
| |
|
| |
|
| |
|
| | model = MiniText().to(DEVICE)
|
| | model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| | model.eval()
|
| |
|
| |
|
| |
|
| |
|
| | def sample_logits(logits, temperature=1.0, top_k=0):
|
| | logits = logits / temperature
|
| |
|
| | if top_k > 0:
|
| | values, _ = torch.topk(logits, top_k)
|
| | min_val = values[:, -1].unsqueeze(-1)
|
| | logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits)
|
| |
|
| | probs = F.softmax(logits, dim=-1)
|
| | return torch.multinomial(probs, 1).item()
|
| |
|
| |
|
| |
|
| |
|
| | def generate(
|
| | prompt="o",
|
| | max_new_tokens=300,
|
| | temperature=0.5,
|
| | top_k=50,
|
| | top_p=0.95,
|
| | repetition_penalty=1.2,
|
| | seed=None,
|
| | h=None
|
| | ):
|
| | if seed is not None:
|
| | torch.manual_seed(seed)
|
| | random.seed(seed)
|
| |
|
| | bytes_in = list(prompt.encode("utf-8", errors="ignore"))
|
| | output = bytes_in.copy()
|
| |
|
| |
|
| | x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE)
|
| | with torch.no_grad():
|
| | _, h = model(x, h)
|
| |
|
| | last = x[:, -1:]
|
| |
|
| | for _ in range(max_new_tokens):
|
| | with torch.no_grad():
|
| | logits, h = model(last, h)
|
| |
|
| | next_byte = sample_logits(
|
| | logits[:, -1],
|
| | temperature=temperature,
|
| | top_k=top_k
|
| | )
|
| |
|
| | output.append(next_byte)
|
| | last = torch.tensor([[next_byte]], device=DEVICE)
|
| |
|
| | return bytes(output).decode(errors="ignore"), h
|
| |
|
| | h = None
|
| |
|
| | print("MiniText-v1.5 Chat | digite 'exit' para sair")
|
| |
|
| | while True:
|
| | user = input("usuario: ")
|
| | if user.lower() == "quit":
|
| | break
|
| |
|
| | prompt = f"usuario: {user}\nia: "
|
| | text, h = generate(
|
| | prompt=prompt,
|
| | max_new_tokens=120,
|
| | temperature=0.5,
|
| | top_k=50,
|
| | top_p=0.95,
|
| | repetition_penalty=1.2,
|
| | h=h
|
| | )
|
| |
|
| | reply = text.split("ia:")[-1].strip()
|
| | print("ia:", reply)
|
| |
|
| |
|