Arthur Samuel Galego Panucci FIgueiredo commited on
Commit
c0741ab
·
verified ·
1 Parent(s): 5231594

Upload 5 files

Browse files
Files changed (5) hide show
  1. checkpoint.pt +3 -0
  2. infer.py +97 -0
  3. minitext.pt +3 -0
  4. model.py +15 -0
  5. train.py +79 -0
checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22eb08b8bfa508f28e0d5d4e531a4c4cff7207375afe34aeab7d7787a92e198e
3
+ size 129845
infer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from model import MiniText
4
+ import random
5
+
6
+ # -----------------------
7
+ # config
8
+ # -----------------------
9
+ MODEL_PATH = "minitext.pt"
10
+ DEVICE = "cpu"
11
+
12
+ # -----------------------
13
+ # load model
14
+ # -----------------------
15
+ model = MiniText().to(DEVICE)
16
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
17
+ model.eval()
18
+
19
+ # -----------------------
20
+ # sampling utils
21
+ # -----------------------
22
+ def sample_logits(logits, temperature=1.0, top_k=0):
23
+ logits = logits / temperature
24
+
25
+ if top_k > 0:
26
+ values, _ = torch.topk(logits, top_k)
27
+ min_val = values[:, -1].unsqueeze(-1)
28
+ logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits)
29
+
30
+ probs = F.softmax(logits, dim=-1)
31
+ return torch.multinomial(probs, 1).item()
32
+
33
+ # -----------------------
34
+ # text generation
35
+ # -----------------------
36
+ def generate(
37
+ prompt="o",
38
+ max_new_tokens=300,
39
+ temperature=0.5,
40
+ top_k=50,
41
+ top_p=0.95,
42
+ repetition_penalty=1.2,
43
+ seed=None,
44
+ h=None
45
+ ):
46
+ if seed is not None:
47
+ torch.manual_seed(seed)
48
+ random.seed(seed)
49
+
50
+ bytes_in = list(prompt.encode("utf-8", errors="ignore"))
51
+ output = bytes_in.copy()
52
+
53
+ # feed prompt
54
+ x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE)
55
+ with torch.no_grad():
56
+ _, h = model(x, h)
57
+
58
+ last = x[:, -1:]
59
+
60
+ for _ in range(max_new_tokens):
61
+ with torch.no_grad():
62
+ logits, h = model(last, h)
63
+
64
+ next_byte = sample_logits(
65
+ logits[:, -1],
66
+ temperature=temperature,
67
+ top_k=top_k
68
+ )
69
+
70
+ output.append(next_byte)
71
+ last = torch.tensor([[next_byte]], device=DEVICE)
72
+
73
+ return bytes(output).decode(errors="ignore"), h
74
+
75
+ h = None
76
+
77
+ print("MiniText-v1.5 Chat | digite 'exit' para sair")
78
+
79
+ while True:
80
+ user = input("usuario: ")
81
+ if user.lower() == "quit":
82
+ break
83
+
84
+ prompt = f"usuario: {user}\nia: "
85
+ text, h = generate(
86
+ prompt=prompt,
87
+ max_new_tokens=120,
88
+ temperature=0.5,
89
+ top_k=50,
90
+ top_p=0.95,
91
+ repetition_penalty=1.2,
92
+ h=h
93
+ )
94
+
95
+ reply = text.split("ia:")[-1].strip()
96
+ print("ia:", reply)
97
+
minitext.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23e5a67e5738c67a9dbf49dc1e26bf43d0ff224330863415561fff082c42c41a
3
+ size 43614
model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MiniText(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.embed = nn.Embedding(256, 16)
8
+ self.gru = nn.GRU(16, 16, batch_first=True)
9
+ self.fc = nn.Linear(16, 256)
10
+
11
+ def forward(self, x, h=None):
12
+ x = self.embed(x)
13
+ out, h = self.gru(x, h)
14
+ logits = self.fc(out)
15
+ return logits, h
train.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ from model import MiniText
5
+
6
+ # -----------------------
7
+ # hiperparâmetros
8
+ # -----------------------
9
+ SEQ_LEN = 64
10
+ EPOCHS = 12000
11
+ LR = 1e-4
12
+ SAVE_EVERY = 2000 # salva checkpoint a cada X epochs
13
+ CHECKPOINT_PATH = "checkpoint.pt"
14
+
15
+
16
+ # -----------------------
17
+ # dataset
18
+ # -----------------------
19
+ with open("dataset.txt", "rb") as f:
20
+ data = torch.tensor(list(f.read()), dtype=torch.long)
21
+
22
+ # -----------------------
23
+ # model + optimizer
24
+ # -----------------------
25
+ model = MiniText()
26
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR)
27
+ loss_fn = nn.CrossEntropyLoss()
28
+
29
+ start_epoch = 0
30
+
31
+ # -----------------------
32
+ # load checkpoint (se existir)
33
+ # -----------------------
34
+ if os.path.exists(CHECKPOINT_PATH):
35
+ print("Checkpoint encontrado, retomando treino...")
36
+ checkpoint = torch.load(CHECKPOINT_PATH)
37
+ model.load_state_dict(checkpoint["model"])
38
+ optimizer.load_state_dict(checkpoint["optimizer"])
39
+ start_epoch = checkpoint["epoch"] + 1
40
+ else:
41
+ print("Nenhum checkpoint encontrado, treino do zero.")
42
+
43
+ # -----------------------
44
+ # batch sampler
45
+ # -----------------------
46
+ def get_batch():
47
+ idx = torch.randint(0, len(data) - SEQ_LEN - 1, (1,))
48
+ x = data[idx:idx + SEQ_LEN].unsqueeze(0)
49
+ y = data[idx + 1:idx + SEQ_LEN + 1].unsqueeze(0)
50
+ return x, y
51
+
52
+ # -----------------------
53
+ # training loop
54
+ # -----------------------
55
+ for epoch in range(start_epoch, EPOCHS):
56
+ x, y = get_batch()
57
+ logits, _ = model(x)
58
+ loss = loss_fn(logits.view(-1, 256), y.view(-1))
59
+
60
+ optimizer.zero_grad()
61
+ loss.backward()
62
+ optimizer.step()
63
+
64
+ print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss.item():.4f}")
65
+
66
+ # salvar checkpoint
67
+ if (epoch + 1) % SAVE_EVERY == 0:
68
+ torch.save({
69
+ "epoch": epoch,
70
+ "model": model.state_dict(),
71
+ "optimizer": optimizer.state_dict()
72
+ }, CHECKPOINT_PATH)
73
+ print("Checkpoint salvo.")
74
+
75
+ # -----------------------
76
+ # salvar modelo final
77
+ # -----------------------
78
+ torch.save(model.state_dict(), "minitext.pt")
79
+ print("Treino finalizado. Modelo salvo em minitext.pt")