import torch import torch.nn.functional as F from model import MiniText import random # ----------------------- # config # ----------------------- MODEL_PATH = "minitext.pt" DEVICE = "cpu" # ----------------------- # load model # ----------------------- model = MiniText().to(DEVICE) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.eval() # ----------------------- # sampling utils # ----------------------- def sample_logits(logits, temperature=1.0, top_k=0): logits = logits / temperature if top_k > 0: values, _ = torch.topk(logits, top_k) min_val = values[:, -1].unsqueeze(-1) logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits) probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, 1).item() # ----------------------- # text generation # ----------------------- def generate( prompt="o", max_new_tokens=300, temperature=0.5, top_k=50, top_p=0.95, repetition_penalty=1.2, seed=None, h=None ): if seed is not None: torch.manual_seed(seed) random.seed(seed) bytes_in = list(prompt.encode("utf-8", errors="ignore")) output = bytes_in.copy() # feed prompt x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE) with torch.no_grad(): _, h = model(x, h) last = x[:, -1:] for _ in range(max_new_tokens): with torch.no_grad(): logits, h = model(last, h) next_byte = sample_logits( logits[:, -1], temperature=temperature, top_k=top_k ) output.append(next_byte) last = torch.tensor([[next_byte]], device=DEVICE) return bytes(output).decode(errors="ignore"), h h = None print("MiniText-v1.5 Chat | digite 'exit' para sair") while True: user = input("usuario: ") if user.lower() == "quit": break prompt = f"usuario: {user}\nia: " text, h = generate( prompt=prompt, max_new_tokens=120, temperature=0.5, top_k=50, top_p=0.95, repetition_penalty=1.2, h=h ) reply = text.split("ia:")[-1].strip() print("ia:", reply)