Arthur Samuel Galego Panucci FIgueiredo
commited on
Upload 5 files
Browse files- checkpoint.pt +3 -0
- infer.py +97 -0
- minitext.pt +3 -0
- model.py +15 -0
- train.py +79 -0
checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22eb08b8bfa508f28e0d5d4e531a4c4cff7207375afe34aeab7d7787a92e198e
|
| 3 |
+
size 129845
|
infer.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from model import MiniText
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
# -----------------------
|
| 7 |
+
# config
|
| 8 |
+
# -----------------------
|
| 9 |
+
MODEL_PATH = "minitext.pt"
|
| 10 |
+
DEVICE = "cpu"
|
| 11 |
+
|
| 12 |
+
# -----------------------
|
| 13 |
+
# load model
|
| 14 |
+
# -----------------------
|
| 15 |
+
model = MiniText().to(DEVICE)
|
| 16 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| 17 |
+
model.eval()
|
| 18 |
+
|
| 19 |
+
# -----------------------
|
| 20 |
+
# sampling utils
|
| 21 |
+
# -----------------------
|
| 22 |
+
def sample_logits(logits, temperature=1.0, top_k=0):
|
| 23 |
+
logits = logits / temperature
|
| 24 |
+
|
| 25 |
+
if top_k > 0:
|
| 26 |
+
values, _ = torch.topk(logits, top_k)
|
| 27 |
+
min_val = values[:, -1].unsqueeze(-1)
|
| 28 |
+
logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits)
|
| 29 |
+
|
| 30 |
+
probs = F.softmax(logits, dim=-1)
|
| 31 |
+
return torch.multinomial(probs, 1).item()
|
| 32 |
+
|
| 33 |
+
# -----------------------
|
| 34 |
+
# text generation
|
| 35 |
+
# -----------------------
|
| 36 |
+
def generate(
|
| 37 |
+
prompt="o",
|
| 38 |
+
max_new_tokens=300,
|
| 39 |
+
temperature=0.5,
|
| 40 |
+
top_k=50,
|
| 41 |
+
top_p=0.95,
|
| 42 |
+
repetition_penalty=1.2,
|
| 43 |
+
seed=None,
|
| 44 |
+
h=None
|
| 45 |
+
):
|
| 46 |
+
if seed is not None:
|
| 47 |
+
torch.manual_seed(seed)
|
| 48 |
+
random.seed(seed)
|
| 49 |
+
|
| 50 |
+
bytes_in = list(prompt.encode("utf-8", errors="ignore"))
|
| 51 |
+
output = bytes_in.copy()
|
| 52 |
+
|
| 53 |
+
# feed prompt
|
| 54 |
+
x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
_, h = model(x, h)
|
| 57 |
+
|
| 58 |
+
last = x[:, -1:]
|
| 59 |
+
|
| 60 |
+
for _ in range(max_new_tokens):
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
logits, h = model(last, h)
|
| 63 |
+
|
| 64 |
+
next_byte = sample_logits(
|
| 65 |
+
logits[:, -1],
|
| 66 |
+
temperature=temperature,
|
| 67 |
+
top_k=top_k
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
output.append(next_byte)
|
| 71 |
+
last = torch.tensor([[next_byte]], device=DEVICE)
|
| 72 |
+
|
| 73 |
+
return bytes(output).decode(errors="ignore"), h
|
| 74 |
+
|
| 75 |
+
h = None
|
| 76 |
+
|
| 77 |
+
print("MiniText-v1.5 Chat | digite 'exit' para sair")
|
| 78 |
+
|
| 79 |
+
while True:
|
| 80 |
+
user = input("usuario: ")
|
| 81 |
+
if user.lower() == "quit":
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
prompt = f"usuario: {user}\nia: "
|
| 85 |
+
text, h = generate(
|
| 86 |
+
prompt=prompt,
|
| 87 |
+
max_new_tokens=120,
|
| 88 |
+
temperature=0.5,
|
| 89 |
+
top_k=50,
|
| 90 |
+
top_p=0.95,
|
| 91 |
+
repetition_penalty=1.2,
|
| 92 |
+
h=h
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
reply = text.split("ia:")[-1].strip()
|
| 96 |
+
print("ia:", reply)
|
| 97 |
+
|
minitext.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:23e5a67e5738c67a9dbf49dc1e26bf43d0ff224330863415561fff082c42c41a
|
| 3 |
+
size 43614
|
model.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class MiniText(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.embed = nn.Embedding(256, 16)
|
| 8 |
+
self.gru = nn.GRU(16, 16, batch_first=True)
|
| 9 |
+
self.fc = nn.Linear(16, 256)
|
| 10 |
+
|
| 11 |
+
def forward(self, x, h=None):
|
| 12 |
+
x = self.embed(x)
|
| 13 |
+
out, h = self.gru(x, h)
|
| 14 |
+
logits = self.fc(out)
|
| 15 |
+
return logits, h
|
train.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
from model import MiniText
|
| 5 |
+
|
| 6 |
+
# -----------------------
|
| 7 |
+
# hiperparâmetros
|
| 8 |
+
# -----------------------
|
| 9 |
+
SEQ_LEN = 64
|
| 10 |
+
EPOCHS = 12000
|
| 11 |
+
LR = 1e-4
|
| 12 |
+
SAVE_EVERY = 2000 # salva checkpoint a cada X epochs
|
| 13 |
+
CHECKPOINT_PATH = "checkpoint.pt"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# -----------------------
|
| 17 |
+
# dataset
|
| 18 |
+
# -----------------------
|
| 19 |
+
with open("dataset.txt", "rb") as f:
|
| 20 |
+
data = torch.tensor(list(f.read()), dtype=torch.long)
|
| 21 |
+
|
| 22 |
+
# -----------------------
|
| 23 |
+
# model + optimizer
|
| 24 |
+
# -----------------------
|
| 25 |
+
model = MiniText()
|
| 26 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
| 27 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 28 |
+
|
| 29 |
+
start_epoch = 0
|
| 30 |
+
|
| 31 |
+
# -----------------------
|
| 32 |
+
# load checkpoint (se existir)
|
| 33 |
+
# -----------------------
|
| 34 |
+
if os.path.exists(CHECKPOINT_PATH):
|
| 35 |
+
print("Checkpoint encontrado, retomando treino...")
|
| 36 |
+
checkpoint = torch.load(CHECKPOINT_PATH)
|
| 37 |
+
model.load_state_dict(checkpoint["model"])
|
| 38 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 39 |
+
start_epoch = checkpoint["epoch"] + 1
|
| 40 |
+
else:
|
| 41 |
+
print("Nenhum checkpoint encontrado, treino do zero.")
|
| 42 |
+
|
| 43 |
+
# -----------------------
|
| 44 |
+
# batch sampler
|
| 45 |
+
# -----------------------
|
| 46 |
+
def get_batch():
|
| 47 |
+
idx = torch.randint(0, len(data) - SEQ_LEN - 1, (1,))
|
| 48 |
+
x = data[idx:idx + SEQ_LEN].unsqueeze(0)
|
| 49 |
+
y = data[idx + 1:idx + SEQ_LEN + 1].unsqueeze(0)
|
| 50 |
+
return x, y
|
| 51 |
+
|
| 52 |
+
# -----------------------
|
| 53 |
+
# training loop
|
| 54 |
+
# -----------------------
|
| 55 |
+
for epoch in range(start_epoch, EPOCHS):
|
| 56 |
+
x, y = get_batch()
|
| 57 |
+
logits, _ = model(x)
|
| 58 |
+
loss = loss_fn(logits.view(-1, 256), y.view(-1))
|
| 59 |
+
|
| 60 |
+
optimizer.zero_grad()
|
| 61 |
+
loss.backward()
|
| 62 |
+
optimizer.step()
|
| 63 |
+
|
| 64 |
+
print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss.item():.4f}")
|
| 65 |
+
|
| 66 |
+
# salvar checkpoint
|
| 67 |
+
if (epoch + 1) % SAVE_EVERY == 0:
|
| 68 |
+
torch.save({
|
| 69 |
+
"epoch": epoch,
|
| 70 |
+
"model": model.state_dict(),
|
| 71 |
+
"optimizer": optimizer.state_dict()
|
| 72 |
+
}, CHECKPOINT_PATH)
|
| 73 |
+
print("Checkpoint salvo.")
|
| 74 |
+
|
| 75 |
+
# -----------------------
|
| 76 |
+
# salvar modelo final
|
| 77 |
+
# -----------------------
|
| 78 |
+
torch.save(model.state_dict(), "minitext.pt")
|
| 79 |
+
print("Treino finalizado. Modelo salvo em minitext.pt")
|