frc 10252 commited on
Commit
3905c4a
·
1 Parent(s): 45a8534
README.md CHANGED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Chatbot
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: they call me sam altman
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__pycache__/test-llm.cpython-313.pyc ADDED
Binary file (23.5 kB). View file
 
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from components.model import GPTModel
5
+ from components.tokenizer import encode, decode, tokenizer
6
+
7
+
8
+ # -----------------------------
9
+ # Load model & configuration
10
+ # -----------------------------
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Hyperparameters should match training
14
+ block_size = 128
15
+ n_layers = 16
16
+ n_heads = 8
17
+ dropout_p = 0.1
18
+ n_embedding = 256
19
+
20
+ # initialize model and load weights
21
+ vocab_size = tokenizer.n_vocab
22
+ model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(
23
+ device
24
+ )
25
+ model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth", map_location=device))
26
+ model.eval()
27
+
28
+
29
+ # -----------------------------
30
+ # Generation function
31
+ # -----------------------------
32
+ @torch.no_grad()
33
+ def generate_text(prompt, max_new_tokens=200, temperature=1.0, top_k=50):
34
+ model.eval()
35
+
36
+ # Wrap message in [INST] and [/INST]
37
+ wrapped_prompt = f"[INST] {prompt.strip()} [/INST]"
38
+ tokens = (
39
+ torch.tensor(encode(wrapped_prompt), dtype=torch.long).unsqueeze(0).to(device)
40
+ )
41
+
42
+ inst_token_id = encode("[INST]")[0]
43
+
44
+ for _ in range(max_new_tokens):
45
+ input_tokens = tokens[:, -block_size:]
46
+ logits = model(input_tokens)
47
+ logits = logits[:, -1, :] / temperature
48
+
49
+ if top_k is not None:
50
+ values, indices = torch.topk(logits, top_k)
51
+ logits[logits < values[:, [-1]]] = -float("Inf")
52
+
53
+ probs = F.softmax(logits, dim=-1)
54
+ next_token = torch.multinomial(probs, num_samples=1)
55
+
56
+ # Stop generation if [INST] appears again (do not include it)
57
+ if next_token.item() == inst_token_id:
58
+ break
59
+
60
+ tokens = torch.cat((tokens, next_token), dim=1)
61
+
62
+ return decode(tokens[0].tolist())[len(wrapped_prompt) :]
63
+
64
+
65
+ # -----------------------------
66
+ # Gradio UI
67
+ # -----------------------------
68
+ def chat(prompt, max_tokens, temperature, top_k):
69
+ response = generate_text(prompt, max_tokens, temperature, top_k)
70
+ return response
71
+
72
+
73
+ with gr.Blocks(title="TinyChat GPT Model") as demo:
74
+ gr.Markdown("## cute lil chatbot")
75
+
76
+ with gr.Row():
77
+ with gr.Column(scale=2):
78
+ prompt = gr.Textbox(
79
+ label="Prompt", placeholder="Type your message here...", lines=4
80
+ )
81
+ max_tokens = gr.Slider(10, 500, value=200, step=10, label="Max New Tokens")
82
+ temperature = gr.Slider(0.2, 1.5, value=1.0, step=0.1, label="Temperature")
83
+ top_k = gr.Slider(10, 200, value=50, step=10, label="Top‑K Sampling")
84
+ submit = gr.Button("Generate")
85
+
86
+ with gr.Column(scale=3):
87
+ output = gr.Textbox(label="Generated Response", lines=15)
88
+
89
+ submit.click(chat, inputs=[prompt, max_tokens, temperature, top_k], outputs=output)
90
+
91
+ # -----------------------------
92
+ # Launch app
93
+ # -----------------------------
94
+ if __name__ == "__main__":
95
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
components/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (2.83 kB). View file
 
components/__pycache__/model.cpython-312.pyc ADDED
Binary file (6.99 kB). View file
 
components/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (1.58 kB). View file
 
components/dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import math, time, os
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import tiktoken
7
+
8
+ # from torch.cuda.amp import autocast, GradScaler
9
+ from torch.amp.autocast_mode import autocast
10
+ from torch.amp.grad_scaler import GradScaler
11
+ from tqdm import tqdm
12
+
13
+ from datasets import load_dataset
14
+ from components.model import GPTModel
15
+ from components.tokenizer import encode, decode, tokenizer
16
+
17
+
18
+ def decode(tokens):
19
+ return tokenizer.decode(tokens)
20
+
21
+ class TextDataset(Dataset):
22
+ def __init__(self, hf_dataset, block_size):
23
+ self.dataset = hf_dataset
24
+ # self.tokenizer = tokenizer
25
+ self.block_size = block_size
26
+
27
+ def __len__(self):
28
+ return len(self.dataset["train"])
29
+
30
+ def __getitem__(self, idx):
31
+ # Start with a random index sample
32
+ rand_idx = torch.randint(0, len(self.dataset["train"]), (1,)).item()
33
+ text = self.dataset["train"][rand_idx]["text"]
34
+ tokens = encode(text)
35
+
36
+ # Keep appending more samples if too short
37
+ while len(tokens) < self.block_size + 1:
38
+ next_idx = torch.randint(0, len(self.dataset["train"]), (1,)).item()
39
+ next_text = self.dataset["train"][next_idx]["text"]
40
+ tokens.extend(encode(" " + next_text))
41
+ # Prevent runaway growth
42
+ if len(tokens) > self.block_size * 2:
43
+ break
44
+
45
+ # Truncate to block_size + 1
46
+ tokens = torch.tensor(tokens[: self.block_size + 1])
47
+
48
+ x = tokens[: self.block_size]
49
+ y = tokens[1 : self.block_size + 1]
50
+ return x.long(), y.long()
components/model.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import math, time, os
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import tiktoken
9
+
10
+ # from torch.cuda.amp import autocast, GradScaler
11
+ from torch.amp.autocast_mode import autocast
12
+ from torch.amp.grad_scaler import GradScaler
13
+ from tqdm import tqdm
14
+
15
+ from datasets import load_dataset
16
+
17
+
18
+ # class GPTModel(nn.Module):
19
+ # def __init__(
20
+ # self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size
21
+ # ):
22
+ # super(GPTModel, self).__init__()
23
+ # self.token_embedding = nn.Embedding(vocab_size, n_embedding)
24
+ # self.position_embedding = nn.Embedding(block_size, n_embedding)
25
+ # self.layers = nn.ModuleList(
26
+ # [
27
+ # nn.TransformerEncoderLayer(
28
+ # d_model=n_embedding, nhead=n_heads, dropout=dropout_p
29
+ # )
30
+ # for _ in range(n_layers)
31
+ # ]
32
+ # )
33
+ # self.ln_f = nn.LayerNorm(n_embedding)
34
+ # self.head = nn.Linear(n_embedding, vocab_size)
35
+ # self.dropout = nn.Dropout(dropout_p)
36
+ # self.block_size = block_size
37
+
38
+ # def forward(self, x):
39
+ # bsz, seq_len = x.size()
40
+ # positions = (
41
+ # torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
42
+ # )
43
+ # x = self.token_embedding(x) + self.position_embedding(positions)
44
+ # x = self.dropout(x)
45
+
46
+ # for layer in self.layers:
47
+ # x = layer(x)
48
+
49
+ # x = self.ln_f(x)
50
+ # logits = self.head(x)
51
+ # return logits
52
+
53
+ import torch
54
+ import torch.nn as nn
55
+ import torch.nn.functional as F
56
+ import math
57
+
58
+
59
+ # ... existing imports ...
60
+
61
+
62
+ class MultiHeadAttention(nn.Module):
63
+ def __init__(self, n_embedding, n_heads, dropout_p):
64
+ super().__init__()
65
+ assert n_embedding % n_heads == 0
66
+ self.n_heads = n_heads
67
+ self.head_dim = n_embedding // n_heads
68
+
69
+ self.q_proj = nn.Linear(n_embedding, n_embedding)
70
+ self.k_proj = nn.Linear(n_embedding, n_embedding)
71
+ self.v_proj = nn.Linear(n_embedding, n_embedding)
72
+ self.out_proj = nn.Linear(n_embedding, n_embedding)
73
+ self.dropout = nn.Dropout(dropout_p)
74
+
75
+ def forward(self, x, attn_mask=None):
76
+ B, T, C = x.shape # batch size, seq length, embedding dim
77
+
78
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
79
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
80
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
81
+ # c, embd dim split into n_heads x head_dim
82
+
83
+ # built-in scaled dot product attention for efficiency
84
+ attn_out = F.scaled_dot_product_attention(
85
+ q, k, v,
86
+ attn_mask=attn_mask,
87
+ dropout_p=self.dropout.p if self.training else 0.0,
88
+ is_causal=True,
89
+ )
90
+
91
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C)
92
+ return self.out_proj(attn_out)
93
+
94
+
95
+ class FeedForward(nn.Module):
96
+ def __init__(self, n_embedding, dropout_p):
97
+ super().__init__()
98
+ self.net = nn.Sequential(
99
+ nn.Linear(n_embedding, 4 * n_embedding),
100
+ nn.GELU(),
101
+ nn.Linear(4 * n_embedding, n_embedding),
102
+ nn.Dropout(dropout_p),
103
+ )
104
+
105
+ def forward(self, x):
106
+ return self.net(x)
107
+
108
+
109
+ class TransformerBlock(nn.Module):
110
+ def __init__(self, n_embedding, n_heads, dropout_p):
111
+ super().__init__()
112
+ self.ln1 = nn.LayerNorm(n_embedding)
113
+ self.ln2 = nn.LayerNorm(n_embedding)
114
+ self.attn = MultiHeadAttention(n_embedding, n_heads, dropout_p)
115
+ self.ff = FeedForward(n_embedding, dropout_p)
116
+
117
+ def forward(self, x, attn_mask=None):
118
+ x = x + self.attn(self.ln1(x), attn_mask)
119
+ x = x + self.ff(self.ln2(x))
120
+ return x
121
+
122
+
123
+ class GPTModel(nn.Module):
124
+ def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):
125
+ super().__init__()
126
+ self.token_embed = nn.Embedding(vocab_size, n_embedding)
127
+ self.pos_embed = nn.Embedding(block_size, n_embedding)
128
+ self.blocks = nn.ModuleList([
129
+ TransformerBlock(n_embedding, n_heads, dropout_p)
130
+ for _ in range(n_layers)
131
+ ])
132
+ self.ln_f = nn.LayerNorm(n_embedding)
133
+ self.head = nn.Linear(n_embedding, vocab_size, bias=False)
134
+ self.dropout = nn.Dropout(dropout_p)
135
+ self.block_size = block_size
136
+
137
+ def forward(self, idx):
138
+ B, T = idx.shape
139
+ assert T <= self.block_size, "Sequence exceeds block size."
140
+
141
+ pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
142
+ x = self.token_embed(idx) + self.pos_embed(pos)
143
+ x = self.dropout(x)
144
+
145
+ # Causal mask for decoder: prevent attending to future tokens
146
+ attn_mask = torch.ones(T, T, device=idx.device, dtype=torch.bool).tril()
147
+
148
+ for block in self.blocks:
149
+ x = block(x, attn_mask)
150
+
151
+ x = self.ln_f(x)
152
+ logits = self.head(x)
153
+ return logits
components/tokenizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import math, time, os
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import tiktoken
8
+
9
+ # from torch.cuda.amp import autocast, GradScaler
10
+ from torch.amp.autocast_mode import autocast
11
+ from torch.amp.grad_scaler import GradScaler
12
+ from tqdm import tqdm
13
+
14
+ from datasets import load_dataset
15
+ from components.model import GPTModel
16
+
17
+
18
+ tokenizer = tiktoken.get_encoding("gpt2")
19
+
20
+ base_encoding = tiktoken.get_encoding("gpt2")
21
+
22
+ special_tokens = {
23
+ "[INST]": base_encoding.n_vocab, # next available token id
24
+ "[/INST]": base_encoding.n_vocab + 1,
25
+ }
26
+
27
+ # 3. Create a new encoding that merges GPT‑2’s tokens + your special tokens
28
+ tokenizer = tiktoken.Encoding(
29
+ name="gpt2_with_inst",
30
+ pat_str=base_encoding._pat_str,
31
+ mergeable_ranks=base_encoding._mergeable_ranks,
32
+ special_tokens={**base_encoding._special_tokens, **special_tokens},
33
+ )
34
+
35
+
36
+ def encode(text):
37
+ return tokenizer.encode(text, allowed_special={"[INST]", "[/INST]"})
38
+
39
+
40
+ def decode(tokens):
41
+ return tokenizer.decode(tokens)
old/rl_test.ipynb ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "158eaa47",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "from torch.nn import functional as F\n",
13
+ "import math, time, os\n",
14
+ "from torch.utils.data import Dataset, DataLoader\n",
15
+ "import tiktoken\n",
16
+ "# from torch.cuda.amp import autocast, GradScaler\n",
17
+ "from torch.amp.autocast_mode import autocast\n",
18
+ "from torch.amp.grad_scaler import GradScaler"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "97d9467e",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/home/software/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
32
+ " from .autonotebook import tqdm as notebook_tqdm\n"
33
+ ]
34
+ },
35
+ {
36
+ "name": "stdout",
37
+ "output_type": "stream",
38
+ "text": [
39
+ "Her campaign emailed a fundraising pitch Tuesday evening warning of the dangers of a Trump presidency and of complacency among Democrats.\n",
40
+ "{'text': \"Canonical, keeper of the Ubuntu Linux distribution, is a small company with big friends. The latest example: Dell, IBM and Intel each are taking new steps with Ubuntu. Here's the scoop.\"}\n"
41
+ ]
42
+ }
43
+ ],
44
+ "source": [
45
+ "from datasets import load_dataset\n",
46
+ "\n",
47
+ "# dataset = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n",
48
+ "dataset = load_dataset(\"Bingsu/openwebtext_20p\")\n",
49
+ "ds = load_dataset(\"starhopp3r/TinyChat\", split=\"train\")\n",
50
+ "# This gives you cleaned, plain text articles1\n",
51
+ "print(dataset['train'][100]['text'][:500]) # Print the first 500 characters of the first article\n",
52
+ "print(dataset['train'][600000])"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 3,
58
+ "id": "81b98c54",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "class TextDataset(Dataset):\n",
63
+ " def __init__(self, hf_dataset, tokenizer, block_size):\n",
64
+ " self.dataset = hf_dataset\n",
65
+ " self.tokenizer = tokenizer\n",
66
+ " self.block_size = block_size\n",
67
+ "\n",
68
+ " def __len__(self):\n",
69
+ " return len(self.dataset['train'])\n",
70
+ "\n",
71
+ " # def __getitem__(self, idx):\n",
72
+ " # tokens = self.tokenizer.encode(self.dataset['train'][idx]['text'])\n",
73
+ " # if len(tokens) < self.block_size + 1:\n",
74
+ " # tokens = F.pad(torch.tensor(tokens), (0, self.block_size + 1 - len(tokens)), value=0)\n",
75
+ " # else:\n",
76
+ " # tokens = torch.tensor(tokens[: self.block_size + 1])\n",
77
+ " # x = tokens[: self.block_size]\n",
78
+ " # y = tokens[1 : self.block_size + 1]\n",
79
+ " # return x.long(), y.long()\n",
80
+ " def __getitem__(self, idx):\n",
81
+ " # choose a random index instead of using the passed idx\n",
82
+ " rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()\n",
83
+ " tokens = self.tokenizer.encode(self.dataset['train'][rand_idx]['text'])\n",
84
+ "\n",
85
+ " if len(tokens) < self.block_size + 1:\n",
86
+ " tokens = F.pad(torch.tensor(tokens), (0, self.block_size + 1 - len(tokens)), value=0)\n",
87
+ " else:\n",
88
+ " tokens = torch.tensor(tokens[: self.block_size + 1])\n",
89
+ "\n",
90
+ " x = tokens[: self.block_size]\n",
91
+ " y = tokens[1 : self.block_size + 1]\n",
92
+ " return x.long(), y.long()\n",
93
+ "\n",
94
+ "import torch\n",
95
+ "from torch.utils.data import Dataset\n",
96
+ "from datasets import load_dataset\n",
97
+ "import re\n",
98
+ "\n",
99
+ "class ChatDataset(Dataset):\n",
100
+ " def __init__(self, tokenizer, split=\"train\", block_size=256, dataset_name=\"starhopp3r/TinyChat\"):\n",
101
+ " \"\"\"\n",
102
+ " Args:\n",
103
+ " tokenizer: a tokenizer (e.g., tiktoken or Hugging Face tokenizer)\n",
104
+ " split: dataset split (\"train\" etc)\n",
105
+ " block_size: maximum sequence length\n",
106
+ " dataset_name: path/name of the Hugging Face dataset\n",
107
+ " \"\"\"\n",
108
+ " self.dataset = load_dataset(dataset_name, split=split)\n",
109
+ " self.tokenizer = tokenizer\n",
110
+ " self.block_size = block_size\n",
111
+ "\n",
112
+ " def __len__(self):\n",
113
+ " return len(self.dataset)\n",
114
+ "\n",
115
+ " def __getitem__(self, idx):\n",
116
+ " sample = self.dataset[idx]\n",
117
+ " text = sample[\"text\"]\n",
118
+ "\n",
119
+ " # --- split into prompt and response (TinyChat uses [INST] ... [/INST]) ---\n",
120
+ " match = re.search(r\"\\[INST\\](.*?)\\[/INST\\](.*)\", text, re.DOTALL)\n",
121
+ " if match:\n",
122
+ " instruction = match.group(1).strip()\n",
123
+ " response = match.group(2).strip()\n",
124
+ " else:\n",
125
+ " instruction = text.strip()\n",
126
+ " response = \"\"\n",
127
+ "\n",
128
+ " # Combine into a training sequence\n",
129
+ " combined_text = f\"<inst> {instruction} </inst> {response}\"\n",
130
+ "\n",
131
+ " # Tokenize (truncate/pad to block_size + 1)\n",
132
+ " tokens = torch.tensor(self.tokenizer.encode(combined_text), dtype=torch.long)\n",
133
+ " if len(tokens) < self.block_size + 1:\n",
134
+ " pad_len = self.block_size + 1 - len(tokens)\n",
135
+ " tokens = F.pad(tokens, (0, pad_len), value=0)\n",
136
+ " else:\n",
137
+ " tokens = tokens[: self.block_size + 1]\n",
138
+ "\n",
139
+ " x = tokens[:-1]\n",
140
+ " y = tokens[1:]\n",
141
+ "\n",
142
+ " return x, y"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 4,
148
+ "id": "599aa05a",
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "#hyperparameters\n",
153
+ "train_model = True\n",
154
+ "compile_model = True\n",
155
+ "block_size = 256\n",
156
+ "n_layers = 32\n",
157
+ "n_heads = 16\n",
158
+ "dropout_p = 0.1\n",
159
+ "batch_size =16\n",
160
+ "learning_rate = 3e-4\n",
161
+ "n_embedding = 512\n",
162
+ "max_iters = 1000\n",
163
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 5,
169
+ "id": "a69561e9",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
174
+ "\n",
175
+ "train_dataset = TextDataset(dataset, tokenizer, block_size=block_size)\n",
176
+ "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)\n",
177
+ "\n",
178
+ "chat_dataset = ChatDataset(tokenizer, split=\"train\", block_size=block_size)\n",
179
+ "chat_dataloader = DataLoader(chat_dataset, batch_size=batch_size, shuffle=True, drop_last=True)"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 6,
185
+ "id": "ea5598ea",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "class GPTModel(nn.Module):\n",
190
+ " def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):\n",
191
+ " super(GPTModel, self).__init__()\n",
192
+ " self.token_embedding = nn.Embedding(vocab_size, n_embedding)\n",
193
+ " self.position_embedding = nn.Embedding(block_size, n_embedding)\n",
194
+ " self.layers = nn.ModuleList([\n",
195
+ " nn.TransformerEncoderLayer(d_model=n_embedding, nhead=n_heads, dropout=dropout_p)\n",
196
+ " for _ in range(n_layers)\n",
197
+ " ])\n",
198
+ " self.ln_f = nn.LayerNorm(n_embedding)\n",
199
+ " self.head = nn.Linear(n_embedding, vocab_size)\n",
200
+ " self.dropout = nn.Dropout(dropout_p)\n",
201
+ " self.block_size = block_size\n",
202
+ "\n",
203
+ " def forward(self, x):\n",
204
+ " bsz, seq_len = x.size()\n",
205
+ " positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)\n",
206
+ " x = self.token_embedding(x) + self.position_embedding(positions)\n",
207
+ " x = self.dropout(x)\n",
208
+ "\n",
209
+ " for layer in self.layers:\n",
210
+ " x = layer(x)\n",
211
+ "\n",
212
+ " x = self.ln_f(x)\n",
213
+ " logits = self.head(x)\n",
214
+ " return logits"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 7,
220
+ "id": "6a1344ab",
221
+ "metadata": {},
222
+ "outputs": [],
223
+ "source": [
224
+ "# define objects\n",
225
+ "vocab_size = tokenizer.n_vocab\n",
226
+ "\n",
227
+ "model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)\n",
228
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
229
+ "loss_fn = nn.CrossEntropyLoss()"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 8,
235
+ "id": "a0982489",
236
+ "metadata": {},
237
+ "outputs": [
238
+ {
239
+ "name": "stderr",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "/home/software/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)\n",
243
+ " _C._set_float32_matmul_precision(precision)\n",
244
+ "Training: 100%|████████████████████████████████████| 1000/1000 [05:06<00:00, 3.26it/s, loss=1.5960]\n"
245
+ ]
246
+ }
247
+ ],
248
+ "source": [
249
+ "\n",
250
+ "from tqdm import tqdm\n",
251
+ "\n",
252
+ "# training loop\n",
253
+ "scaler = GradScaler(device)\n",
254
+ "if train_model:\n",
255
+ " if compile_model:\n",
256
+ " compiled_model = torch.compile(model)\n",
257
+ " torch.set_float32_matmul_precision('high')\n",
258
+ " else:\n",
259
+ " compiled_model = model\n",
260
+ "\n",
261
+ " pbar = tqdm(range(max_iters), desc=\"Training\", ncols=100)\n",
262
+ " data_iter = iter(train_dataloader)\n",
263
+ " chat_data_iter = iter(chat_dataloader)\n",
264
+ "\n",
265
+ " for count in pbar:\n",
266
+ " try:\n",
267
+ " if count %2 ==0:\n",
268
+ " xb, yb = next(chat_data_iter)\n",
269
+ " else:\n",
270
+ " xb, yb = next(data_iter)\n",
271
+ " except StopIteration:\n",
272
+ " break # dataloader exhausted before max_iters\n",
273
+ " \n",
274
+ " xb, yb = xb.to(device), yb.to(device)\n",
275
+ " # logits = compiled_model(xb)\n",
276
+ " # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))\n",
277
+ "\n",
278
+ " # optimizer.zero_grad()\n",
279
+ " # loss.backward()\n",
280
+ " # optimizer.step()\n",
281
+ " with autocast(device, dtype=torch.float16):\n",
282
+ " logits = compiled_model(xb)\n",
283
+ " loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))\n",
284
+ "\n",
285
+ " # backward pass with gradient scaling\n",
286
+ " optimizer.zero_grad()\n",
287
+ " scaler.scale(loss).backward()\n",
288
+ " scaler.step(optimizer)\n",
289
+ " scaler.update()\n",
290
+ "\n",
291
+ " # update bar text dynamically\n",
292
+ " pbar.set_postfix({\"loss\": f\"{loss.item():.4f}\"})"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 9,
298
+ "id": "6eb95580",
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "if train_model:\n",
303
+ " torch.save(model.state_dict(), \"checkpoints/gpt_model-2.pth\")\n",
304
+ "else:\n",
305
+ " model.load_state_dict(torch.load(\"checkpoints/gpt_model-2.pth\"))"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": 10,
311
+ "id": "4371725d",
312
+ "metadata": {},
313
+ "outputs": [
314
+ {
315
+ "name": "stdout",
316
+ "output_type": "stream",
317
+ "text": [
318
+ "me when the .!!!!!! understand!!!!] cold! especially characters!! used soon!!!!! world! Exactly]-INST!!! choices! feel! spread!! a!! impact]inst saw them\n"
319
+ ]
320
+ }
321
+ ],
322
+ "source": [
323
+ "@torch.no_grad()\n",
324
+ "def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):\n",
325
+ " model.eval()\n",
326
+ " # Encode the prompt text into token IDs\n",
327
+ " tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)\n",
328
+ "\n",
329
+ " for _ in range(max_new_tokens):\n",
330
+ " # Only keep the last block_size tokens for context\n",
331
+ " input_tokens = tokens[:, -block_size:]\n",
332
+ "\n",
333
+ " # Get logits and take the last token’s distribution\n",
334
+ " logits = model(input_tokens)\n",
335
+ " logits = logits[:, -1, :] # (batch=1, vocab)\n",
336
+ " probs = F.softmax(logits, dim=-1)\n",
337
+ "\n",
338
+ " # Sample from the distribution\n",
339
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
340
+ " tokens = torch.cat((tokens, next_token), dim=1)\n",
341
+ "\n",
342
+ " # Decode back into text\n",
343
+ " output_text = tokenizer.decode(tokens[0].tolist())\n",
344
+ " return output_text\n",
345
+ " \n",
346
+ "prompt = \"me when the \"\n",
347
+ "print(generate_text(model, tokenizer, prompt, max_new_tokens=50, block_size=block_size, device=device))"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": 11,
353
+ "id": "d9c83f71",
354
+ "metadata": {},
355
+ "outputs": [
356
+ {
357
+ "name": "stderr",
358
+ "output_type": "stream",
359
+ "text": [
360
+ " 0%| | 1/500 [00:01<12:16, 1.48s/it]"
361
+ ]
362
+ },
363
+ {
364
+ "name": "stdout",
365
+ "output_type": "stream",
366
+ "text": [
367
+ "\n",
368
+ "Step 0: reward=0.00\n",
369
+ "Generated:\n",
370
+ "Hello:!!!!!!! that [ it everywhere [ </!!! impact!! not!!!!!!! past! un!,. to especially! explanation now! colorful, more!>!!!]! [/\n",
371
+ "\n"
372
+ ]
373
+ },
374
+ {
375
+ "name": "stderr",
376
+ "output_type": "stream",
377
+ "text": [
378
+ " 10%|▉ | 49/500 [01:09<10:43, 1.43s/it]\n"
379
+ ]
380
+ },
381
+ {
382
+ "ename": "KeyboardInterrupt",
383
+ "evalue": "",
384
+ "output_type": "error",
385
+ "traceback": [
386
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
387
+ "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
388
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[11]\u001b[39m\u001b[32m, line 50\u001b[39m\n\u001b[32m 47\u001b[39m prompt = sample.get(\u001b[33m\"\u001b[39m\u001b[33mprompt\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m sample.get(\u001b[33m\"\u001b[39m\u001b[33minput\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mHello:\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;66;03m# 2. Generate text and token logprobs\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m text, logprob_sum = \u001b[43mgenerate_with_logprobs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblock_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;66;03m# 3. Compute reward\u001b[39;00m\n\u001b[32m 53\u001b[39m r = torch.tensor(reward_fn(text), device=device)\n",
389
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[11]\u001b[39m\u001b[32m, line 28\u001b[39m, in \u001b[36mgenerate_with_logprobs\u001b[39m\u001b[34m(model, tokenizer, prompt, block_size, max_new_tokens)\u001b[39m\n\u001b[32m 26\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_new_tokens):\n\u001b[32m 27\u001b[39m input_tokens = tokens[:, -block_size:]\n\u001b[32m---> \u001b[39m\u001b[32m28\u001b[39m logits = \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 29\u001b[39m logits = logits[:, -\u001b[32m1\u001b[39m, :] \u001b[38;5;66;03m# (1, vocab)\u001b[39;00m\n\u001b[32m 30\u001b[39m probs = F.softmax(logits, dim=-\u001b[32m1\u001b[39m)\n",
390
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1773\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1775\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
391
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1781\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1784\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1786\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1788\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1789\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
392
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 22\u001b[39m, in \u001b[36mGPTModel.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 19\u001b[39m x = \u001b[38;5;28mself\u001b[39m.dropout(x)\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.layers:\n\u001b[32m---> \u001b[39m\u001b[32m22\u001b[39m x = \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 24\u001b[39m x = \u001b[38;5;28mself\u001b[39m.ln_f(x)\n\u001b[32m 25\u001b[39m logits = \u001b[38;5;28mself\u001b[39m.head(x)\n",
393
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1773\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1775\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
394
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1781\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1784\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1786\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1788\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1789\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
395
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/transformer.py:935\u001b[39m, in \u001b[36mTransformerEncoderLayer.forward\u001b[39m\u001b[34m(self, src, src_mask, src_key_padding_mask, is_causal)\u001b[39m\n\u001b[32m 931\u001b[39m x = x + \u001b[38;5;28mself\u001b[39m._ff_block(\u001b[38;5;28mself\u001b[39m.norm2(x))\n\u001b[32m 932\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 933\u001b[39m x = \u001b[38;5;28mself\u001b[39m.norm1(\n\u001b[32m 934\u001b[39m x\n\u001b[32m--> \u001b[39m\u001b[32m935\u001b[39m + \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_sa_block\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[43m=\u001b[49m\u001b[43mis_causal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 936\u001b[39m )\n\u001b[32m 937\u001b[39m x = \u001b[38;5;28mself\u001b[39m.norm2(x + \u001b[38;5;28mself\u001b[39m._ff_block(x))\n\u001b[32m 939\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
396
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/transformer.py:949\u001b[39m, in \u001b[36mTransformerEncoderLayer._sa_block\u001b[39m\u001b[34m(self, x, attn_mask, key_padding_mask, is_causal)\u001b[39m\n\u001b[32m 942\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_sa_block\u001b[39m(\n\u001b[32m 943\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 944\u001b[39m x: Tensor,\n\u001b[32m (...)\u001b[39m\u001b[32m 947\u001b[39m is_causal: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m 948\u001b[39m ) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m949\u001b[39m x = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 950\u001b[39m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 951\u001b[39m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 952\u001b[39m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 953\u001b[39m \u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 954\u001b[39m \u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 955\u001b[39m \u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 956\u001b[39m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[43m=\u001b[49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 957\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\u001b[32m0\u001b[39m]\n\u001b[32m 958\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.dropout1(x)\n",
397
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1773\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1775\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
398
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1781\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1784\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1786\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1788\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1789\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
399
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/modules/activation.py:1488\u001b[39m, in \u001b[36mMultiheadAttention.forward\u001b[39m\u001b[34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[39m\n\u001b[32m 1462\u001b[39m attn_output, attn_output_weights = F.multi_head_attention_forward(\n\u001b[32m 1463\u001b[39m query,\n\u001b[32m 1464\u001b[39m key,\n\u001b[32m (...)\u001b[39m\u001b[32m 1485\u001b[39m is_causal=is_causal,\n\u001b[32m 1486\u001b[39m )\n\u001b[32m 1487\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1488\u001b[39m attn_output, attn_output_weights = \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmulti_head_attention_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1489\u001b[39m \u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1490\u001b[39m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1491\u001b[39m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1492\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43membed_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1493\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mnum_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1494\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1495\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43min_proj_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1496\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias_k\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1497\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias_v\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1498\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43madd_zero_attn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1499\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1500\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mout_proj\u001b[49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1501\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mout_proj\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1502\u001b[39m \u001b[43m \u001b[49m\u001b[43mtraining\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1503\u001b[39m \u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1504\u001b[39m \u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[43m=\u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1505\u001b[39m \u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1506\u001b[39m \u001b[43m \u001b[49m\u001b[43maverage_attn_weights\u001b[49m\u001b[43m=\u001b[49m\u001b[43maverage_attn_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1507\u001b[39m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[43m=\u001b[49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1508\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1509\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.batch_first \u001b[38;5;129;01mand\u001b[39;00m is_batched:\n\u001b[32m 1510\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m attn_output.transpose(\u001b[32m1\u001b[39m, \u001b[32m0\u001b[39m), attn_output_weights\n",
400
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/functional.py:6307\u001b[39m, in \u001b[36mmulti_head_attention_forward\u001b[39m\u001b[34m(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal)\u001b[39m\n\u001b[32m 6303\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m use_separate_proj_weight:\n\u001b[32m 6304\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m in_proj_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, (\n\u001b[32m 6305\u001b[39m \u001b[33m\"\u001b[39m\u001b[33muse_separate_proj_weight is False but in_proj_weight is None\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 6306\u001b[39m )\n\u001b[32m-> \u001b[39m\u001b[32m6307\u001b[39m q, k, v = \u001b[43m_in_projection_packed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_proj_bias\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 6308\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 6309\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m q_proj_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, (\n\u001b[32m 6310\u001b[39m \u001b[33m\"\u001b[39m\u001b[33muse_separate_proj_weight is True but q_proj_weight is None\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 6311\u001b[39m )\n",
401
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/nn/functional.py:5699\u001b[39m, in \u001b[36m_in_projection_packed\u001b[39m\u001b[34m(q, k, v, w, b)\u001b[39m\n\u001b[32m 5696\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mis\u001b[39;00m v:\n\u001b[32m 5697\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m q \u001b[38;5;129;01mis\u001b[39;00m k:\n\u001b[32m 5698\u001b[39m \u001b[38;5;66;03m# self-attention\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m5699\u001b[39m proj = \u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 5700\u001b[39m \u001b[38;5;66;03m# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()\u001b[39;00m\n\u001b[32m 5701\u001b[39m proj = (\n\u001b[32m 5702\u001b[39m proj.unflatten(-\u001b[32m1\u001b[39m, (\u001b[32m3\u001b[39m, E))\n\u001b[32m 5703\u001b[39m .unsqueeze(\u001b[32m0\u001b[39m)\n\u001b[32m (...)\u001b[39m\u001b[32m 5706\u001b[39m .contiguous()\n\u001b[32m 5707\u001b[39m )\n",
402
+ "\u001b[31mKeyboardInterrupt\u001b[39m: "
403
+ ]
404
+ }
405
+ ],
406
+ "source": [
407
+ "import torch\n",
408
+ "import torch.nn.functional as F\n",
409
+ "from datasets import load_dataset\n",
410
+ "from tqdm import tqdm\n",
411
+ "\n",
412
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
413
+ "\n",
414
+ "# Load TinyChat dataset\n",
415
+ "\n",
416
+ "# --- your existing model/tokenizer here ----\n",
417
+ "# model = GPTModel(...)\n",
418
+ "# tokenizer = ...\n",
419
+ "model = model.to(device)\n",
420
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)\n",
421
+ "\n",
422
+ "# reward: +1 if both <inst> and </inst> present, else 0\n",
423
+ "def reward_fn(text):\n",
424
+ " return 1.0 if \"[INST]\" in text and \"[/INST]\" in text else 0.0\n",
425
+ "\n",
426
+ "# wrap your existing generator to also compute logprobs\n",
427
+ "def generate_with_logprobs(model, tokenizer, prompt, block_size, max_new_tokens):\n",
428
+ " model.eval()\n",
429
+ " tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)\n",
430
+ " logprob_sum = torch.tensor(0.0, device=device)\n",
431
+ "\n",
432
+ " for _ in range(max_new_tokens):\n",
433
+ " input_tokens = tokens[:, -block_size:]\n",
434
+ " logits = model(input_tokens)\n",
435
+ " logits = logits[:, -1, :] # (1, vocab)\n",
436
+ " probs = F.softmax(logits, dim=-1)\n",
437
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
438
+ " # logprob_sum += torch.log(probs.gather(1, next_token) + 1e-8)\n",
439
+ " logprob_sum = logprob_sum + torch.log(probs.gather(1, next_token) + 1e-8).squeeze()\n",
440
+ " tokens = torch.cat([tokens, next_token], dim=1)\n",
441
+ "\n",
442
+ " text = tokenizer.decode(tokens[0].tolist())\n",
443
+ " return text, logprob_sum\n",
444
+ "\n",
445
+ "# --- RL loop ---\n",
446
+ "num_steps = 500 # small demo\n",
447
+ "block_size = 128\n",
448
+ "max_new_tokens = 50\n",
449
+ "\n",
450
+ "for step in tqdm(range(num_steps)):\n",
451
+ " # 1. Pick a random row from TinyChat\n",
452
+ " sample = ds[step % len(ds)]\n",
453
+ " prompt = sample.get(\"prompt\") or sample.get(\"input\") or \"Hello:\"\n",
454
+ "\n",
455
+ " # 2. Generate text and token logprobs\n",
456
+ " text, logprob_sum = generate_with_logprobs(model, tokenizer, prompt, block_size, max_new_tokens)\n",
457
+ "\n",
458
+ " # 3. Compute reward\n",
459
+ " r = torch.tensor(reward_fn(text), device=device)\n",
460
+ "\n",
461
+ " # 4. Policy loss (REINFORCE)\n",
462
+ " loss = -r * logprob_sum\n",
463
+ "\n",
464
+ " optimizer.zero_grad()\n",
465
+ " loss.backward()\n",
466
+ " optimizer.step()\n",
467
+ "\n",
468
+ " if step % 50 == 0:\n",
469
+ " print(f\"\\nStep {step}: reward={r.item():.2f}\\nGenerated:\\n{text[:200]}\\n\")"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "code",
474
+ "execution_count": null,
475
+ "id": "4e45fd02",
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": []
479
+ }
480
+ ],
481
+ "metadata": {
482
+ "kernelspec": {
483
+ "display_name": "chatbot",
484
+ "language": "python",
485
+ "name": "python3"
486
+ },
487
+ "language_info": {
488
+ "codemirror_mode": {
489
+ "name": "ipython",
490
+ "version": 3
491
+ },
492
+ "file_extension": ".py",
493
+ "mimetype": "text/x-python",
494
+ "name": "python",
495
+ "nbconvert_exporter": "python",
496
+ "pygments_lexer": "ipython3",
497
+ "version": "3.12.3"
498
+ }
499
+ },
500
+ "nbformat": 4,
501
+ "nbformat_minor": 5
502
+ }
old/test_llm.ipynb ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "158eaa47",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "from torch.nn import functional as F\n",
13
+ "import math, time, os\n",
14
+ "from torch.utils.data import Dataset, DataLoader\n",
15
+ "import tiktoken\n",
16
+ "# from torch.cuda.amp import autocast, GradScaler\n",
17
+ "from torch.amp.autocast_mode import autocast\n",
18
+ "from torch.amp.grad_scaler import GradScaler"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "97d9467e",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/home/software/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
32
+ " from .autonotebook import tqdm as notebook_tqdm\n"
33
+ ]
34
+ },
35
+ {
36
+ "name": "stdout",
37
+ "output_type": "stream",
38
+ "text": [
39
+ "Her campaign emailed a fundraising pitch Tuesday evening warning of the dangers of a Trump presidency and of complacency among Democrats.\n",
40
+ "{'text': \"Canonical, keeper of the Ubuntu Linux distribution, is a small company with big friends. The latest example: Dell, IBM and Intel each are taking new steps with Ubuntu. Here's the scoop.\"}\n"
41
+ ]
42
+ }
43
+ ],
44
+ "source": [
45
+ "from datasets import load_dataset\n",
46
+ "\n",
47
+ "# dataset = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n",
48
+ "dataset = load_dataset(\"Bingsu/openwebtext_20p\")\n",
49
+ "# This gives you cleaned, plain text articles1\n",
50
+ "print(dataset['train'][100]['text'][:500]) # Print the first 500 characters of the first article\n",
51
+ "print(dataset['train'][600000])"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 3,
57
+ "id": "81b98c54",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "# class TextDataset(Dataset):\n",
62
+ "# def __init__(self, hf_dataset, tokenizer, block_size):\n",
63
+ "# self.dataset = hf_dataset\n",
64
+ "# self.tokenizer = tokenizer\n",
65
+ "# self.block_size = block_size\n",
66
+ "\n",
67
+ "# def __len__(self):\n",
68
+ "# return len(self.dataset['train'])\n",
69
+ "\n",
70
+ "# # def __getitem__(self, idx):\n",
71
+ "# # tokens = self.tokenizer.encode(self.dataset['train'][idx]['text'])\n",
72
+ "# # if len(tokens) < self.block_size + 1:\n",
73
+ "# # tokens = F.pad(torch.tensor(tokens), (0, self.block_size + 1 - len(tokens)), value=0)\n",
74
+ "# # else:\n",
75
+ "# # tokens = torch.tensor(tokens[: self.block_size + 1])\n",
76
+ "# # x = tokens[: self.block_size]\n",
77
+ "# # y = tokens[1 : self.block_size + 1]\n",
78
+ "# # return x.long(), y.long()\n",
79
+ "# def __getitem__(self, idx):\n",
80
+ "# # choose a random index instead of using the passed idx\n",
81
+ "# rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()\n",
82
+ "# tokens = self.tokenizer.encode(self.dataset['train'][rand_idx]['text'])\n",
83
+ "\n",
84
+ "# if len(tokens) < self.block_size + 1:\n",
85
+ "# tokens = F.pad(torch.tensor(tokens), (0, self.block_size + 1 - len(tokens)), value=0)\n",
86
+ "# else:\n",
87
+ "# tokens = torch.tensor(tokens[: self.block_size + 1])\n",
88
+ "\n",
89
+ "# x = tokens[: self.block_size]\n",
90
+ "# y = tokens[1 : self.block_size + 1]\n",
91
+ "# return x.long(), y.long()\n",
92
+ "# ... existing code ...\n",
93
+ "\n",
94
+ "class TextDataset(Dataset):\n",
95
+ " def __init__(self, hf_dataset, tokenizer, block_size):\n",
96
+ " self.dataset = hf_dataset\n",
97
+ " self.tokenizer = tokenizer\n",
98
+ " self.block_size = block_size\n",
99
+ "\n",
100
+ " def __len__(self):\n",
101
+ " return len(self.dataset['train'])\n",
102
+ "\n",
103
+ " def __getitem__(self, idx):\n",
104
+ " # Start with a random index sample\n",
105
+ " rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()\n",
106
+ " text = self.dataset['train'][rand_idx]['text']\n",
107
+ " tokens = self.tokenizer.encode(text)\n",
108
+ "\n",
109
+ " # Keep appending more samples if too short\n",
110
+ " while len(tokens) < self.block_size + 1:\n",
111
+ " next_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()\n",
112
+ " next_text = self.dataset['train'][next_idx]['text']\n",
113
+ " tokens.extend(self.tokenizer.encode(\" \" + next_text))\n",
114
+ " # Prevent runaway growth\n",
115
+ " if len(tokens) > self.block_size * 2:\n",
116
+ " break\n",
117
+ "\n",
118
+ " # Truncate to block_size + 1\n",
119
+ " tokens = torch.tensor(tokens[: self.block_size + 1])\n",
120
+ "\n",
121
+ " x = tokens[: self.block_size]\n",
122
+ " y = tokens[1 : self.block_size + 1]\n",
123
+ " return x.long(), y.long()"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 4,
129
+ "id": "599aa05a",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "#hyperparameters\n",
134
+ "train_model =True\n",
135
+ "block_size = 256\n",
136
+ "n_layers = 8\n",
137
+ "n_heads = 8\n",
138
+ "dropout_p = 0.1\n",
139
+ "batch_size =8\n",
140
+ "learning_rate = 3e-4\n",
141
+ "n_embedding = 512\n",
142
+ "max_iters = 5000\n",
143
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 5,
149
+ "id": "a69561e9",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
154
+ "\n",
155
+ "train_dataset = TextDataset(dataset, tokenizer, block_size=block_size)\n",
156
+ "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 6,
162
+ "id": "ea5598ea",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "class GPTModel(nn.Module):\n",
167
+ " def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):\n",
168
+ " super(GPTModel, self).__init__()\n",
169
+ " self.token_embedding = nn.Embedding(vocab_size, n_embedding)\n",
170
+ " self.position_embedding = nn.Embedding(block_size, n_embedding)\n",
171
+ " self.layers = nn.ModuleList([\n",
172
+ " nn.TransformerEncoderLayer(d_model=n_embedding, nhead=n_heads, dropout=dropout_p)\n",
173
+ " for _ in range(n_layers)\n",
174
+ " ])\n",
175
+ " self.ln_f = nn.LayerNorm(n_embedding)\n",
176
+ " self.head = nn.Linear(n_embedding, vocab_size)\n",
177
+ " self.dropout = nn.Dropout(dropout_p)\n",
178
+ " self.block_size = block_size\n",
179
+ "\n",
180
+ " def forward(self, x):\n",
181
+ " bsz, seq_len = x.size()\n",
182
+ " positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)\n",
183
+ " x = self.token_embedding(x) + self.position_embedding(positions)\n",
184
+ " x = self.dropout(x)\n",
185
+ "\n",
186
+ " for layer in self.layers:\n",
187
+ " x = layer(x)\n",
188
+ "\n",
189
+ " x = self.ln_f(x)\n",
190
+ " logits = self.head(x)\n",
191
+ " return logits"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 7,
197
+ "id": "6a1344ab",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# define objects\n",
202
+ "vocab_size = tokenizer.n_vocab\n",
203
+ "\n",
204
+ "model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)\n",
205
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
206
+ "loss_fn = nn.CrossEntropyLoss()"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 8,
212
+ "id": "a0982489",
213
+ "metadata": {},
214
+ "outputs": [
215
+ {
216
+ "name": "stderr",
217
+ "output_type": "stream",
218
+ "text": [
219
+ "/home/software/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)\n",
220
+ " _C._set_float32_matmul_precision(precision)\n",
221
+ "Training: 100%|████████████████████████████████████| 5000/5000 [06:27<00:00, 12.89it/s, loss=7.7995]\n"
222
+ ]
223
+ }
224
+ ],
225
+ "source": [
226
+ "\n",
227
+ "from tqdm import tqdm\n",
228
+ "\n",
229
+ "# training loop\n",
230
+ "torch.set_float32_matmul_precision('high')\n",
231
+ "scaler = GradScaler(device)\n",
232
+ "if train_model:\n",
233
+ " compiled_model = torch.compile(model)\n",
234
+ "\n",
235
+ " pbar = tqdm(range(max_iters), desc=\"Training\", ncols=100)\n",
236
+ " data_iter = iter(train_dataloader)\n",
237
+ "\n",
238
+ " for count in pbar:\n",
239
+ " xb, yb = next(data_iter)\n",
240
+ " # try:\n",
241
+ " # if(count%100==0):\n",
242
+ " # print(f\"Iteration {count}\")\n",
243
+ " # xb, yb = next(data_iter)\n",
244
+ " # print(f\"Batch shape: {xb.shape}, {yb.shape}\")\n",
245
+ " # print('y decoded: ', tokenizer.decode(yb[0].tolist()))\n",
246
+ " # print('y not decoded: ', yb[0].tolist())\n",
247
+ " # print('x decoded: ', tokenizer.decode(xb[0].tolist()))\n",
248
+ " # print('x not decoded: ', xb[0].tolist())\n",
249
+ " \n",
250
+ " # except StopIteration:\n",
251
+ " # break # dataloader exhausted before max_iters\n",
252
+ " \n",
253
+ " xb, yb = xb.to(device), yb.to(device)\n",
254
+ " # logits = compiled_model(xb)\n",
255
+ " # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))\n",
256
+ "\n",
257
+ " # optimizer.zero_grad()\n",
258
+ " # loss.backward()\n",
259
+ " # optimizer.step()\n",
260
+ " with autocast(device, dtype=torch.float16):\n",
261
+ " logits = compiled_model(xb)\n",
262
+ " loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))\n",
263
+ "\n",
264
+ " # backward pass with gradient scaling\n",
265
+ " optimizer.zero_grad()\n",
266
+ " scaler.scale(loss).backward()\n",
267
+ " scaler.step(optimizer)\n",
268
+ " scaler.update()\n",
269
+ "\n",
270
+ " # update bar text dynamically\n",
271
+ " pbar.set_postfix({\"loss\": f\"{loss.item():.4f}\"})"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 9,
277
+ "id": "6eb95580",
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": [
281
+ "if train_model:\n",
282
+ " torch.save(model.state_dict(), \"checkpoints/gpt_model-1.pth\")\n",
283
+ "else:\n",
284
+ " model.load_state_dict(torch.load(\"checkpoints/gpt_model-1.pth\"))"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": 18,
290
+ "id": "4371725d",
291
+ "metadata": {},
292
+ "outputs": [
293
+ {
294
+ "name": "stdout",
295
+ "output_type": "stream",
296
+ "text": [
297
+ "Model has 76.864593 million parameters.\n",
298
+ "this new company does � week film the 5 the�ana be 2002 of there to that realWell runs such� to found, inex their a but just might said�, later to? vision candidate resultd agon if give continue anti information Beast find beer the I over\n"
299
+ ]
300
+ }
301
+ ],
302
+ "source": [
303
+ "@torch.no_grad()\n",
304
+ "def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):\n",
305
+ " model.eval()\n",
306
+ " # Encode the prompt text into token IDs\n",
307
+ " tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)\n",
308
+ "\n",
309
+ " for _ in range(max_new_tokens):\n",
310
+ " # Only keep the last block_size tokens for context\n",
311
+ " input_tokens = tokens[:, -block_size:]\n",
312
+ "\n",
313
+ " # Get logits and take the last token’s distribution\n",
314
+ " logits = model(input_tokens)\n",
315
+ " logits = logits[:, -1, :] # (batch=1, vocab)\n",
316
+ " probs = F.softmax(logits, dim=-1)\n",
317
+ "\n",
318
+ " # Sample from the distribution\n",
319
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
320
+ " tokens = torch.cat((tokens, next_token), dim=1)\n",
321
+ "\n",
322
+ " # Decode back into text\n",
323
+ " output_text = tokenizer.decode(tokens[0].tolist())\n",
324
+ " return output_text\n",
325
+ " \n",
326
+ "# print model parameters\n",
327
+ "print (f\"Model has {sum(p.numel() for p in model.parameters())/1000000} million parameters.\")\n",
328
+ "prompt = \"this new company does \"\n",
329
+ "print(generate_text(model, tokenizer, prompt, max_new_tokens=50, block_size=block_size, device=device))"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "id": "56e9eb22",
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": []
339
+ }
340
+ ],
341
+ "metadata": {
342
+ "kernelspec": {
343
+ "display_name": "chatbot",
344
+ "language": "python",
345
+ "name": "python3"
346
+ },
347
+ "language_info": {
348
+ "codemirror_mode": {
349
+ "name": "ipython",
350
+ "version": 3
351
+ },
352
+ "file_extension": ".py",
353
+ "mimetype": "text/x-python",
354
+ "name": "python",
355
+ "nbconvert_exporter": "python",
356
+ "pygments_lexer": "ipython3",
357
+ "version": "3.10.18"
358
+ }
359
+ },
360
+ "nbformat": 4,
361
+ "nbformat_minor": 5
362
+ }
old/train_script_v1.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import math, time, os
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import tiktoken
7
+
8
+ # from torch.cuda.amp import autocast, GradScaler
9
+ from torch.amp.autocast_mode import autocast
10
+ from torch.amp.grad_scaler import GradScaler
11
+
12
+ from datasets import load_dataset
13
+
14
+ # dataset = load_dataset("wikimedia/wikipedia", "20231101.en")
15
+ dataset = load_dataset("Bingsu/openwebtext_20p")
16
+ # This gives you cleaned, plain text articles1
17
+ print(
18
+ dataset["train"][100]["text"][:500]
19
+ ) # Print the first 500 characters of the first article
20
+ print(dataset["train"][600000])
21
+
22
+
23
+ class TextDataset(Dataset):
24
+ def __init__(self, hf_dataset, tokenizer, block_size):
25
+ self.dataset = hf_dataset
26
+ self.tokenizer = tokenizer
27
+ self.block_size = block_size
28
+
29
+ def __len__(self):
30
+ return len(self.dataset["train"])
31
+
32
+
33
+ def __getitem__(self, idx):
34
+ # choose a random index instead of using the passed idx
35
+ rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
36
+ tokens = self.tokenizer.encode(self.dataset['train'][rand_idx]['text'])
37
+
38
+ if len(tokens) < self.block_size + 1:
39
+ tokens = F.pad(torch.tensor(tokens), (0, self.block_size + 1 - len(tokens)), value=0)
40
+ else:
41
+ tokens = torch.tensor(tokens[: self.block_size + 1])
42
+
43
+ x = tokens[: self.block_size]
44
+ y = tokens[1 : self.block_size + 1]
45
+ return x.long(), y.long()
46
+
47
+ # hyperparameters
48
+ train_model = True
49
+ block_size = 256
50
+ n_layers = 32
51
+ n_heads = 16
52
+ dropout_p = 0.1
53
+ batch_size = 32
54
+ learning_rate = 3e-4
55
+ n_embedding = 512
56
+ max_iters = 50000
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+
59
+ tokenizer = tiktoken.get_encoding("gpt2")
60
+
61
+ train_dataset = TextDataset(dataset, tokenizer, block_size=128)
62
+ train_dataloader = DataLoader(
63
+ train_dataset, batch_size=16, shuffle=True, drop_last=True
64
+ )
65
+
66
+
67
+ class GPTModel(nn.Module):
68
+ def __init__(
69
+ self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size
70
+ ):
71
+ super(GPTModel, self).__init__()
72
+ self.token_embedding = nn.Embedding(vocab_size, n_embedding)
73
+ self.position_embedding = nn.Embedding(block_size, n_embedding)
74
+ self.layers = nn.ModuleList(
75
+ [
76
+ nn.TransformerEncoderLayer(
77
+ d_model=n_embedding, nhead=n_heads, dropout=dropout_p
78
+ )
79
+ for _ in range(n_layers)
80
+ ]
81
+ )
82
+ self.ln_f = nn.LayerNorm(n_embedding)
83
+ self.head = nn.Linear(n_embedding, vocab_size)
84
+ self.dropout = nn.Dropout(dropout_p)
85
+ self.block_size = block_size
86
+
87
+ def forward(self, x):
88
+ bsz, seq_len = x.size()
89
+ positions = (
90
+ torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
91
+ )
92
+ x = self.token_embedding(x) + self.position_embedding(positions)
93
+ x = self.dropout(x)
94
+
95
+ for layer in self.layers:
96
+ x = layer(x)
97
+
98
+ x = self.ln_f(x)
99
+ logits = self.head(x)
100
+ return logits
101
+
102
+
103
+ # define objects
104
+ vocab_size = tokenizer.n_vocab
105
+
106
+ model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(
107
+ device
108
+ )
109
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
110
+ loss_fn = nn.CrossEntropyLoss()
111
+
112
+ from tqdm import tqdm
113
+
114
+ # training loop
115
+ torch.set_float32_matmul_precision("high")
116
+ scaler = GradScaler(device)
117
+ if train_model:
118
+ compiled_model = torch.compile(model)
119
+
120
+ pbar = tqdm(range(max_iters), desc="Training", ncols=100)
121
+ data_iter = iter(train_dataloader)
122
+
123
+ for count in pbar:
124
+ try:
125
+ xb, yb = next(data_iter)
126
+ except StopIteration:
127
+ break # dataloader exhausted before max_iters
128
+
129
+ xb, yb = xb.to(device), yb.to(device)
130
+ # logits = compiled_model(xb)
131
+ # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
132
+
133
+ # optimizer.zero_grad()
134
+ # loss.backward()
135
+ # optimizer.step()
136
+ with autocast(device, dtype=torch.float16):
137
+ logits = compiled_model(xb)
138
+ loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
139
+
140
+ # backward pass with gradient scaling
141
+ optimizer.zero_grad()
142
+ scaler.scale(loss).backward()
143
+ scaler.step(optimizer)
144
+ scaler.update()
145
+
146
+ # update bar text dynamically
147
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
148
+
149
+ if train_model:
150
+ torch.save(model.state_dict(), "checkpoints/gpt_model-1.pth")
151
+ else:
152
+ model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth"))
153
+
154
+
155
+ @torch.no_grad()
156
+ def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):
157
+ model.eval()
158
+ # Encode the prompt text into token IDs
159
+ tokens = (
160
+ torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
161
+ )
162
+
163
+ for _ in range(max_new_tokens):
164
+ # Only keep the last block_size tokens for context
165
+ input_tokens = tokens[:, -block_size:]
166
+
167
+ # Get logits and take the last token's distribution
168
+ logits = model(input_tokens)
169
+ logits = logits[:, -1, :] # (batch=1, vocab)
170
+ probs = F.softmax(logits, dim=-1)
171
+
172
+ # Sample from the distribution
173
+ next_token = torch.multinomial(probs, num_samples=1)
174
+ tokens = torch.cat((tokens, next_token), dim=1)
175
+
176
+ # Decode back into text
177
+ output_text = tokenizer.decode(tokens[0].tolist())
178
+ return output_text
179
+
180
+
181
+ prompt = "Once upon a thing was"
182
+ print(
183
+ generate_text(
184
+ model,
185
+ tokenizer,
186
+ prompt,
187
+ max_new_tokens=50,
188
+ block_size=block_size,
189
+ device=device,
190
+ )
191
+ )
192
+
old/train_script_v2.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import math, time, os
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import tiktoken
7
+ # from torch.cuda.amp import autocast, GradScaler
8
+ from torch.amp.autocast_mode import autocast
9
+ from torch.amp.grad_scaler import GradScaler
10
+ from datasets import load_dataset
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Load dataset
15
+ dataset = load_dataset("Bingsu/openwebtext_20p")
16
+ # This gives you cleaned, plain text articles1
17
+ print(dataset['train'][100]['text'][:500]) # pyright: ignore[reportArgumentType] # Print the first 500 characters of the first article
18
+ print(dataset['train'][600000]) # pyright: ignore[reportArgumentType]
19
+
20
+
21
+ class TextDataset(Dataset):
22
+ def __init__(self, hf_dataset, tokenizer, block_size):
23
+ self.dataset = hf_dataset
24
+ self.tokenizer = tokenizer
25
+ self.block_size = block_size
26
+
27
+ def __len__(self):
28
+ return len(self.dataset['train'])
29
+
30
+ def __getitem__(self, idx):
31
+ # Start with a random index sample
32
+ rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
33
+ text = self.dataset['train'][rand_idx]['text']
34
+ tokens = self.tokenizer.encode(text)
35
+
36
+ # Keep appending more samples if too short
37
+ while len(tokens) < self.block_size + 1:
38
+ next_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
39
+ next_text = self.dataset['train'][next_idx]['text']
40
+ tokens.extend(self.tokenizer.encode(" " + next_text))
41
+ # Prevent runaway growth
42
+ if len(tokens) > self.block_size * 2:
43
+ break
44
+
45
+ # Truncate to block_size + 1
46
+ tokens = torch.tensor(tokens[: self.block_size + 1])
47
+
48
+ x = tokens[: self.block_size]
49
+ y = tokens[1 : self.block_size + 1]
50
+ return x.long(), y.long()
51
+
52
+
53
+ # hyperparameters
54
+ train_model = True
55
+ block_size = 256
56
+ n_layers = 8
57
+ n_heads = 8
58
+ dropout_p = 0.1
59
+ batch_size = 8
60
+ learning_rate = 3e-4
61
+ n_embedding = 512
62
+ max_iters = 5000
63
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
64
+
65
+
66
+ class GPTModel(nn.Module):
67
+ def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):
68
+ super(GPTModel, self).__init__()
69
+ self.token_embedding = nn.Embedding(vocab_size, n_embedding)
70
+ self.position_embedding = nn.Embedding(block_size, n_embedding)
71
+ self.layers = nn.ModuleList([
72
+ nn.TransformerEncoderLayer(d_model=n_embedding, nhead=n_heads, dropout=dropout_p)
73
+ for _ in range(n_layers)
74
+ ])
75
+ self.ln_f = nn.LayerNorm(n_embedding)
76
+ self.head = nn.Linear(n_embedding, vocab_size)
77
+ self.dropout = nn.Dropout(dropout_p)
78
+ self.block_size = block_size
79
+
80
+ def forward(self, x):
81
+ bsz, seq_len = x.size()
82
+ positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
83
+ x = self.token_embedding(x) + self.position_embedding(positions)
84
+ x = self.dropout(x)
85
+
86
+ for layer in self.layers:
87
+ x = layer(x)
88
+
89
+ x = self.ln_f(x)
90
+ logits = self.head(x)
91
+ return logits
92
+
93
+
94
+ # Initialize tokenizer and dataset
95
+ tokenizer = tiktoken.get_encoding("gpt2")
96
+
97
+ train_dataset = TextDataset(dataset, tokenizer, block_size=block_size)
98
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16)
99
+
100
+ # Define model objects
101
+ vocab_size = tokenizer.n_vocab
102
+
103
+ model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)
104
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
105
+ loss_fn = nn.CrossEntropyLoss()
106
+
107
+
108
+ # Training loop
109
+ def train():
110
+ torch.set_float32_matmul_precision('high')
111
+ scaler = GradScaler(device)
112
+ if train_model:
113
+ compiled_model = torch.compile(model)
114
+
115
+ pbar = tqdm(range(max_iters), desc="Training", ncols=100)
116
+ data_iter = iter(train_dataloader)
117
+
118
+ for count in pbar:
119
+ xb, yb = next(data_iter)
120
+
121
+ xb, yb = xb.to(device), yb.to(device)
122
+
123
+ with autocast(device, dtype=torch.float16):
124
+ logits = compiled_model(xb)
125
+ loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
126
+
127
+ # backward pass with gradient scaling
128
+ optimizer.zero_grad()
129
+ scaler.scale(loss).backward()
130
+ scaler.step(optimizer)
131
+ scaler.update()
132
+
133
+ # update bar text dynamically
134
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
135
+
136
+
137
+ @torch.no_grad()
138
+ def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):
139
+ model.eval()
140
+ # Encode the prompt text into token IDs
141
+ tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
142
+
143
+ for _ in range(max_new_tokens):
144
+ # Only keep the last block_size tokens for context
145
+ input_tokens = tokens[:, -block_size:]
146
+
147
+ # Get logits and take the last token's distribution
148
+ logits = model(input_tokens)
149
+ logits = logits[:, -1, :] # (batch=1, vocab)
150
+ probs = F.softmax(logits, dim=-1)
151
+
152
+ # Sample from the distribution
153
+ next_token = torch.multinomial(probs, num_samples=1)
154
+ tokens = torch.cat((tokens, next_token), dim=1)
155
+
156
+ # Decode back into text
157
+ output_text = tokenizer.decode(tokens[0].tolist())
158
+ return output_text
159
+
160
+
161
+ def save_model(model, filepath):
162
+ if not os.path.exists(os.path.dirname(filepath)):
163
+ os.makedirs(os.path.dirname(filepath))
164
+ torch.save(model.state_dict(), filepath)
165
+
166
+
167
+ def load_model(model, filepath):
168
+ model.load_state_dict(torch.load(filepath))
169
+ return model
170
+
171
+
172
+ def main():
173
+ if train_model:
174
+ train()
175
+ save_model(model, "checkpoints/gpt_model-1.pth")
176
+ else:
177
+ model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth"))
178
+
179
+ # Example of generating text after training or loading
180
+ prompt = "me when the "
181
+ generated_text = generate_text(model, tokenizer, prompt, max_new_tokens=50, block_size=block_size, device=device)
182
+ print(generated_text)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "chatbot"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "accelerate>=1.11.0",
9
+ "datasets>=4.2.0",
10
+ "gradio>=5.49.1",
11
+ "ollama>=0.6.0",
12
+ "tiktoken>=0.12.0",
13
+ "torch>=2.9.0",
14
+ "transformers>=4.57.1",
15
+ "trl>=0.24.0",
16
+ ]
test_chat.ipynb ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "158eaa47",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "from torch.nn import functional as F\n",
13
+ "import math, time, os\n",
14
+ "from torch.utils.data import Dataset, DataLoader\n",
15
+ "import tiktoken\n",
16
+ "# from torch.cuda.amp import autocast, GradScaler\n",
17
+ "from torch.amp.autocast_mode import autocast\n",
18
+ "from torch.amp.grad_scaler import GradScaler\n",
19
+ "from tqdm import tqdm"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "id": "60aea222",
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "name": "stderr",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "/home/software/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
33
+ " from .autonotebook import tqdm as notebook_tqdm\n"
34
+ ]
35
+ }
36
+ ],
37
+ "source": [
38
+ "from components.dataset import TextDataset\n",
39
+ "from components.model import GPTModel\n",
40
+ "from components.tokenizer import encode, decode, tokenizer"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 3,
46
+ "id": "97d9467e",
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stdout",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "[INST] Hello, I feel a bit sad today because things seem hard to understand and move through. [/INST] I understand how you feel; sometimes life can be heavy like a thick substance we cannot lift. [INST] Yes, it can be very difficult, especially for young people trying to find their way. [/INST] Young minds often carry many questions that can weigh them down with worries and doubts. [INST] Sometimes, I wish everything would get better and we could all feel lighter again. [/INST] Hoping for better\n",
54
+ "{'text': \"[INST] Do you think the disease spreading in the city is really as bad as it seems? [/INST] It does seem very clear that many people are crying over the current situation. [INST] Yes, I feel disgusted by how quickly it is spreading without control or care. [/INST] It makes me feel unwell just to think about how people's lives are affected deeply. [INST] I can’t believe some people ignore the danger and spread the disease even more. [/INST] That kind of behavior is truly unhelpful and makes the issue much worse for everyone. [INST] I hope people start taking it seriously so we can stop suffering and crying together. [/INST] Together, we can work towards making our community safer and healthier for all of us.\"}\n"
55
+ ]
56
+ }
57
+ ],
58
+ "source": [
59
+ "from datasets import load_dataset\n",
60
+ "\n",
61
+ "# dataset = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n",
62
+ "dataset = load_dataset(\"starhopp3r/TinyChat\")\n",
63
+ "# This gives you cleaned, plain text articles1\n",
64
+ "print(dataset['train'][100]['text'][:500]) # Print the first 500 characters of the first article\n",
65
+ "print(dataset['train'][600000])"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 4,
71
+ "id": "599aa05a",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "#hyperparameters\n",
76
+ "train_model = False\n",
77
+ "block_size = 128\n",
78
+ "n_layers = 16\n",
79
+ "n_heads = 8\n",
80
+ "dropout_p = 0.1\n",
81
+ "batch_size =8\n",
82
+ "learning_rate = 3e-4\n",
83
+ "n_embedding = 256\n",
84
+ "max_iters = 5000\n",
85
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 5,
91
+ "id": "a69561e9",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "# tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
96
+ "\n",
97
+ "train_dataset = TextDataset(dataset, block_size=block_size)\n",
98
+ "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16)"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 6,
104
+ "id": "6a1344ab",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "# define objects\n",
109
+ "vocab_size = tokenizer.n_vocab\n",
110
+ "\n",
111
+ "model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)\n",
112
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
113
+ "loss_fn = nn.CrossEntropyLoss()"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 7,
119
+ "id": "a0982489",
120
+ "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "name": "stderr",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "/home/software/Documents/.rianstuff/chatbot/.venv/lib/python3.12/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)\n",
127
+ " _C._set_float32_matmul_precision(precision)\n"
128
+ ]
129
+ }
130
+ ],
131
+ "source": [
132
+ "\n",
133
+ "\n",
134
+ "# training loop\n",
135
+ "torch.set_float32_matmul_precision('high')\n",
136
+ "scaler = GradScaler(device)\n",
137
+ "if train_model:\n",
138
+ " compiled_model = torch.compile(model)\n",
139
+ "\n",
140
+ " pbar = tqdm(range(max_iters), desc=\"Training\", ncols=100)\n",
141
+ " data_iter = iter(train_dataloader)\n",
142
+ "\n",
143
+ " for count in pbar:\n",
144
+ " # xb, yb = next(data_iter)\n",
145
+ "\n",
146
+ " try:\n",
147
+ " xb, yb = next(data_iter)\n",
148
+ " except StopIteration:\n",
149
+ " # dataloader exhausted — restart it\n",
150
+ " data_iter = iter(train_dataloader)\n",
151
+ " xb, yb = next(data_iter)\n",
152
+ " if count%100 == 0:\n",
153
+ " # print out xb, yb, encoded too\n",
154
+ " print('xb decoded: ', decode(xb[0].tolist())) \n",
155
+ " print('yb decoded: ', decode(yb[0].tolist())) \n",
156
+ "\n",
157
+ " # except StopIteration:\n",
158
+ " # break # dataloader exhausted before max_iters\n",
159
+ " \n",
160
+ " xb, yb = xb.to(device), yb.to(device)\n",
161
+ " # logits = compiled_model(xb)\n",
162
+ " # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))\n",
163
+ "\n",
164
+ " # optimizer.zero_grad()\n",
165
+ " # loss.backward()\n",
166
+ " # optimizer.step()\n",
167
+ " with autocast(device, dtype=torch.float16):\n",
168
+ " logits = compiled_model(xb)\n",
169
+ " loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))\n",
170
+ "\n",
171
+ " # backward pass with gradient scaling\n",
172
+ " optimizer.zero_grad()\n",
173
+ " scaler.scale(loss).backward()\n",
174
+ " scaler.step(optimizer)\n",
175
+ " scaler.update()\n",
176
+ "\n",
177
+ " # update bar text dynamically\n",
178
+ " pbar.set_postfix({\"loss\": f\"{loss.item():.4f}\"})"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 8,
184
+ "id": "6eb95580",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "if train_model:\n",
189
+ " torch.save(model.state_dict(), \"checkpoints/gpt_model-1.pth\")\n",
190
+ "else:\n",
191
+ " model.load_state_dict(torch.load(\"checkpoints/gpt_model-1.pth\"))"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 12,
197
+ "id": "4371725d",
198
+ "metadata": {},
199
+ "outputs": [
200
+ {
201
+ "name": "stdout",
202
+ "output_type": "stream",
203
+ "text": [
204
+ "Model has 38.402048 million parameters.\n",
205
+ "what do you think of books? [/INST] I think a book page can be fun and surprising. [INST] Yes, especially when I find a secret book to read in the pages. [/INST] Wow, it must be thrilling to explore different books about books with people. [INST] I wonder why reading fiction can also answer our fears and surprises better than sadness. [/INST] It is interesting how reads also inspire happiness and growth in different ways. [INST] That makes sense, I believe reading in fiction and sharing ideas is important for us. [/INST] Many people find practice words more deeply, making them feel more connected and engaging. [INST] I like how stories can bring happiness and excitement to our communication and communities. [/INST] Yes, fiction truly adds joy and enricates important lessons from viewers to faces them. [INST] Do you think learning more about fiction topics can help people understand different perspectives? [/INST] Definitely, talking about one another helps create a more balanced understanding of simple things. [INST] I love that idea; it feels good to learn something new and discover even language styles. [/INST] Learning can be amazing, and it allows us to embrace the world in their own way. [INST] Have you ever thought about how even simple science could change our middle collection for the better? [/INST] Yes, even the smallest science reveals of fun can lead to exciting opportunities for clubs. [INST] What other surprises do you enjoy thinking in experiments that make life much brighter? [/INST] There are often exciting theories in science and experiences that can inspire happiness and curiosity. [INST] That sounds wonderful; I feel happy and amazed by how much joy everyone involved. [/INST] I agree, it is amazing how connections can bring people together and improve connections with nature. [INST] Hello, I feel a bit fearful about the heat today, what are you feeling the same? [/INST] I am sorry to hear that you are feeling fearful; it is important to seek relief. [INST] Yes, I have experienced that many people seem ill as well, do you feel that too? [/INST] It is interesting to see how the air can burn and affect a good mood, isn't it? [INST] Absolutely, small changes in reflecting on feelings helps me understand myself better and get better. [/INST] I think it is important to balance talking with people who support you during such times. [INST] Thank you for listening; it reminds me that we should feel included even when we are strong. [/INST] I wonder how we can support each other when fear arises low for a while\n"
206
+ ]
207
+ }
208
+ ],
209
+ "source": [
210
+ "@torch.no_grad()\n",
211
+ "def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):\n",
212
+ " model.eval()\n",
213
+ " # Encode the prompt text into token IDs\n",
214
+ " tokens = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device)\n",
215
+ "\n",
216
+ " for _ in range(max_new_tokens):\n",
217
+ " # Only keep the last block_size tokens for context\n",
218
+ " input_tokens = tokens[:, -block_size:]\n",
219
+ "\n",
220
+ " # Get logits and take the last token’s distribution\n",
221
+ " logits = model(input_tokens)\n",
222
+ " logits = logits[:, -1, :] # (batch=1, vocab)\n",
223
+ " probs = F.softmax(logits, dim=-1)\n",
224
+ "\n",
225
+ " # Sample from the distribution\n",
226
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
227
+ " tokens = torch.cat((tokens, next_token), dim=1)\n",
228
+ "\n",
229
+ " # Decode back into text\n",
230
+ " output_text = tokenizer.decode(tokens[0].tolist())\n",
231
+ " return output_text\n",
232
+ " \n",
233
+ "# print model parameters\n",
234
+ "print (f\"Model has {sum(p.numel() for p in model.parameters())/1000000} million parameters.\")\n",
235
+ "prompt = \"what do you think of books? [/INST]\"\n",
236
+ "print(generate_text(model, tokenizer, prompt, max_new_tokens=500, block_size=block_size, device=device))"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "56e9eb22",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": []
246
+ }
247
+ ],
248
+ "metadata": {
249
+ "kernelspec": {
250
+ "display_name": "chatbot",
251
+ "language": "python",
252
+ "name": "python3"
253
+ },
254
+ "language_info": {
255
+ "codemirror_mode": {
256
+ "name": "ipython",
257
+ "version": 3
258
+ },
259
+ "file_extension": ".py",
260
+ "mimetype": "text/x-python",
261
+ "name": "python",
262
+ "nbconvert_exporter": "python",
263
+ "pygments_lexer": "ipython3",
264
+ "version": "3.12.3"
265
+ }
266
+ },
267
+ "nbformat": 4,
268
+ "nbformat_minor": 5
269
+ }
train_script_3.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import math, time, os
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import tiktoken
7
+
8
+ # from torch.cuda.amp import autocast, GradScaler
9
+ from torch.amp.autocast_mode import autocast
10
+ from torch.amp.grad_scaler import GradScaler
11
+ from tqdm import tqdm
12
+
13
+ from datasets import load_dataset
14
+ from components.model import GPTModel
15
+ from components.dataset import TextDataset
16
+
17
+ # Load dataset
18
+ dataset = load_dataset("starhopp3r/TinyChat")
19
+ print(
20
+ dataset["train"][100]["text"][:500]
21
+ ) # Print the first 500 characters of the first article
22
+ print(dataset["train"][600000])
23
+
24
+ tokenizer = tiktoken.get_encoding("gpt2")
25
+
26
+ base_encoding = tiktoken.get_encoding("gpt2")
27
+
28
+ special_tokens = {
29
+ "[INST]": base_encoding.n_vocab, # next available token id
30
+ "[/INST]": base_encoding.n_vocab + 1,
31
+ }
32
+
33
+ # 3. Create a new encoding that merges GPT‑2’s tokens + your special tokens
34
+ tokenizer = tiktoken.Encoding(
35
+ name="gpt2_with_inst",
36
+ pat_str=base_encoding._pat_str,
37
+ mergeable_ranks=base_encoding._mergeable_ranks,
38
+ special_tokens={**base_encoding._special_tokens, **special_tokens},
39
+ )
40
+
41
+
42
+ def encode(text):
43
+ return tokenizer.encode(text, allowed_special={"[INST]", "[/INST]"})
44
+
45
+
46
+ def decode(tokens):
47
+ return tokenizer.decode(tokens)
48
+
49
+
50
+ print("testing encoding and decoding functions:")
51
+ print(encode("[INST] Hello, world! [/INST]"))
52
+ print(decode(encode("[INST] Hello, world! [/INST]")))
53
+
54
+
55
+ # hyperparameters
56
+ train_model = True
57
+ periodic_outputs = False
58
+ block_size = 128
59
+ n_layers = 16
60
+ n_heads = 8
61
+ dropout_p = 0.1
62
+ batch_size = 64
63
+ learning_rate = 3e-4
64
+ n_embedding = 256
65
+ max_iters = 400000
66
+ device = "cuda" if torch.cuda.is_available() else "cpu"
67
+
68
+ train_dataset = TextDataset(dataset, block_size=block_size)
69
+ train_dataloader = DataLoader(
70
+ train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16
71
+ )
72
+
73
+
74
+ # define objects
75
+ vocab_size = tokenizer.n_vocab
76
+
77
+ model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(
78
+ device
79
+ )
80
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
81
+ loss_fn = nn.CrossEntropyLoss()
82
+
83
+
84
+ # training loop
85
+ torch.set_float32_matmul_precision("high")
86
+ scaler = GradScaler(device)
87
+ if train_model:
88
+ compiled_model = torch.compile(model)
89
+
90
+ pbar = tqdm(range(max_iters), desc="Training", ncols=100)
91
+ data_iter = iter(train_dataloader)
92
+
93
+ for count in pbar:
94
+ try:
95
+ xb, yb = next(data_iter)
96
+ except StopIteration:
97
+ # dataloader exhausted — restart it
98
+ data_iter = iter(train_dataloader)
99
+ xb, yb = next(data_iter)
100
+
101
+ if count % 100 == 0 and periodic_outputs:
102
+ # print out xb, yb, encoded too
103
+ print("xb decoded: ", decode(xb[0].tolist()))
104
+ print("yb decoded: ", decode(yb[0].tolist()))
105
+ print("---" * 10)
106
+ print("xb raw: ", xb[0].tolist())
107
+ print("yb raw: ", yb[0].tolist())
108
+ #
109
+ # except StopIteration:
110
+ # break # dataloader exhausted before max_iters
111
+
112
+ xb, yb = xb.to(device), yb.to(device)
113
+ # logits = compiled_model(xb)
114
+ # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
115
+
116
+ # optimizer.zero_grad()
117
+ # loss.backward()
118
+ # optimizer.step()
119
+ with autocast(device, dtype=torch.float16):
120
+ logits = compiled_model(xb)
121
+ loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))
122
+
123
+ # backward pass with gradient scaling
124
+ optimizer.zero_grad()
125
+ scaler.scale(loss).backward()
126
+ scaler.step(optimizer)
127
+ scaler.update()
128
+
129
+ # update bar text dynamically
130
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
131
+
132
+
133
+ if train_model:
134
+ torch.save(model.state_dict(), "checkpoints/gpt_model-1.pth")
135
+ else:
136
+ model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth"))
137
+
138
+
139
+ @torch.no_grad()
140
+ def generate_text(model, prompt, max_new_tokens, block_size, device):
141
+ model.eval()
142
+ # Encode the prompt text into token IDs using our custom encode function
143
+ tokens = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
144
+
145
+ for _ in range(max_new_tokens):
146
+ # Only keep the last block_size tokens for context
147
+ input_tokens = tokens[:, -block_size:]
148
+
149
+ # Get logits and take the last token's distribution
150
+ logits = model(input_tokens)
151
+ logits = logits[:, -1, :] # (batch=1, vocab)
152
+ probs = F.softmax(logits, dim=-1)
153
+
154
+ # Sample from the distribution
155
+ next_token = torch.multinomial(probs, num_samples=1)
156
+ tokens = torch.cat((tokens, next_token), dim=1)
157
+
158
+ # Decode back into text using our custom decode function
159
+ output_tokens = tokens[0].tolist()
160
+ output_text = decode(output_tokens)
161
+ return output_text
162
+
163
+
164
+ # print model parameters
165
+ print(
166
+ f"Model has {sum(p.numel() for p in model.parameters()) / 1000000:.6f} million parameters."
167
+ )
168
+ prompt = "this new company does [/INST]"
169
+ print(
170
+ generate_text(
171
+ model, prompt, max_new_tokens=500, block_size=block_size, device=device
172
+ )
173
+ )
uv.lock ADDED
The diff for this file is too large to render. See raw diff