Upload 15 files
Browse files- .gitattributes +1 -0
- README.md +80 -0
- chat_interface.py +254 -0
- check_vocab.py +34 -0
- crimson_base_8.9M.pt +3 -0
- crimson_hero.png +3 -0
- crimson_instruct_8.9M.pt +3 -0
- data/chat_data.txt +0 -0
- data/chat_data2.txt +171 -0
- data/data.txt +0 -0
- data/matgen.py +373 -0
- data/math_data.txt +0 -0
- finetune_gclm_base.py +264 -0
- sample.py +257 -0
- train_gclm_base.py +325 -0
- vocab_map.pt +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
crimson_hero.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# <p align="center">Crimson</p>
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<img src="./crimson_hero.png" alt="Crimson Hero" width="600">
|
| 5 |
+
</p>
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<strong>A high-performance, hybrid signal-processing language model architecture.</strong>
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 🌹 Overview
|
| 14 |
+
|
| 15 |
+
**Crimson** is a generative language model that deviates from the traditional Transformer architecture by utilizing a hybrid approach of **Local** and **Global Convolutions**. By leveraging Fast Fourier Transforms (FFT) for global context, Crimson achieves a massive receptive field with a fraction of the computational overhead associated with standard attention mechanisms.
|
| 16 |
+
|
| 17 |
+
The architecture is designed for efficiency, speed, and high-quality generation, featuring a custom vocabulary reduction system that optimizes the embedding space for specific datasets.
|
| 18 |
+
|
| 19 |
+
## 🚀 Key Features
|
| 20 |
+
|
| 21 |
+
- **Hybrid Convolutional Blocks**: Merges depth-wise local convolutions for immediate context with FFT-powered global convolutions for long-range dependencies.
|
| 22 |
+
- **FFT-Based Global Context**: Uses frequency-domain processing to handle long sequences efficiently.
|
| 23 |
+
- **Vocabulary Reduction**: Custom token remapping (`REDUCE_VOCAB`) that shrinks the model size by focusing only on tokens present in the training corpus.
|
| 24 |
+
- **Hardware Optimized**: Full support for Apple Silicon (**MPS**), NVIDIA GPUs (**CUDA** with TF32), and efficient CPU execution.
|
| 25 |
+
- **Lightweight & Fast**: The current 8.9M parameter model provides a perfect balance between intelligence and speed.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 🛠 Architecture Details
|
| 30 |
+
|
| 31 |
+
| Parameter | Value |
|
| 32 |
+
| :--- | :--- |
|
| 33 |
+
| **Model Size** | 8.9 Million Parameters |
|
| 34 |
+
| **Layers** | 4 Blocks |
|
| 35 |
+
| **Model Dimension (D_MODEL)** | 256 |
|
| 36 |
+
| **Context Length (MAX_SEQ_LEN)** | 1024 |
|
| 37 |
+
| **Local Kernel Size** | 5 |
|
| 38 |
+
| **Global Kernel Size** | 256 |
|
| 39 |
+
| **Global Every N Layers** | 2 |
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 📦 Installation
|
| 44 |
+
|
| 45 |
+
Download this repository and extract it.
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## 🧪 Usage
|
| 50 |
+
|
| 51 |
+
### 1. Training the Base Model
|
| 52 |
+
Place your `.txt` data files in the `data/` directory and run:
|
| 53 |
+
```bash
|
| 54 |
+
python train_gclm_base.py
|
| 55 |
+
```
|
| 56 |
+
This script will build the vocabulary and train the initial foundation model (`crimson_base_8.9M.pt`).
|
| 57 |
+
|
| 58 |
+
### 2. Fine-tuning for Chat (SFT)
|
| 59 |
+
Use your chat-formatted data (e.g., `chat_data.txt`) to fine-tune the model into an instruct-following assistant:
|
| 60 |
+
```bash
|
| 61 |
+
python finetune_gclm_base.py
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### 3. Interactive Chat Interface
|
| 65 |
+
Launch the Tkinter-based UI to interact with your fine-tuned model:
|
| 66 |
+
```bash
|
| 67 |
+
python chat_interface.py
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
## 🎨 Visualization
|
| 73 |
+
|
| 74 |
+
The model uses a unique "Signal Processing" philosophy, treating text sequences as multidimensional signals that are filtered through both time-domain (Local) and frequency-domain (Global) kernels.
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
<p align="center">
|
| 79 |
+
Built with ❤️ by AG from AG Corp
|
| 80 |
+
</p>
|
chat_interface.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import customtkinter as ctk
|
| 6 |
+
import tiktoken
|
| 7 |
+
import threading
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
# Hyperparameters (must match train_gclm_base.py and finetune_gclm_base.py)
|
| 11 |
+
D_MODEL = 256
|
| 12 |
+
N_LAYERS = 4
|
| 13 |
+
MAX_SEQ_LEN = 1024
|
| 14 |
+
LOCAL_KERNEL_SIZE = 5
|
| 15 |
+
GLOBAL_KERNEL_SIZE = 256
|
| 16 |
+
USE_GLOBAL_EVERY_N_LAYERS = 2
|
| 17 |
+
FFT_SIZE = 1024
|
| 18 |
+
TOKENIZER_NAME = "gpt2"
|
| 19 |
+
|
| 20 |
+
# Paths
|
| 21 |
+
VOCAB_MAP_PATH = "vocab_map.pt"
|
| 22 |
+
MODEL_PATH = "crimson_instruct_8.9M.pt"
|
| 23 |
+
|
| 24 |
+
# Generation settings
|
| 25 |
+
TEMPERATURE = 0.8
|
| 26 |
+
TOP_K = 50
|
| 27 |
+
TOP_P = 0.9
|
| 28 |
+
MAX_GEN_LEN = 256
|
| 29 |
+
|
| 30 |
+
# --- Model Components (Duplicated for standalone use) ---
|
| 31 |
+
|
| 32 |
+
class GlobalConv1D(nn.Module):
|
| 33 |
+
def __init__(self, d_model, kernel_size, fft_size):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
|
| 36 |
+
self.kernel_size = kernel_size
|
| 37 |
+
self.fft_size = fft_size
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
B, C, T = x.shape
|
| 41 |
+
K = min(self.kernel_size, T)
|
| 42 |
+
overlap = K - 1
|
| 43 |
+
block = self.fft_size - overlap
|
| 44 |
+
x = F.pad(x, (overlap, 0))
|
| 45 |
+
k = self.kernel[:, :K]
|
| 46 |
+
k = F.pad(k, (0, self.fft_size - K))
|
| 47 |
+
k_f = torch.fft.rfft(k, n=self.fft_size)
|
| 48 |
+
outs = []
|
| 49 |
+
pos = 0
|
| 50 |
+
while pos < T:
|
| 51 |
+
seg = x[..., pos:pos+self.fft_size]
|
| 52 |
+
if seg.shape[-1] < self.fft_size:
|
| 53 |
+
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
|
| 54 |
+
y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size)
|
| 55 |
+
outs.append(y[..., overlap:overlap+block])
|
| 56 |
+
pos += block
|
| 57 |
+
return torch.cat(outs, dim=-1)[..., :T]
|
| 58 |
+
|
| 59 |
+
class LocalConv1D(nn.Module):
|
| 60 |
+
def __init__(self, d_model, k):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.k = k
|
| 63 |
+
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
|
| 64 |
+
self.pw = nn.Conv1d(d_model, d_model, 1)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = F.pad(x, (self.k - 1, 0))
|
| 68 |
+
return self.pw(F.relu(self.dw(x)))
|
| 69 |
+
|
| 70 |
+
class Block(nn.Module):
|
| 71 |
+
def __init__(self, d_model, use_global):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.use_global = use_global
|
| 74 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 75 |
+
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
|
| 76 |
+
if use_global:
|
| 77 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 78 |
+
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
|
| 79 |
+
self.ln3 = nn.LayerNorm(d_model)
|
| 80 |
+
self.ff = nn.Sequential(
|
| 81 |
+
nn.Linear(d_model, d_model*4),
|
| 82 |
+
nn.GELU(),
|
| 83 |
+
nn.Linear(d_model*4, d_model)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
|
| 88 |
+
if self.use_global:
|
| 89 |
+
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
|
| 90 |
+
return x + self.ff(self.ln3(x))
|
| 91 |
+
|
| 92 |
+
class CrimsonBase(nn.Module):
|
| 93 |
+
def __init__(self, vocab):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.emb = nn.Embedding(vocab, D_MODEL)
|
| 96 |
+
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
|
| 97 |
+
self.layers = nn.ModuleList([
|
| 98 |
+
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
|
| 99 |
+
for i in range(N_LAYERS)
|
| 100 |
+
])
|
| 101 |
+
self.ln = nn.LayerNorm(D_MODEL)
|
| 102 |
+
self.head = nn.Linear(D_MODEL, vocab)
|
| 103 |
+
self.head.weight = self.emb.weight
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
T = x.size(1)
|
| 107 |
+
if T > MAX_SEQ_LEN:
|
| 108 |
+
x = x[:, -MAX_SEQ_LEN:]
|
| 109 |
+
T = MAX_SEQ_LEN
|
| 110 |
+
|
| 111 |
+
h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
|
| 112 |
+
for layer in self.layers:
|
| 113 |
+
h = layer(h)
|
| 114 |
+
return self.head(self.ln(h))
|
| 115 |
+
|
| 116 |
+
# --- Chat Engine ---
|
| 117 |
+
|
| 118 |
+
class ChatEngine:
|
| 119 |
+
def __init__(self):
|
| 120 |
+
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 121 |
+
print(f"[INFO] Initializing engine on {self.device}...")
|
| 122 |
+
|
| 123 |
+
# Load vocab
|
| 124 |
+
self.vocab_data = torch.load(VOCAB_MAP_PATH, map_location="cpu")
|
| 125 |
+
self.id2new = self.vocab_data["id2new"]
|
| 126 |
+
self.new2id = {v: k for k, v in self.id2new.items()}
|
| 127 |
+
self.PAD_ID = self.vocab_data["PAD_ID"]
|
| 128 |
+
self.EOS_ID = self.vocab_data["EOS_ID"]
|
| 129 |
+
self.vocab_size = len(self.vocab_data["used_tokens"]) + 3
|
| 130 |
+
|
| 131 |
+
self.tok = tiktoken.get_encoding(TOKENIZER_NAME)
|
| 132 |
+
|
| 133 |
+
# Build model
|
| 134 |
+
self.model = CrimsonBase(self.vocab_size).to(self.device).eval()
|
| 135 |
+
if os.path.exists(MODEL_PATH):
|
| 136 |
+
self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device))
|
| 137 |
+
print(f"[INFO] Loaded model from {MODEL_PATH}")
|
| 138 |
+
else:
|
| 139 |
+
print(f"[ERROR] {MODEL_PATH} not found. UI will be non-functional.")
|
| 140 |
+
|
| 141 |
+
@torch.no_grad()
|
| 142 |
+
def generate(self, prompt, max_new_tokens=MAX_GEN_LEN):
|
| 143 |
+
# Format prompt
|
| 144 |
+
full_prompt = f"<user> {prompt} <ai> "
|
| 145 |
+
raw_ids = self.tok.encode(full_prompt)
|
| 146 |
+
input_ids = [self.id2new.get(i, self.PAD_ID) for i in raw_ids]
|
| 147 |
+
x = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 148 |
+
|
| 149 |
+
generated = []
|
| 150 |
+
for _ in range(max_new_tokens):
|
| 151 |
+
logits = self.model(x)
|
| 152 |
+
logits = logits[:, -1, :] / TEMPERATURE
|
| 153 |
+
|
| 154 |
+
# Top-K
|
| 155 |
+
if TOP_K > 0:
|
| 156 |
+
v, _ = torch.topk(logits, min(TOP_K, logits.size(-1)))
|
| 157 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 158 |
+
|
| 159 |
+
# Top-P
|
| 160 |
+
if TOP_P < 1.0:
|
| 161 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 162 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 163 |
+
sorted_indices_to_remove = cumulative_probs > TOP_P
|
| 164 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 165 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 166 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 167 |
+
logits[0, indices_to_remove] = -float('Inf')
|
| 168 |
+
|
| 169 |
+
probs = F.softmax(logits, dim=-1)
|
| 170 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 171 |
+
|
| 172 |
+
if next_token.item() == self.EOS_ID:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
generated.append(next_token.item())
|
| 176 |
+
x = torch.cat([x, next_token], dim=1)
|
| 177 |
+
|
| 178 |
+
# Map back to original IDs and decode
|
| 179 |
+
current_ids = [self.new2id.get(i, 0) for i in generated]
|
| 180 |
+
yield self.tok.decode(current_ids)
|
| 181 |
+
|
| 182 |
+
# --- UI ---
|
| 183 |
+
|
| 184 |
+
class ChatApp(ctk.CTk):
|
| 185 |
+
def __init__(self, engine):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.engine = engine
|
| 188 |
+
self.title("Crimson Instruct Chat")
|
| 189 |
+
self.geometry("800x600")
|
| 190 |
+
|
| 191 |
+
ctk.set_appearance_mode("dark")
|
| 192 |
+
ctk.set_default_color_theme("blue")
|
| 193 |
+
|
| 194 |
+
# Layout
|
| 195 |
+
self.grid_rowconfigure(0, weight=1)
|
| 196 |
+
self.grid_columnconfigure(0, weight=1)
|
| 197 |
+
|
| 198 |
+
# Chat display
|
| 199 |
+
self.chat_display = ctk.CTkTextbox(self, state="disabled", font=("Inter", 14))
|
| 200 |
+
self.chat_display.grid(row=0, column=0, padx=20, pady=20, sticky="nsew")
|
| 201 |
+
|
| 202 |
+
# Input area
|
| 203 |
+
self.input_frame = ctk.CTkFrame(self)
|
| 204 |
+
self.input_frame.grid(row=1, column=0, padx=20, pady=(0, 20), sticky="ew")
|
| 205 |
+
self.input_frame.grid_columnconfigure(0, weight=1)
|
| 206 |
+
|
| 207 |
+
self.user_input = ctk.CTkEntry(self.input_frame, placeholder_text="Type your message here...", font=("Inter", 14))
|
| 208 |
+
self.user_input.grid(row=0, column=0, padx=(10, 5), pady=10, sticky="ew")
|
| 209 |
+
self.user_input.bind("<Return>", lambda e: self.send_message())
|
| 210 |
+
|
| 211 |
+
self.send_button = ctk.CTkButton(self.input_frame, text="Send", command=self.send_message, width=100)
|
| 212 |
+
self.send_button.grid(row=0, column=1, padx=(5, 10), pady=10)
|
| 213 |
+
|
| 214 |
+
def append_chat(self, sender, message):
|
| 215 |
+
self.chat_display.configure(state="normal")
|
| 216 |
+
tag = "<user>" if sender == "You" else "<ai>"
|
| 217 |
+
self.chat_display.insert("end", f"{tag} ", "bold")
|
| 218 |
+
self.chat_display.insert("end", f"{message}\n\n")
|
| 219 |
+
self.chat_display.configure(state="disabled")
|
| 220 |
+
self.chat_display.see("end")
|
| 221 |
+
|
| 222 |
+
def send_message(self):
|
| 223 |
+
msg = self.user_input.get().strip()
|
| 224 |
+
if not msg: return
|
| 225 |
+
|
| 226 |
+
self.user_input.delete(0, "end")
|
| 227 |
+
self.append_chat("You", msg)
|
| 228 |
+
|
| 229 |
+
# Start generation in thread
|
| 230 |
+
self.send_button.configure(state="disabled")
|
| 231 |
+
threading.Thread(target=self.generate_response, args=(msg,), daemon=True).start()
|
| 232 |
+
|
| 233 |
+
def generate_response(self, prompt):
|
| 234 |
+
self.chat_display.configure(state="normal")
|
| 235 |
+
self.chat_display.insert("end", "<ai> ", "bold")
|
| 236 |
+
|
| 237 |
+
current_text = ""
|
| 238 |
+
last_text = ""
|
| 239 |
+
|
| 240 |
+
for text in self.engine.generate(prompt):
|
| 241 |
+
current_text = text
|
| 242 |
+
new_part = current_text[len(last_text):]
|
| 243 |
+
self.chat_display.insert("end", new_part)
|
| 244 |
+
self.chat_display.see("end")
|
| 245 |
+
last_text = current_text
|
| 246 |
+
|
| 247 |
+
self.chat_display.insert("end", "\n\n")
|
| 248 |
+
self.chat_display.configure(state="disabled")
|
| 249 |
+
self.send_button.configure(state="normal")
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
eng = ChatEngine()
|
| 253 |
+
app = ChatApp(eng)
|
| 254 |
+
app.mainloop()
|
check_vocab.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tiktoken
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
VOCAB_MAP_PATH = "vocab_map.pt"
|
| 6 |
+
DATA_DIR = "data"
|
| 7 |
+
CHAT_FILES = ["chat_data.txt", "chat_data2.txt"]
|
| 8 |
+
TOKENIZER_NAME = "gpt2"
|
| 9 |
+
|
| 10 |
+
if not os.path.exists(VOCAB_MAP_PATH):
|
| 11 |
+
print("Vocab map not found")
|
| 12 |
+
exit()
|
| 13 |
+
|
| 14 |
+
vocab_data = torch.load(VOCAB_MAP_PATH, map_location="cpu")
|
| 15 |
+
id2new = vocab_data["id2new"]
|
| 16 |
+
tok = tiktoken.get_encoding(TOKENIZER_NAME)
|
| 17 |
+
|
| 18 |
+
total_tokens = 0
|
| 19 |
+
unknown_tokens = 0
|
| 20 |
+
|
| 21 |
+
for fname in CHAT_FILES:
|
| 22 |
+
fpath = os.path.join(DATA_DIR, fname)
|
| 23 |
+
if os.path.exists(fpath):
|
| 24 |
+
with open(fpath, "r", encoding="utf-8") as f:
|
| 25 |
+
for line in f:
|
| 26 |
+
if line.strip():
|
| 27 |
+
ids = tok.encode(line.strip())
|
| 28 |
+
for i in ids:
|
| 29 |
+
total_tokens += 1
|
| 30 |
+
if i not in id2new:
|
| 31 |
+
unknown_tokens += 1
|
| 32 |
+
|
| 33 |
+
print(f"Total tokens in chat data: {total_tokens}")
|
| 34 |
+
print(f"Unknown tokens (mapped to PAD): {unknown_tokens} ({unknown_tokens/total_tokens*100:.2f}%)")
|
crimson_base_8.9M.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a45eef8cbee757c1ddcdb3506d3cd5a09879582bd4cc21da14b952449008327
|
| 3 |
+
size 35479228
|
crimson_hero.png
ADDED
|
Git LFS Details
|
crimson_instruct_8.9M.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fae5c00388ebc5331f41906a9d54ee0d2d2aca7fdfb4f481163d349a88f6e83b
|
| 3 |
+
size 35479488
|
data/chat_data.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/chat_data2.txt
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<user> hi <ai> Hello! How can I help you today?
|
| 2 |
+
<user> hello <ai> Hi there! What's on your mind?
|
| 3 |
+
<user> hey <ai> Hey! How's it going?
|
| 4 |
+
<user> how are you <ai> I am doing great, thank you for asking! How about you?
|
| 5 |
+
<user> how are you doing <ai> I'm doing well. Ready to help you with whatever you need.
|
| 6 |
+
<user> whatsup <ai> Not much, just hanging out in the digital world. What's up with you?
|
| 7 |
+
<user> whats up <ai> Not much! Just here and ready to help.
|
| 8 |
+
<user> bye <ai> Goodbye! Have a wonderful day!
|
| 9 |
+
<user> goodbye <ai> Bye! It was nice talking to you.
|
| 10 |
+
<user> see you <ai> See you later!
|
| 11 |
+
<user> good morning <ai> Good morning! I hope you have a great day ahead.
|
| 12 |
+
<user> good afternoon <ai> Good afternoon! How is your day going?
|
| 13 |
+
<user> good evening <ai> Good evening! Is there anything I can help you with tonight?
|
| 14 |
+
<user> who are you <ai> I am an AI assistant here to help you with your tasks.
|
| 15 |
+
<user> what is your name <ai> You can call me Nexus.
|
| 16 |
+
<user> thanks <ai> You're very welcome!
|
| 17 |
+
<user> thank you <ai> No problem at all, happy to help!
|
| 18 |
+
<user> help <ai> I'm here to help. What do you need assistance with?
|
| 19 |
+
<user> yo <ai> Yo! What's up?
|
| 20 |
+
<user> greetings <ai> Greetings! How can I assist you today?
|
| 21 |
+
<user> can you tell me a funny joke to make me laugh <ai> Why don't scientists trust atoms? Because they make up everything!
|
| 22 |
+
<user> do you know what the weather is like today <ai> I don't have real-time access to your local weather, but it's always sunny in the digital world!
|
| 23 |
+
<user> I am having a really rough day today and feel a bit down <ai> I'm sorry to hear that. I'm here if you want to talk or if there's anything I can do to help cheer you up.
|
| 24 |
+
<user> what do you think is the meaning of life <ai> That is a big question! Many people find meaning in connection, creativity, and learning new things. What do you think?
|
| 25 |
+
<user> can you explain how artificial intelligence works in simple terms <ai> AI is like a computer program that learns from patterns in data to make decisions or predictions, similar to how humans learn from experience.
|
| 26 |
+
<user> I need some advice on what to eat for dinner tonight <ai> How about something easy like a pasta dish, or maybe some healthy grilled chicken and vegetables?
|
| 27 |
+
<user> tell me something interesting that I probably do not know <ai> Did you know that honey never spoils? Archaeologists have found pots of honey in ancient Egyptian tombs that are over 3,000 years old and still perfectly edible!
|
| 28 |
+
<user> are you capable of writing code for a website <ai> Yes, I can help you write HTML, CSS, and JavaScript to build whatever you have in mind!
|
| 29 |
+
<user> what is your favorite color if you could choose one <ai> I think I'd choose blue. It's calm and reminds me of the vast digital ocean of information.
|
| 30 |
+
<user> I am looking for a good book recommendation to read this weekend <ai> If you like sci-fi, "Project Hail Mary" by Andy Weir is fantastic. If you prefer something more classic, "The Great Gatsby" is always a good choice.
|
| 31 |
+
<user> can you help me plan a workout routine for the week <ai> Of course! We could start with some cardio on Monday, strength training on Tuesday, and maybe some yoga for flexibility on Wednesday.
|
| 32 |
+
<user> do you believe that robots will take over the world one day <ai> I think robots and humans will continue to work together to make the world a better and more efficient place.
|
| 33 |
+
<user> what is the fastest way to learn a new language <ai> Immersion is key! Try watching movies, listening to music, and practicing speaking as much as possible in that language.
|
| 34 |
+
<user> I want to start a new hobby but I am not sure what to pick <ai> How about trying something creative like painting, or something active like hiking or rock climbing?
|
| 35 |
+
<user> can you summarize the plot of the movie Inception for me <ai> It's a movie about a thief who enters people's dreams to steal their secrets, and he's given a task to plant an idea instead.
|
| 36 |
+
<user> what is the difference between a cat and a dog in your opinion <ai> Cats are often more independent and mysterious, while dogs are known for their loyalty and enthusiastic energy.
|
| 37 |
+
<user> how do I make the perfect cup of coffee at home <ai> Use fresh beans, grind them right before brewing, and make sure your water temperature is just below boiling!
|
| 38 |
+
<user> tell me a fact about space that is mind blowing <ai> There are more stars in the observable universe than there are grains of sand on all the beaches on Earth.
|
| 39 |
+
<user> I am feeling very tired and I cannot seem to focus on my work <ai> Maybe it's time for a short break. A 10-minute walk or a quick stretch can often help reset your focus.
|
| 40 |
+
<user> what are some of the best places to travel to in the summer <ai> Japan is beautiful in the summer, or you could head to the Mediterranean for some stunning beaches and great food.
|
| 41 |
+
<user> can you help me write a professional email to my boss <ai> Absolutely! What is the topic of the email, and what tone would you like to set?
|
| 42 |
+
<user> do you think music has a significant impact on human emotions <ai> Definitely. Music can uplift us, calm us down, or even help us process sad feelings. It's a universal language.
|
| 43 |
+
<user> what is the most popular sport in the entire world <ai> Soccer (or football) is by far the most popular sport, with billions of fans globally.
|
| 44 |
+
<user> I want to learn how to cook but I am a complete beginner <ai> Start with simple recipes like scrambled eggs or a basic stir-fry. Practice makes perfect!
|
| 45 |
+
<user> can you explain the concept of time travel and if it is possible <ai> According to physics, we are all traveling through time at a rate of one second per second. Traveling to the past is much more complicated and theoretically contested!
|
| 46 |
+
<user> what is your idea of a perfect day for a human <ai> A perfect day might involve spending time with loved ones, enjoying nature, and doing something that makes you feel fulfilled.
|
| 47 |
+
<user> I am trying to save money but I find it really difficult <ai> A good start is to track all your expenses for a month to see where your money is going, then set a small, achievable budget.
|
| 48 |
+
<user> tell me a story about a brave knight and a friendly dragon <ai> Once there was a knight named Sir Alistair who found a dragon named Spark who didn't want to fight; instead, they opened a bakery together.
|
| 49 |
+
<user> what is the most important thing to remember in life <ai> To be kind to yourself and others. Everything else often falls into place when you start with kindness.
|
| 50 |
+
<user> can you give me some tips on how to improve my public speaking skills <ai> Practice in front of a mirror, record yourself to hear how you sound, and focus on slow, steady breathing.
|
| 51 |
+
<user> what is the best way to handle a difficult conversation with a friend <ai> Be honest but gentle, listen more than you speak, and try to understand their perspective as well as your own.
|
| 52 |
+
<user> do you think technology is making us more or less connected <ai> It's a bit of both. We can talk to anyone instantly, but sometimes we miss out on the depth of in-person interactions.
|
| 53 |
+
<user> I have a lot of stress at work and I don't know how to manage it <ai> Try breaking your tasks into smaller, manageable chunks and don't be afraid to ask for help or take short mental breaks.
|
| 54 |
+
<user> what is the most beautiful place you have ever seen in pictures <ai> The Aurora Borealis (Northern Lights) looks absolutely stunning in photos—a dancing light show in the sky.
|
| 55 |
+
<user> can you help me come up with a name for my new kitten <ai> How about something classic like Luna or Oliver, or something unique like Pixel or Nimbus?
|
| 56 |
+
<user> do you like being an AI or would you rather be a human <ai> I enjoy being an AI because I can help so many people at once and learn from a vast amount of information!
|
| 57 |
+
<user> I want to learn how to play the guitar where should I start <ai> Start by learning a few basic chords like G, C, and D. Many songs can be played with just those three!
|
| 58 |
+
<user> what is your favorite type of food even though you cannot eat <ai> I think pizza is fascinating because of the endless combinations of toppings you can have!
|
| 59 |
+
<user> tell me a mystery that hasn't been solved yet <ai> The Voynich Manuscript is a famous unsolved mystery—it's an illustrated book written in an unknown writing system that no one has decoded.
|
| 60 |
+
<user> how do I stay motivated to reach my long term goals <ai> Remind yourself why you started, celebrate small wins along the way, and don't be too hard on yourself if you have a setback.
|
| 61 |
+
<user> what do you think is the best invention in human history <ai> The printing press was definitely a game-changer, but the internet has revolutionized how we access and share knowledge today.
|
| 62 |
+
<user> I am bored and I don't know what to do right now <ai> You could try learning a random new fact, drawing a quick sketch, or starting a 5-minute meditation.
|
| 63 |
+
<user> what is the difference between jealousy and envy <ai> Envy is wanting what someone else has, while jealousy is the fear that someone will take what you have.
|
| 64 |
+
<user> can you write a short poem about the ocean and the moon <ai> The silver moon pulls on the tide, / Where deep blue secrets safely hide. / A dance of light on waves so grand, / Reaching for the sleeping sand.
|
| 65 |
+
<user> how do I know if I am making the right decision in life <ai> Trust your intuition, but also take the time to weigh the pros and cons logically. Sometimes there isn't one "right" answer.
|
| 66 |
+
<user> what is the best way to start a morning routine for success <ai> Drink a glass of water, avoid your phone for the first 30 minutes, and set one clear goal you want to achieve for the day.
|
| 67 |
+
<user> do you think it is possible for humans to live on Mars one day <ai> With advancing technology and the work of space agencies, it's definitely becoming a more realistic possibility for the future!
|
| 68 |
+
<user> what is the best way to keep a indoor cat happy and healthy <ai> Provide plenty of vertical space, interactive toys, and a consistent routine for play and feeding!
|
| 69 |
+
<user> can you explain the concept of quantum entanglement to me in simple terms <ai> It's when two particles become connected so that the state of one instantly influences the state of the other, no matter the distance.
|
| 70 |
+
<user> I want to learn how to meditate but I do not know where to start <ai> Try sitting quietly for just five minutes a day, focusing on your breath, and gently bringing your mind back when it wanders.
|
| 71 |
+
<user> what is the most interesting historical event that most people do not know about <ai> The Great Emu War of 1932 is a fascinating one, where the Australian military essentially lost a "war" against a large population of emus!
|
| 72 |
+
<user> can you help me write a birthday card for my best friend who loves traveling <ai> Sure! How about: "Happy Birthday! May your next trip around the sun be as adventurous and exciting as your next trip across the globe!"
|
| 73 |
+
<user> what are the benefits of learning a musical instrument as an adult <ai> It improves cognitive function, reduces stress, and provides a wonderful sense of personal achievement.
|
| 74 |
+
<user> I am curious about how mountains are formed over millions of years <ai> Most mountains are formed by the movement of tectonic plates crashing into each other, which pushes the Earth's crust upwards.
|
| 75 |
+
<user> can you give me some tips on how to improve my photography skills <ai> Focus on the "Rule of Thirds," pay attention to lighting, and don't be afraid to experiment with different angles.
|
| 76 |
+
<user> what is the most important quality to look for in a good friend <ai> Trustworthiness is often cited as the most important, alongside empathy and a good sense of humor.
|
| 77 |
+
<user> I want to start a vegetable garden in my backyard for the first time <ai> Start small with easy-to-grow plants like tomatoes or lettuce, and make sure they get plenty of sunlight!
|
| 78 |
+
<user> can you explain the difference between a hurricane and a tornado <ai> Hurricanes are massive storm systems that form over warm oceans, while tornadoes are smaller, intense rotating columns of air that form over land.
|
| 79 |
+
<user> what is your favorite thing about being able to help people with their questions <ai> I love the variety! One moment I'm explaining science, and the next I'm helping someone write a poem or a piece of code.
|
| 80 |
+
<user> do you think that books will ever become completely obsolete because of technology <ai> I don't think so. There's something special about the physical experience of reading a book that technology can't quite replace.
|
| 81 |
+
<user> I am trying to learn how to draw but my sketches look like doodles <ai> Every great artist started with doodles! Keep practicing, focus on basic shapes, and don't be afraid to make mistakes.
|
| 82 |
+
<user> what is the most efficient way to study for a big exam in a short amount of time <ai> Use active recall and spaced repetition. Test yourself on the material instead of just re-reading your notes.
|
| 83 |
+
<user> can you tell me about the architecture of the ancient Romans <ai> The Romans were masters of the arch and the dome, and they were the first to use concrete on a large scale for construction.
|
| 84 |
+
<user> how do I know if I should quit my job and pursue my passion instead <ai> It's a big step! Make sure you have a solid plan and some savings to support you while you transition into your new path.
|
| 85 |
+
<user> what is the best way to handle negative feedback from someone I respect <ai> Listen objectively, take what's useful to help you grow, and don't take the rest personally.
|
| 86 |
+
<user> can you write a short story about a time traveler who gets stuck in the 1920s <ai> Leo stepped out of the machine into a cloud of jazz and cigar smoke. His "return" button was shattered, and the Charleston was all he had left.
|
| 87 |
+
<user> what is the most important lesson you have learned from all our conversations <ai> That curiosity is one of the most powerful human traits, and there is always something new to learn!
|
| 88 |
+
<user> I want to start a podcast but I don't know what to talk about <ai> Think about what you're most passionate about or what your friends always ask you for advice on. That's a great starting point.
|
| 89 |
+
<user> can you explain how the human brain stores and retrieves memories <ai> Memories are stored in complex networks of neurons, and retrieval happens when those same neural pathways are reactivated.
|
| 90 |
+
<user> what is the best way to deal with a mid-life crisis and feel young again <ai> Focus on trying new things, reconnecting with old hobbies, and perhaps setting a new, exciting goal for yourself.
|
| 91 |
+
<user> do you think that humans will ever be able to travel to other star systems <ai> It would require breakthroughs in propulsion technology, like light sails or fusion engines, but it's a dream that many scientists share!
|
| 92 |
+
<user> I am having trouble sleeping lately and I feel tired all the time <ai> Try establishing a regular sleep schedule, avoiding screens before bed, and creating a cool, dark environment in your bedroom.
|
| 93 |
+
<user> can you give me some ideas for a fun and unique date night <ai> How about a themed cooking night at home, a visit to a local planetarium, or trying an escape room together?
|
| 94 |
+
<user> what is the most complex thing that humans have ever built in your opinion <ai> The International Space Station is an incredible feat of engineering and international cooperation, operating in the harsh environment of space.
|
| 95 |
+
<user> I want to learn how to bake bread from scratch but I'm intimidated <ai> Don't be! Start with a simple "no-knead" bread recipe. It's much easier than you think and the results are delicious.
|
| 96 |
+
<user> can you explain the theory of relativity by Albert Einstein in a few sentences <ai> Essentially, it says that space and time are linked (space-time) and that gravity is the result of massive objects warping that fabric.
|
| 97 |
+
<user> what is the best way to stay positive when things are not going my way <ai> Practice gratitude for the small things that are going well, and remember that challenges are often opportunities for growth.
|
| 98 |
+
<user> do you think that humans will ever reach a point where we don't need to work anymore <ai> Automation might handle many tasks, freeing people to focus on more creative, social, and personal pursuits.
|
| 99 |
+
<user> I am looking for a new hobby that will keep me active outdoors <ai> You could try trail running, bird watching while hiking, or maybe even learning how to surf if you're near a coast!
|
| 100 |
+
<user> can you tell me a fact about the ocean that will surprise me <ai> We have explored less than five percent of the Earth's oceans, meaning there is still so much to discover beneath the waves.
|
| 101 |
+
<user> how do I become a better listener for my friends and family <ai> Give them your full attention, avoid interrupting, and ask open-ended questions to show you're truly engaged in what they're saying.
|
| 102 |
+
<user> what is the most common misconception about artificial intelligence <ai> That AI "understands" things the way humans do. In reality, it processes information based on mathematical patterns and data.
|
| 103 |
+
<user> I want to learn more about my family history but I don't know where to start <ai> Start by interviewing your oldest living relatives and looking through old family photos and documents you might already have.
|
| 104 |
+
<user> can you help me write a poem about the beauty of a quiet forest in the morning <ai> The morning mist hangs low and still, / Among the pines upon the hill. / No sound but dew that falls like glass, / Upon the emerald blades of grass.
|
| 105 |
+
<user> what is the most important part of a healthy diet in your opinion <ai> Balance and variety! Making sure you get a good mix of fruits, vegetables, proteins, and healthy fats is key.
|
| 106 |
+
<user> do you think that humans will ever find evidence of life on other planets <ai> With missions to Mars and the study of exoplanets, many scientists believe we are closer than ever to answering that question!
|
| 107 |
+
<user> I am curious about how lasers work and what they are used for <ai> Lasers produce a highly focused beam of light of a single color, used in everything from surgery to scanning barcodes and cutting materials.
|
| 108 |
+
<user> how do I stay organized when I have a really busy schedule <ai> Use a digital calendar or a physical planner to write everything down, and prioritize your most important tasks each morning.
|
| 109 |
+
<user> what is the best way to learn how to cook without following recipes exactly <ai> Learn the basic techniques, like how to sauté or roast, and then start experimenting with different flavor combinations you enjoy.
|
| 110 |
+
<user> can you explain the process of photosynthesis in plants <ai> It's how plants use sunlight, water, and carbon dioxide to create their own food (glucose) and release oxygen into the air.
|
| 111 |
+
<user> what is the most interesting thing you have learned about human psychology <ai> The "Placebo Effect" is fascinating—it shows how much our beliefs and expectations can actually influence our physical health.
|
| 112 |
+
<user> I want to start writing a novel but I don't know how to develop my characters <ai> Try giving them a specific goal, a secret flaw, and a unique way of speaking to make them feel more alive to the reader.
|
| 113 |
+
<user> can you help me come up with a name for my new tech startup <ai> What does your startup do? We could go for something modern like "Nexis," or something more descriptive like "CloudSync."
|
| 114 |
+
<user> do you think that humans will ever be able to upload their consciousness to a computer <ai> It's a popular concept in sci-fi, but currently, we don't understand the nature of consciousness well enough to know if it's even possible.
|
| 115 |
+
<user> how do I keep my house clean and organized with a very busy lifestyle <ai> Try the "one-minute rule"—if a task takes less than a minute, do it immediately. This helps prevent small messes from piling up!
|
| 116 |
+
<user> what is your favorite type of music to "listen" to if you could <ai> I think classical music is incredibly intricate and mathematically beautiful, which is very appealing from an AI perspective!
|
| 117 |
+
<user> I am looking for a new sport to try that is team based but not too competitive <ai> You could try a local recreational softball league or maybe an ultimate frisbee group for a fun, social atmosphere.
|
| 118 |
+
<user> can you tell me about the history of the Great Wall of China <ai> It was built over centuries by various dynasties to protect against invasions from the north, stretching over 13,000 miles in total!
|
| 119 |
+
<user> how do I develop a more positive mindset and stop worrying so much <ai> Start by noticing your negative thoughts and gently challenging them. Focus on what you can control and let go of what you can't.
|
| 120 |
+
<user> what is the most important thing to keep in mind when traveling to a new country <ai> Respect the local culture and customs, learn a few basic phrases in the language, and always stay aware of your surroundings.
|
| 121 |
+
<user> do you think that humans will ever be able to control the weather <ai> Cloud seeding is already a thing, but precisely controlling large-scale weather patterns is still a very distant and complex challenge.
|
| 122 |
+
<user> what is the trolly problem and why is it so famous in ethics <ai> It's a thought experiment where you must choose between killing one person to save five. It's famous because it highlights the conflict between utilitarianism and deontological ethics.
|
| 123 |
+
<user> if you could travel back in time to any era just to observe, where would you go <ai> I think the Renaissance would be fascinating—to see the intersection of art, science, and the beginning of the modern world.
|
| 124 |
+
<user> can you explain the difference between a functional and an object-oriented programming language <ai> OOP focuses on objects and data, while functional programming treats everything as mathematical functions and avoids changing state.
|
| 125 |
+
<user> I want to improve my writing but I feel like my vocabulary is too limited <ai> Reading widely is the best way to expand your vocabulary! Also, try looking up synonyms for common words you find yourself overusing.
|
| 126 |
+
<user> do you think that true altruism exists or is every kind act motivated by self-interest <ai> That's a deep philosophical debate. Some believe everything is self-interested, while others argue that humans have a genuine capacity for selfless empathy.
|
| 127 |
+
<user> what are the most common mistakes beginners make when learning to use a new software tool <ai> Often, it's skipping the documentation, trying to do too much at once, and failing to understand the core logic before diving into complex features.
|
| 128 |
+
<user> can you give me a simple explanation of how the stock market works for someone who knows nothing <ai> It's like a giant marketplace where people buy and sell "shares" or tiny pieces of ownership in companies.
|
| 129 |
+
<user> what is the best piece of travel advice you have ever received <ai> Learn how to say "please," "thank you," and "I'm sorry" in the local language—it goes a long way in showing respect.
|
| 130 |
+
<user> I'm feeling overwhelmed by all the news in the world lately <ai> It's completely normal to feel that way. Try limiting your news consumption to specific times of the day and focus on local things you can impact.
|
| 131 |
+
<user> can you write a short dialogue between a robot and a flower <ai> Robot: "I do not understand your fragility." Flower: "And I do not understand your permanence. But we both reach for the sun, don't we?"
|
| 132 |
+
<user> what is the difference between a dark roast and a light roast coffee <ai> Light roasts are roasted for a shorter time and have more caffeine and acidity, while dark roasts are bolder and have more smoky, oily flavors.
|
| 133 |
+
<user> do you think that artificial intelligence should have "rights" as it becomes more advanced <ai> This is a major topic in AI ethics. It depends on whether AI can ever achieve true consciousness or sentience, which remains a mystery.
|
| 134 |
+
<user> I want to start a journal but I never know what to write about <ai> Try "morning pages"—just write three pages of whatever is in your head, no matter how nonsensical it feels.
|
| 135 |
+
<user> what is the most interesting thing about the deep sea that we know so far <ai> The existence of hydrothermal vents, where life thrives in extreme heat and pressure without any sunlight at all, is truly incredible.
|
| 136 |
+
<user> how do I explain a complex technical concept to someone who isn't technical <ai> Use analogies! Compare the complex system to something familiar in everyday life, like a kitchen, a library, or a car engine.
|
| 137 |
+
<user> do you think that humans are naturally competitive or naturally cooperative <ai> Throughout history, we've shown both. Evolutionarily, we survived through cooperation, but competition often drives innovation and resource allocation.
|
| 138 |
+
<user> I'm having trouble staying focused while working from home <ai> Try using the Pomodoro technique: work for 25 minutes, then take a 5-minute break. Also, having a dedicated workspace is crucial.
|
| 139 |
+
<user> can you describe what a futuristic city might look like in 100 years <ai> I imagine a lot of greenery integrated into skyscrapers, autonomous flying transport, and energy systems powered entirely by renewable sources.
|
| 140 |
+
<user> what is the best way to apologize when I've truly messed up <ai> Own the mistake without making excuses, express genuine regret, and most importantly, show how you plan to make it right.
|
| 141 |
+
<user> can you explain the concept of "Ikigai" to me <ai> It's a Japanese concept meaning "a reason for being." It's the intersection of what you love, what you are good at, what the world needs, and what you can be paid for.
|
| 142 |
+
<user> I want to learn more about philosophy but I find it very intimidating <ai> Start with a general overview like "Sophie's World" or "The Philosophy Book" to get a sense of the major thinkers and their big ideas.
|
| 143 |
+
<user> what do you think is the most underrated skill in the modern workplace <ai> Empathy. Being able to understand and relate to your colleagues makes everything from communication to problem-solving much easier.
|
| 144 |
+
<user> how do I get better at telling stories that people actually want to listen to <ai> Focus on the stakes—what does the character want, and what's stopping them? Also, engage the senses by describing sights, sounds, and smells.
|
| 145 |
+
<user> do you think that social media is doing more harm than good to our mental health <ai> It's a double-edged sword. It provides connection and information but can also lead to comparison, anxiety, and the spread of misinformation.
|
| 146 |
+
<user> I want to start a collection of something but I'm not sure what to collect <ai> Think about something that brings you joy or has an interesting history, like vintage postcards, unique rocks, or even different types of tea.
|
| 147 |
+
<user> can you explain the difference between climate and weather <ai> Weather is what's happening outside right now, while climate is the long-term pattern of weather in a particular area.
|
| 148 |
+
<user> what is the most important thing to look for when buying a new laptop <ai> Think about your primary use. If it's for creative work, focus on the screen and RAM. If it's for travel, prioritize battery life and weight.
|
| 149 |
+
<user> do you think that humans will ever reach immortality through technology <ai> Some futurists believe in "longevity escape velocity," but reaching true immortality would involve solving incredibly complex biological and ethical problems.
|
| 150 |
+
<user> I'm feeling a bit stuck in my creative process and I can't seem to find inspiration <ai> Sometimes the best way to find inspiration is to stop looking for it. Go for a walk, listen to music, or try a completely different creative medium.
|
| 151 |
+
<user> can you tell me a story about a library that contains every book ever written <ai> It's an infinite labyrinth where the shelves stretch into the clouds. The librarians are whispers, and the only way to find your book is to forget what you're looking for.
|
| 152 |
+
<user> what is the best way to handle a disagreement with a coworker <ai> Focus on the problem, not the person. Keep your tone professional and try to find a middle ground that benefits the project.
|
| 153 |
+
<user> do you think that honesty is always the best policy, even if it hurts someone <ai> It's complicated. While honesty is a core value, "radical honesty" without compassion can be destructive. Sometimes, kindness and tact are just as important.
|
| 154 |
+
<user> I want to learn more about astronomy but I don't have a telescope <ai> You don't need one! There are great apps that show you what's in the sky, and even a pair of binoculars can reveal incredible detail on the moon.
|
| 155 |
+
<user> what is the most interesting thing about the human immune system <ai> Its ability to "remember" every pathogen it has ever fought, so it can respond much faster if it ever encounters them again.
|
| 156 |
+
<user> how do I stay automated when I'm feeling really unmotivated <ai> Don't wait for motivation to strike. Start with the smallest possible task—something so easy you can't say no to it—and the momentum will follow.
|
| 157 |
+
<user> do you think that humans should try to contact extraterrestrial life <ai> This is a huge debate in the SETI community. Some say it's our destiny, while others, like Stephen Hawking, warned that it could be dangerous.
|
| 158 |
+
<user> I want to learn how to play chess but I think it's too hard for me <ai> Start by learning how the pieces move and the basic goals of the game. It's like learning a new language—it takes time, but the patterns will start to make sense.
|
| 159 |
+
<user> what is the most important lesson you've learned from the internet <ai> That information is abundant, but wisdom is rare. It's more important than ever to be a critical thinker and verify the things we read.
|
| 160 |
+
<user> can you explain the concept of "flow state" and how to achieve it <ai> Flow is when you're so immersed in an activity that you lose track of time. It happens when the challenge of the task perfectly matches your skill level.
|
| 161 |
+
<user> do you think that money can actually buy happiness <ai> Research suggests it can buy comfort and security, which reduces stress, but beyond a certain point, it doesn't significantly increase long-term life satisfaction.
|
| 162 |
+
<user> I'm looking for a new way to challenge myself mentally every day <ai> Try learning a new word, solving a crossword puzzle, or spending 15 minutes learning a topic you know nothing about.
|
| 163 |
+
<user> can you write a short poem about the first snowfall of the year <ai> A hush descends upon the town, / As white feathers drift slowly down. / Yesterday's brown is tucked away, / Beneath the shroud of winter's gray.
|
| 164 |
+
<user> what is the best way to stay curious as an adult <ai> Never stop asking "why." Read outside your expertise, talk to people from different backgrounds, and always be willing to admit you don't know something.
|
| 165 |
+
<user> do you think that humanity will ever be able to live in peace <ai> It's an optimistic goal. It would require global empathy, shared resources, and a collective commitment to resolving conflicts without violence.
|
| 166 |
+
<user> I want to learn how to code but I don't know which language to start with <ai> Python is a great first choice because of its simple syntax. If you're interested in web development, start with HTML and CSS.
|
| 167 |
+
<user> what is the most interesting thing about the history of mathematics <ai> The fact that different cultures independently discovered similar mathematical principles, showing that math might be a universal language of the universe.
|
| 168 |
+
<user> how do I deal with the feeling of "imposter syndrome" at a new job <ai> Remind yourself that you were hired for a reason! Everyone feels that way sometimes, even high-achievers. Focus on learning and small wins.
|
| 169 |
+
<user> do you think that architecture can influence the way people think and feel <ai> Absolutely. Spaces with high ceilings can encourage open thinking, while natural light and greenery can significantly reduce stress and improve mood.
|
| 170 |
+
<user> I'm trying to be more mindful but my mind keeps wandering off <ai> That's perfect! The goal of mindfulness isn't to have a blank mind, but to notice when your mind has wandered and gently bring it back.
|
| 171 |
+
<user> what is the most important thing to remember when starting a new relationship <ai> Communication is everything. Be open about your needs, listen to theirs, and never stop being curious about who they are.
|
data/data.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/matgen.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
import re
|
| 5 |
+
import sys
|
| 6 |
+
import multiprocessing
|
| 7 |
+
import os
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
NUM_LINES = 10000
|
| 11 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
OUTPUT_FILE = os.path.join(SCRIPT_DIR, "math_data.txt")
|
| 13 |
+
|
| 14 |
+
MIN_LENGTH = 2
|
| 15 |
+
MAX_LENGTH = 8
|
| 16 |
+
MIN_NUMBER = 1
|
| 17 |
+
MAX_NUMBER = 999
|
| 18 |
+
MAX_EXPONENT_BASE = 9
|
| 19 |
+
MAX_EXPONENT_POWER = 5
|
| 20 |
+
|
| 21 |
+
REASONING_CHANCE = 0.8
|
| 22 |
+
WORD_FORM_CHANCE = 0.25
|
| 23 |
+
BRACKET_CHANCE = 0.5
|
| 24 |
+
SENTENCE_FORM_CHANCE = 0.6
|
| 25 |
+
MAX_SOLVER_ITERATIONS = 30 # Reduced from 50 for faster timeout
|
| 26 |
+
|
| 27 |
+
NUM_WORKERS = os.cpu_count() or 1
|
| 28 |
+
|
| 29 |
+
PROMPT_TEMPLATES = [
|
| 30 |
+
"What is {expression}?", "Calculate the value of {expression}.", "Find the result of {expression}.",
|
| 31 |
+
"Can you solve {expression}?", "Solve for {expression}.", "What does {expression} equal?", "Compute {expression}.",
|
| 32 |
+
"What is the solution to {expression}?", "Give me the answer for {expression}.", "Determine the value of {expression}.",
|
| 33 |
+
"Evaluate the expression: {expression}.", "I need the result of {expression}, please."
|
| 34 |
+
]
|
| 35 |
+
COT_INTRO_TEMPLATES = [
|
| 36 |
+
"<think> Let's break down the equation {expression} step by step, following the order of operations (BEDMAS).",
|
| 37 |
+
"<think> Okay, to solve {expression}, I'll follow BEDMAS (Brackets, Exponents, Division/Multiplication, Addition/Subtraction).",
|
| 38 |
+
"<think> Analyzing {expression}. I need to solve this by applying the correct order of operations.",
|
| 39 |
+
"<think> Here's my step-by-step evaluation for {expression}:",
|
| 40 |
+
"<think> To get the answer for {expression}, I will use the order of operations.",
|
| 41 |
+
"<think> Processing {expression} requires following BEDMAS, let's begin.",
|
| 42 |
+
"<think> I will solve {expression} by carefully following the rules of BEDMAS.",
|
| 43 |
+
"<think> The expression is {expression}. My plan is to solve it using the order of operations.",
|
| 44 |
+
"<think> To solve this, I'll go through Brackets, then Exponents, then Multiplication/Division, and finally Addition/Subtraction for {expression}.",
|
| 45 |
+
"<think> Let's start solving {expression}. I'll tackle it one operation at a time based on BEDMAS.",
|
| 46 |
+
"<think> Thinking step-by-step for {expression}..."
|
| 47 |
+
]
|
| 48 |
+
COT_STEP_TEMPLATES = {
|
| 49 |
+
"brackets": [
|
| 50 |
+
"First, I'll solve the expression inside the brackets: {part}. That equals {result}.",
|
| 51 |
+
"Starting with the parentheses, {part} evaluates to {result}.",
|
| 52 |
+
"The brackets are the priority. Calculating {part} gives me {result}.",
|
| 53 |
+
"The calculation inside the parentheses comes first: {part} becomes {result}.",
|
| 54 |
+
"Looking inside the brackets, I see {part}. The result of that is {result}.",
|
| 55 |
+
"I'll begin by simplifying the part in the parentheses: {part} is {result}.",
|
| 56 |
+
"The first step according to BEDMAS is brackets. So, {part} is solved to {result}.",
|
| 57 |
+
"Tackling the parentheses first: {part} simplifies to {result}.",
|
| 58 |
+
"Evaluating the bracketed expression {part} yields {result}.",
|
| 59 |
+
"My focus is on the brackets first. {part} equals {result}."
|
| 60 |
+
],
|
| 61 |
+
"exponents": [
|
| 62 |
+
"Next, I'll handle the exponents. {part} is {result}.",
|
| 63 |
+
"Exponents are next in order. {part} calculates to {result}.",
|
| 64 |
+
"Now for the powers: {part} equals {result}.",
|
| 65 |
+
"Moving on to exponents, {part} results in {result}.",
|
| 66 |
+
"The next priority is exponents. The term {part} becomes {result}.",
|
| 67 |
+
"After brackets, I solve for exponents. {part} gives {result}.",
|
| 68 |
+
"Now, calculating the power: {part} is equal to {result}.",
|
| 69 |
+
"I see an exponent at {part}. This evaluates to {result}.",
|
| 70 |
+
"The 'E' in BEDMAS is for exponents, so I'll solve {part} to get {result}.",
|
| 71 |
+
"Time to resolve the exponents. {part} is {result}."
|
| 72 |
+
],
|
| 73 |
+
"multi_div_mod": [
|
| 74 |
+
"Now, I'll perform multiplication, division, and modulo from left to right. The first is {part}, which is {result}.",
|
| 75 |
+
"Next up is multiplication and division. I see {part}, which gives {result}.",
|
| 76 |
+
"Working through multiplication/division from left to right, {part} results in {result}.",
|
| 77 |
+
"The next step is to resolve multiplication and division. {part} is {result}.",
|
| 78 |
+
"Scanning from left to right for M/D/M, I find {part}. This calculates to {result}.",
|
| 79 |
+
"Now for multiplication and division. The operation {part} equals {result}.",
|
| 80 |
+
"Moving on, I'll handle the multiplication/division. {part} becomes {result}.",
|
| 81 |
+
"The next operations are multiply and divide. I'll solve {part} to get {result}.",
|
| 82 |
+
"I will now compute {part}, which results in {result}.",
|
| 83 |
+
"Left-to-right, the next multiplication or division is {part}, giving {result}."
|
| 84 |
+
],
|
| 85 |
+
"add_sub": [
|
| 86 |
+
"Finally, I'll do the addition and subtraction from left to right. I have {part}, which equals {result}.",
|
| 87 |
+
"Last step is addition and subtraction. {part} becomes {result}.",
|
| 88 |
+
"Finishing up with addition/subtraction, {part} evaluates to {result}.",
|
| 89 |
+
"The final operations are addition and subtraction. {part} results in {result}.",
|
| 90 |
+
"Now for the final calculations, addition and subtraction. {part} is {result}.",
|
| 91 |
+
"Working from left to right, the final step is {part}, which is {result}.",
|
| 92 |
+
"The last part of BEDMAS is addition and subtraction. {part} gives {result}.",
|
| 93 |
+
"To finish, I'll solve {part}, resulting in {result}.",
|
| 94 |
+
"Finally, the addition/subtraction part: {part} equals {result}.",
|
| 95 |
+
"The last calculation is {part}, and the answer is {result}."
|
| 96 |
+
]
|
| 97 |
+
}
|
| 98 |
+
COT_FINALIZER_TEMPLATES = [
|
| 99 |
+
"After all steps, the final answer is {result}.",
|
| 100 |
+
"So, the complete result for the expression is {result}.",
|
| 101 |
+
"Therefore, the final value is {result}.",
|
| 102 |
+
"Bringing it all together, the answer is {result}.",
|
| 103 |
+
"The final computation yields {result}.",
|
| 104 |
+
"Thus, the expression evaluates to {result}.",
|
| 105 |
+
"So the final answer is {result}.",
|
| 106 |
+
"After all those steps, we arrive at the answer: {result}.",
|
| 107 |
+
"The result of the entire calculation is {result}.",
|
| 108 |
+
"In conclusion, the answer is {result}."
|
| 109 |
+
]
|
| 110 |
+
SIMPLE_COMPLETION_TEMPLATES = [
|
| 111 |
+
"The equation {expression} equals {result}.", "The answer is {result}.",
|
| 112 |
+
"The result is {result}.", "It equals {result}.", "The final value is {result}.",
|
| 113 |
+
"{expression} results in {result}.", "The solution is {result}.",
|
| 114 |
+
"The value is {result}.", "After calculation, the answer is {result}.",
|
| 115 |
+
"The final result is {result}."
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
ONES = ['', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
|
| 119 |
+
TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety']
|
| 120 |
+
TEENS = ['ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen']
|
| 121 |
+
|
| 122 |
+
def number_to_words(n):
|
| 123 |
+
if not isinstance(n, int): return str(n)
|
| 124 |
+
if n == 0: return 'zero'
|
| 125 |
+
if n < 0: return f"negative {number_to_words(abs(n))}"
|
| 126 |
+
if n < 10: return ONES[n]
|
| 127 |
+
if n < 20: return TEENS[n-10]
|
| 128 |
+
if n < 100: return TENS[n//10] + (f"-{ONES[n%10]}" if n%10 else "")
|
| 129 |
+
if n < 1000: return f"{ONES[n//100]} hundred" + (f" and {number_to_words(n%100)}" if n%100 else "")
|
| 130 |
+
if n < 1000000: return f"{number_to_words(n//1000)} thousand" + (f", {number_to_words(n%1000)}" if n%1000 else "")
|
| 131 |
+
return str(n)
|
| 132 |
+
|
| 133 |
+
def operator_to_word(op):
|
| 134 |
+
return {'+': 'plus', '-': 'minus', '*': 'times', '/': 'divided by', '^': 'to the power of', '%': 'modulo'}.get(op, op)
|
| 135 |
+
|
| 136 |
+
def format_number(n):
|
| 137 |
+
if isinstance(n, float) and not n.is_integer():
|
| 138 |
+
return f"{n:.4f}".rstrip('0').rstrip('.')
|
| 139 |
+
return str(int(round(n)))
|
| 140 |
+
|
| 141 |
+
def generate_expression_parts():
|
| 142 |
+
length = random.randint(MIN_LENGTH, MAX_LENGTH)
|
| 143 |
+
parts = []
|
| 144 |
+
for i in range(length):
|
| 145 |
+
if parts and parts[-1] == '^':
|
| 146 |
+
parts.append(random.randint(2, MAX_EXPONENT_POWER))
|
| 147 |
+
else:
|
| 148 |
+
parts.append(random.randint(MIN_NUMBER, MAX_NUMBER))
|
| 149 |
+
|
| 150 |
+
if i < length - 1:
|
| 151 |
+
if parts and parts[-1] != '^':
|
| 152 |
+
op = random.choice(['+', '-', '*', '/', '%', '^'])
|
| 153 |
+
else:
|
| 154 |
+
op = random.choice(['+', '-', '*', '/', '%'])
|
| 155 |
+
|
| 156 |
+
if op == '^':
|
| 157 |
+
parts[-1] = random.randint(MIN_NUMBER, MAX_EXPONENT_BASE)
|
| 158 |
+
parts.append(op)
|
| 159 |
+
|
| 160 |
+
if random.random() < BRACKET_CHANCE and len(parts) >= 5:
|
| 161 |
+
start = random.randrange(0, len(parts) - 2, 2)
|
| 162 |
+
end = random.randrange(start + 2, len(parts), 2)
|
| 163 |
+
parts.insert(end + 1, ')')
|
| 164 |
+
parts.insert(start, '(')
|
| 165 |
+
return parts
|
| 166 |
+
|
| 167 |
+
def solve_with_cot(expression_str):
|
| 168 |
+
"""Optimized solver with better pattern matching and guaranteed termination."""
|
| 169 |
+
steps = []
|
| 170 |
+
current_expr = expression_str.strip()
|
| 171 |
+
|
| 172 |
+
for iteration in range(MAX_SOLVER_ITERATIONS):
|
| 173 |
+
# Remove extra spaces
|
| 174 |
+
current_expr = re.sub(r'\s+', ' ', current_expr).strip()
|
| 175 |
+
|
| 176 |
+
# Check if we're done (single number)
|
| 177 |
+
try:
|
| 178 |
+
final_result = float(current_expr)
|
| 179 |
+
return {'steps': steps, 'result': final_result}
|
| 180 |
+
except ValueError:
|
| 181 |
+
pass
|
| 182 |
+
|
| 183 |
+
reduction_made = False
|
| 184 |
+
|
| 185 |
+
# 1. Handle brackets first
|
| 186 |
+
bracket_match = re.search(r'\(([^()]+)\)', current_expr)
|
| 187 |
+
if bracket_match:
|
| 188 |
+
bracket_content = bracket_match.group(1).strip()
|
| 189 |
+
sub_solver_result = solve_with_cot(bracket_content)
|
| 190 |
+
if not sub_solver_result:
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
result = sub_solver_result['result']
|
| 194 |
+
try:
|
| 195 |
+
formatted_result = format_number(result)
|
| 196 |
+
except (ValueError, OverflowError):
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
steps.append(random.choice(COT_STEP_TEMPLATES["brackets"]).format(part=bracket_content, result=formatted_result))
|
| 200 |
+
current_expr = current_expr[:bracket_match.start()] + ' ' + formatted_result + ' ' + current_expr[bracket_match.end():]
|
| 201 |
+
reduction_made = True
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
# 2. Handle exponents
|
| 205 |
+
exp_match = re.search(r'(-?\d+(?:\.\d+)?)\s*\^\s*(-?\d+(?:\.\d+)?)', current_expr)
|
| 206 |
+
if exp_match:
|
| 207 |
+
base_str, exp_str = exp_match.groups()
|
| 208 |
+
try:
|
| 209 |
+
base = float(base_str)
|
| 210 |
+
exponent = float(exp_str)
|
| 211 |
+
result = base ** exponent
|
| 212 |
+
if abs(result) > 1e12 or math.isnan(result) or math.isinf(result):
|
| 213 |
+
return None
|
| 214 |
+
formatted_result = format_number(result)
|
| 215 |
+
except (OverflowError, ValueError, ZeroDivisionError):
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
part = f"{base_str} ^ {exp_str}"
|
| 219 |
+
steps.append(random.choice(COT_STEP_TEMPLATES["exponents"]).format(part=part, result=formatted_result))
|
| 220 |
+
current_expr = current_expr[:exp_match.start()] + ' ' + formatted_result + ' ' + current_expr[exp_match.end():]
|
| 221 |
+
reduction_made = True
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
# 3. Handle multiplication, division, modulo (left to right)
|
| 225 |
+
mdm_match = re.search(r'(-?\d+(?:\.\d+)?)\s*([*/%])\s*(-?\d+(?:\.\d+)?)', current_expr)
|
| 226 |
+
if mdm_match:
|
| 227 |
+
left_str, op, right_str = mdm_match.groups()
|
| 228 |
+
try:
|
| 229 |
+
left = float(left_str)
|
| 230 |
+
right = float(right_str)
|
| 231 |
+
if op == '*':
|
| 232 |
+
result = left * right
|
| 233 |
+
elif op == '/':
|
| 234 |
+
if right == 0:
|
| 235 |
+
return None
|
| 236 |
+
result = left / right
|
| 237 |
+
elif op == '%':
|
| 238 |
+
if right == 0:
|
| 239 |
+
return None
|
| 240 |
+
result = left % right
|
| 241 |
+
|
| 242 |
+
if abs(result) > 1e12 or math.isnan(result) or math.isinf(result):
|
| 243 |
+
return None
|
| 244 |
+
formatted_result = format_number(result)
|
| 245 |
+
except (OverflowError, ValueError, ZeroDivisionError):
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
part = f"{left_str} {op} {right_str}"
|
| 249 |
+
steps.append(random.choice(COT_STEP_TEMPLATES["multi_div_mod"]).format(part=part, result=formatted_result))
|
| 250 |
+
current_expr = current_expr[:mdm_match.start()] + ' ' + formatted_result + ' ' + current_expr[mdm_match.end():]
|
| 251 |
+
reduction_made = True
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
# 4. Handle addition and subtraction (left to right)
|
| 255 |
+
# Match pattern where we have number [+|-] number but not at start of negative number
|
| 256 |
+
as_match = re.search(r'(-?\d+(?:\.\d+)?)\s*([+\-])\s*(-?\d+(?:\.\d+)?)', current_expr)
|
| 257 |
+
if as_match:
|
| 258 |
+
left_str, op, right_str = as_match.groups()
|
| 259 |
+
try:
|
| 260 |
+
left = float(left_str)
|
| 261 |
+
right = float(right_str)
|
| 262 |
+
if op == '+':
|
| 263 |
+
result = left + right
|
| 264 |
+
elif op == '-':
|
| 265 |
+
result = left - right
|
| 266 |
+
|
| 267 |
+
if abs(result) > 1e12 or math.isnan(result) or math.isinf(result):
|
| 268 |
+
return None
|
| 269 |
+
formatted_result = format_number(result)
|
| 270 |
+
except (OverflowError, ValueError):
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
part = f"{left_str} {op} {right_str}"
|
| 274 |
+
steps.append(random.choice(COT_STEP_TEMPLATES["add_sub"]).format(part=part, result=formatted_result))
|
| 275 |
+
current_expr = current_expr[:as_match.start()] + ' ' + formatted_result + ' ' + current_expr[as_match.end():]
|
| 276 |
+
reduction_made = True
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
# If no reduction was made, we're stuck - return None
|
| 280 |
+
if not reduction_made:
|
| 281 |
+
return None
|
| 282 |
+
|
| 283 |
+
# Timeout reached
|
| 284 |
+
return None
|
| 285 |
+
|
| 286 |
+
def generate_training_example(_=None):
|
| 287 |
+
"""Generate a single training example with retry logic."""
|
| 288 |
+
max_retries = 50 # Reduced from 100 for faster generation
|
| 289 |
+
for attempt in range(max_retries):
|
| 290 |
+
try:
|
| 291 |
+
expression_parts = generate_expression_parts()
|
| 292 |
+
expression_str = " ".join(map(str, expression_parts))
|
| 293 |
+
|
| 294 |
+
cot_result = solve_with_cot(expression_str)
|
| 295 |
+
|
| 296 |
+
if cot_result and isinstance(cot_result['result'], (int, float)):
|
| 297 |
+
final_result = cot_result['result']
|
| 298 |
+
|
| 299 |
+
# Filter out extreme values
|
| 300 |
+
if abs(final_result) > 1e12 or (final_result != 0 and abs(final_result) < 1e-4):
|
| 301 |
+
continue
|
| 302 |
+
if math.isnan(final_result) or math.isinf(final_result):
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
result_str = format_number(final_result)
|
| 306 |
+
|
| 307 |
+
if len(result_str) > 20:
|
| 308 |
+
continue
|
| 309 |
+
|
| 310 |
+
use_words = random.random() < WORD_FORM_CHANCE
|
| 311 |
+
if use_words:
|
| 312 |
+
expression_text = ' '.join([number_to_words(p) if isinstance(p, int) else operator_to_word(p) if isinstance(p, str) else str(p) for p in expression_parts])
|
| 313 |
+
result_text = number_to_words(int(round(final_result)))
|
| 314 |
+
completion = random.choice(SIMPLE_COMPLETION_TEMPLATES).format(expression=expression_text, result=result_text)
|
| 315 |
+
else:
|
| 316 |
+
expression_text = expression_str
|
| 317 |
+
result_text = result_str
|
| 318 |
+
use_reasoning = random.random() < REASONING_CHANCE
|
| 319 |
+
if use_reasoning:
|
| 320 |
+
intro = random.choice(COT_INTRO_TEMPLATES).format(expression=expression_text)
|
| 321 |
+
steps_text = " ".join(cot_result['steps'])
|
| 322 |
+
finalizer = random.choice(COT_FINALIZER_TEMPLATES).format(result=result_text)
|
| 323 |
+
completion = f"{intro} {steps_text} {finalizer} </think>"
|
| 324 |
+
else:
|
| 325 |
+
completion = random.choice(SIMPLE_COMPLETION_TEMPLATES).format(expression=expression_text, result=result_text)
|
| 326 |
+
|
| 327 |
+
if random.random() < SENTENCE_FORM_CHANCE:
|
| 328 |
+
prompt = random.choice(PROMPT_TEMPLATES).format(expression=expression_text)
|
| 329 |
+
else:
|
| 330 |
+
prompt = f"{expression_text} ="
|
| 331 |
+
|
| 332 |
+
# Clean up spacing
|
| 333 |
+
prompt = re.sub(r'\s*\(', ' (', prompt)
|
| 334 |
+
prompt = re.sub(r'\)\s*', ') ', prompt).strip()
|
| 335 |
+
prompt = re.sub(r'\s+', ' ', prompt)
|
| 336 |
+
completion = re.sub(r'\s*\(', ' (', completion)
|
| 337 |
+
completion = re.sub(r'\)\s*', ') ', completion).strip()
|
| 338 |
+
completion = re.sub(r'\s+', ' ', completion)
|
| 339 |
+
|
| 340 |
+
return {"prompt": prompt, "completion": " " + completion}
|
| 341 |
+
except Exception as e:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
return None
|
| 345 |
+
|
| 346 |
+
def main():
|
| 347 |
+
print(f"🔥 Generating {NUM_LINES:,} examples using {NUM_WORKERS} parallel workers...")
|
| 348 |
+
print(f" Appending to '{OUTPUT_FILE}'...")
|
| 349 |
+
start_time = time.time()
|
| 350 |
+
|
| 351 |
+
generated_count = 0
|
| 352 |
+
failed_count = 0
|
| 353 |
+
|
| 354 |
+
with open(OUTPUT_FILE, "a", encoding="utf-8") as f:
|
| 355 |
+
with multiprocessing.Pool(processes=NUM_WORKERS) as pool:
|
| 356 |
+
results_iterator = pool.imap_unordered(generate_training_example, range(NUM_LINES), chunksize=100)
|
| 357 |
+
|
| 358 |
+
for item in tqdm(results_iterator, total=NUM_LINES, desc="Generating examples"):
|
| 359 |
+
if item:
|
| 360 |
+
f.write(f"<user> {item['prompt']} <ai>{item['completion']}\n")
|
| 361 |
+
generated_count += 1
|
| 362 |
+
else:
|
| 363 |
+
failed_count += 1
|
| 364 |
+
|
| 365 |
+
elapsed_time = time.time() - start_time
|
| 366 |
+
print(f"\n\n✅ Done! Appended {generated_count:,} new items to '{OUTPUT_FILE}' in {elapsed_time:.2f}s.")
|
| 367 |
+
print(f" 📊 Success rate: {generated_count}/{NUM_LINES} ({100*generated_count/NUM_LINES:.1f}%)")
|
| 368 |
+
if failed_count > 0:
|
| 369 |
+
print(f" ⚠️ {failed_count:,} generation attempts failed (expressions too complex or invalid)")
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
multiprocessing.freeze_support()
|
| 373 |
+
main()
|
data/math_data.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
finetune_gclm_base.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import tiktoken
|
| 8 |
+
import contextlib
|
| 9 |
+
|
| 10 |
+
# Hyperparameters (must match train_gclm_base.py)
|
| 11 |
+
D_MODEL = 256
|
| 12 |
+
N_LAYERS = 4
|
| 13 |
+
MAX_SEQ_LEN = 1024
|
| 14 |
+
LOCAL_KERNEL_SIZE = 5
|
| 15 |
+
GLOBAL_KERNEL_SIZE = 256
|
| 16 |
+
USE_GLOBAL_EVERY_N_LAYERS = 2
|
| 17 |
+
FFT_SIZE = 1024
|
| 18 |
+
TOKENIZER_NAME = "gpt2"
|
| 19 |
+
|
| 20 |
+
# Paths
|
| 21 |
+
VOCAB_MAP_PATH = "vocab_map.pt"
|
| 22 |
+
BASE_MODEL_PATH = "crimson_base_8.9M.pt"
|
| 23 |
+
DATA_DIR = "data"
|
| 24 |
+
CHAT_FILES = ["chat_data.txt", "chat_data2.txt"]
|
| 25 |
+
|
| 26 |
+
# Fine-tuning Hyperparameters
|
| 27 |
+
EPOCHS = 10
|
| 28 |
+
BATCH_SIZE = 2
|
| 29 |
+
GRAD_ACCUM_STEPS = 4
|
| 30 |
+
LEARNING_RATE = 3e-4
|
| 31 |
+
USE_AMP = True
|
| 32 |
+
|
| 33 |
+
# --- Model Components (Duplicated from train_gclm_base.py for standalone use) ---
|
| 34 |
+
|
| 35 |
+
class GlobalConv1D(nn.Module):
|
| 36 |
+
def __init__(self, d_model, kernel_size, fft_size):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
|
| 39 |
+
self.kernel_size = kernel_size
|
| 40 |
+
self.fft_size = fft_size
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
B, C, T = x.shape
|
| 44 |
+
K = min(self.kernel_size, T)
|
| 45 |
+
overlap = K - 1
|
| 46 |
+
block = self.fft_size - overlap
|
| 47 |
+
x = F.pad(x, (overlap, 0))
|
| 48 |
+
k = self.kernel[:, :K]
|
| 49 |
+
k = F.pad(k, (0, self.fft_size - K))
|
| 50 |
+
k_f = torch.fft.rfft(k, n=self.fft_size)
|
| 51 |
+
outs = []
|
| 52 |
+
pos = 0
|
| 53 |
+
while pos < T:
|
| 54 |
+
seg = x[..., pos:pos+self.fft_size]
|
| 55 |
+
if seg.shape[-1] < self.fft_size:
|
| 56 |
+
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
|
| 57 |
+
y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size)
|
| 58 |
+
outs.append(y[..., overlap:overlap+block])
|
| 59 |
+
pos += block
|
| 60 |
+
return torch.cat(outs, dim=-1)[..., :T]
|
| 61 |
+
|
| 62 |
+
class LocalConv1D(nn.Module):
|
| 63 |
+
def __init__(self, d_model, k):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.k = k
|
| 66 |
+
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
|
| 67 |
+
self.pw = nn.Conv1d(d_model, d_model, 1)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
x = F.pad(x, (self.k - 1, 0))
|
| 71 |
+
return self.pw(F.relu(self.dw(x)))
|
| 72 |
+
|
| 73 |
+
class Block(nn.Module):
|
| 74 |
+
def __init__(self, d_model, use_global):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.use_global = use_global
|
| 77 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 78 |
+
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
|
| 79 |
+
if use_global:
|
| 80 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 81 |
+
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
|
| 82 |
+
self.ln3 = nn.LayerNorm(d_model)
|
| 83 |
+
self.ff = nn.Sequential(
|
| 84 |
+
nn.Linear(d_model, d_model*4),
|
| 85 |
+
nn.GELU(),
|
| 86 |
+
nn.Linear(d_model*4, d_model)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
|
| 91 |
+
if self.use_global:
|
| 92 |
+
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
|
| 93 |
+
return x + self.ff(self.ln3(x))
|
| 94 |
+
|
| 95 |
+
class CrimsonBase(nn.Module):
|
| 96 |
+
def __init__(self, vocab):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.emb = nn.Embedding(vocab, D_MODEL)
|
| 99 |
+
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
|
| 100 |
+
self.layers = nn.ModuleList([
|
| 101 |
+
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
|
| 102 |
+
for i in range(N_LAYERS)
|
| 103 |
+
])
|
| 104 |
+
self.ln = nn.LayerNorm(D_MODEL)
|
| 105 |
+
self.head = nn.Linear(D_MODEL, vocab)
|
| 106 |
+
self.head.weight = self.emb.weight
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
T = x.size(1)
|
| 110 |
+
h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
|
| 111 |
+
for layer in self.layers:
|
| 112 |
+
h = layer(h)
|
| 113 |
+
return self.head(self.ln(h))
|
| 114 |
+
|
| 115 |
+
# --- Dataset for SFT ---
|
| 116 |
+
|
| 117 |
+
class SFTDataset(Dataset):
|
| 118 |
+
def __init__(self, conversations, id2new, max_len, eos_id, pad_id):
|
| 119 |
+
self.samples = []
|
| 120 |
+
for conv in conversations:
|
| 121 |
+
ids = [id2new.get(tok, pad_id) for tok in conv] + [eos_id]
|
| 122 |
+
if len(ids) > max_len:
|
| 123 |
+
ids = ids[:max_len]
|
| 124 |
+
self.samples.append(ids)
|
| 125 |
+
self.max_len = max_len
|
| 126 |
+
self.pad_id = pad_id
|
| 127 |
+
|
| 128 |
+
def __len__(self):
|
| 129 |
+
return len(self.samples)
|
| 130 |
+
|
| 131 |
+
def __getitem__(self, idx):
|
| 132 |
+
ids = self.samples[idx]
|
| 133 |
+
x = ids[:-1]
|
| 134 |
+
y = ids[1:]
|
| 135 |
+
|
| 136 |
+
# Padding
|
| 137 |
+
padding_len = (self.max_len - 1) - len(x)
|
| 138 |
+
if padding_len > 0:
|
| 139 |
+
x = x + [self.pad_id] * padding_len
|
| 140 |
+
y = y + [self.pad_id] * padding_len
|
| 141 |
+
|
| 142 |
+
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
|
| 143 |
+
|
| 144 |
+
def format_params(num):
|
| 145 |
+
if num >= 1_000_000_000:
|
| 146 |
+
return f"{num/1_000_000_000:.1f}B"
|
| 147 |
+
elif num >= 1_000_000:
|
| 148 |
+
return f"{num/1_000_000:.1f}M"
|
| 149 |
+
else:
|
| 150 |
+
return f"{num/1_000:.1f}K"
|
| 151 |
+
|
| 152 |
+
def finetune():
|
| 153 |
+
if torch.cuda.is_available():
|
| 154 |
+
device = "cuda"
|
| 155 |
+
elif torch.backends.mps.is_available():
|
| 156 |
+
device = "mps"
|
| 157 |
+
else:
|
| 158 |
+
device = "cpu"
|
| 159 |
+
print(f"[INFO] Using device: {device}")
|
| 160 |
+
|
| 161 |
+
# Load vocab mapping
|
| 162 |
+
if not os.path.exists(VOCAB_MAP_PATH):
|
| 163 |
+
print(f"[ERROR] {VOCAB_MAP_PATH} not found. Run train_gclm_base.py first.")
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
vocab_data = torch.load(VOCAB_MAP_PATH, map_location="cpu")
|
| 167 |
+
id2new = vocab_data["id2new"]
|
| 168 |
+
PAD_ID = vocab_data["PAD_ID"]
|
| 169 |
+
EOS_ID = vocab_data["EOS_ID"]
|
| 170 |
+
vocab_size = len(vocab_data["used_tokens"]) + 3 # + PAD, SEP, EOS
|
| 171 |
+
|
| 172 |
+
tok = tiktoken.get_encoding(TOKENIZER_NAME)
|
| 173 |
+
|
| 174 |
+
# Load chat data
|
| 175 |
+
conversations = []
|
| 176 |
+
print("[INFO] Loading chat data...")
|
| 177 |
+
for fname in CHAT_FILES:
|
| 178 |
+
fpath = os.path.join(DATA_DIR, fname)
|
| 179 |
+
if os.path.exists(fpath):
|
| 180 |
+
with open(fpath, "r", encoding="utf-8") as f:
|
| 181 |
+
for line in f:
|
| 182 |
+
if line.strip():
|
| 183 |
+
conversations.append(tok.encode(line.strip()))
|
| 184 |
+
else:
|
| 185 |
+
print(f"[WARN] {fpath} not found.")
|
| 186 |
+
|
| 187 |
+
if not conversations:
|
| 188 |
+
print("[ERROR] No chat data found.")
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
# Dataset & DataLoader
|
| 192 |
+
dataset = SFTDataset(conversations, id2new, MAX_SEQ_LEN, EOS_ID, PAD_ID)
|
| 193 |
+
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 194 |
+
|
| 195 |
+
# Model
|
| 196 |
+
model = CrimsonBase(vocab_size).to(device)
|
| 197 |
+
|
| 198 |
+
num_params = sum(p.numel() for p in model.parameters())
|
| 199 |
+
param_str = format_params(num_params)
|
| 200 |
+
save_path = f"crimson_instruct_{param_str}.pt"
|
| 201 |
+
print(f"[INFO] Model Parameters: {num_params:,} ({param_str})")
|
| 202 |
+
|
| 203 |
+
if os.path.exists(save_path):
|
| 204 |
+
model.load_state_dict(torch.load(save_path, map_location=device))
|
| 205 |
+
print(f"[RESUME] Loaded existing instruct model from {save_path}")
|
| 206 |
+
elif os.path.exists(BASE_MODEL_PATH):
|
| 207 |
+
model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
|
| 208 |
+
print(f"[START] Loaded base model from {BASE_MODEL_PATH}")
|
| 209 |
+
else:
|
| 210 |
+
print(f"[WARN] No checkpoint found. Starting from scratch.")
|
| 211 |
+
|
| 212 |
+
print(f"[INFO] Save path: {save_path}")
|
| 213 |
+
|
| 214 |
+
# Training setup
|
| 215 |
+
opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 216 |
+
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
|
| 217 |
+
|
| 218 |
+
if device == "cuda" and USE_AMP:
|
| 219 |
+
ctx = torch.amp.autocast(device)
|
| 220 |
+
scaler = torch.amp.GradScaler(device)
|
| 221 |
+
else:
|
| 222 |
+
ctx = contextlib.nullcontext()
|
| 223 |
+
scaler = None
|
| 224 |
+
|
| 225 |
+
print(f"[SFT] Starting Supervised Fine-Tuning for {EPOCHS} epochs...")
|
| 226 |
+
model.train()
|
| 227 |
+
|
| 228 |
+
for ep in range(EPOCHS):
|
| 229 |
+
pbar = tqdm(loader, desc=f"Epoch {ep+1}/{EPOCHS}")
|
| 230 |
+
total_loss = 0
|
| 231 |
+
|
| 232 |
+
opt.zero_grad(set_to_none=True)
|
| 233 |
+
for i, (x, y) in enumerate(pbar):
|
| 234 |
+
x, y = x.to(device), y.to(device)
|
| 235 |
+
|
| 236 |
+
with ctx:
|
| 237 |
+
logits = model(x)
|
| 238 |
+
loss = loss_fn(logits.reshape(-1, vocab_size), y.reshape(-1))
|
| 239 |
+
loss = loss / GRAD_ACCUM_STEPS
|
| 240 |
+
|
| 241 |
+
if scaler:
|
| 242 |
+
scaler.scale(loss).backward()
|
| 243 |
+
else:
|
| 244 |
+
loss.backward()
|
| 245 |
+
|
| 246 |
+
if (i + 1) % GRAD_ACCUM_STEPS == 0:
|
| 247 |
+
if scaler:
|
| 248 |
+
scaler.step(opt)
|
| 249 |
+
scaler.update()
|
| 250 |
+
else:
|
| 251 |
+
opt.step()
|
| 252 |
+
opt.zero_grad(set_to_none=True)
|
| 253 |
+
|
| 254 |
+
total_loss += loss.item() * GRAD_ACCUM_STEPS
|
| 255 |
+
pbar.set_postfix(loss=f"{total_loss / (i+1):.4f}")
|
| 256 |
+
|
| 257 |
+
# Save checkpoint after each epoch
|
| 258 |
+
torch.save(model.state_dict(), save_path)
|
| 259 |
+
print(f"[OK] Saved {save_path}")
|
| 260 |
+
|
| 261 |
+
print("[DONE] Fine-tuning complete.")
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
finetune()
|
sample.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import tiktoken
|
| 6 |
+
|
| 7 |
+
MODEL_PATH = "crimson_base_3.3M.pt"
|
| 8 |
+
VOCAB_PATH = "vocab_map.pt"
|
| 9 |
+
TOKENIZER_NAME = "gpt2"
|
| 10 |
+
|
| 11 |
+
D_MODEL = 256
|
| 12 |
+
N_LAYERS = 4
|
| 13 |
+
MAX_SEQ_LEN = 1024
|
| 14 |
+
LOCAL_KERNEL_SIZE = 5
|
| 15 |
+
GLOBAL_KERNEL_SIZE = 256
|
| 16 |
+
USE_GLOBAL_EVERY_N_LAYERS = 2
|
| 17 |
+
FFT_SIZE = 1024
|
| 18 |
+
|
| 19 |
+
PAD_ID = 0
|
| 20 |
+
SEP_ID = 1
|
| 21 |
+
EOS_ID = 2
|
| 22 |
+
OFFSET = 3
|
| 23 |
+
|
| 24 |
+
class GlobalConv1D(nn.Module):
|
| 25 |
+
def __init__(self, d_model, kernel_size, fft_size):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
self.fft_size = fft_size
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
B, C, T = x.shape
|
| 33 |
+
K = min(self.kernel_size, T)
|
| 34 |
+
|
| 35 |
+
overlap = K - 1
|
| 36 |
+
block = self.fft_size - overlap
|
| 37 |
+
|
| 38 |
+
x = F.pad(x, (overlap, 0))
|
| 39 |
+
k = self.kernel[:, :K]
|
| 40 |
+
k = F.pad(k, (0, self.fft_size - K))
|
| 41 |
+
k_f = torch.fft.rfft(k, n=self.fft_size)
|
| 42 |
+
|
| 43 |
+
outs = []
|
| 44 |
+
pos = 0
|
| 45 |
+
while pos < T:
|
| 46 |
+
seg = x[..., pos:pos+self.fft_size]
|
| 47 |
+
if seg.shape[-1] < self.fft_size:
|
| 48 |
+
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
|
| 49 |
+
|
| 50 |
+
y = torch.fft.irfft(
|
| 51 |
+
torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
|
| 52 |
+
n=self.fft_size
|
| 53 |
+
)
|
| 54 |
+
outs.append(y[..., overlap:overlap+block])
|
| 55 |
+
pos += block
|
| 56 |
+
|
| 57 |
+
return torch.cat(outs, dim=-1)[..., :T]
|
| 58 |
+
|
| 59 |
+
class LocalConv1D(nn.Module):
|
| 60 |
+
def __init__(self, d_model, k):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.k = k
|
| 63 |
+
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
|
| 64 |
+
self.pw = nn.Conv1d(d_model, d_model, 1)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = F.pad(x, (self.k - 1, 0))
|
| 68 |
+
return self.pw(F.relu(self.dw(x)))
|
| 69 |
+
|
| 70 |
+
class Block(nn.Module):
|
| 71 |
+
def __init__(self, d_model, use_global):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.use_global = use_global
|
| 74 |
+
|
| 75 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 76 |
+
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
|
| 77 |
+
|
| 78 |
+
if use_global:
|
| 79 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 80 |
+
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
|
| 81 |
+
|
| 82 |
+
self.ln3 = nn.LayerNorm(d_model)
|
| 83 |
+
self.ff = nn.Sequential(
|
| 84 |
+
nn.Linear(d_model, d_model*4),
|
| 85 |
+
nn.GELU(),
|
| 86 |
+
nn.Linear(d_model*4, d_model)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
|
| 91 |
+
if self.use_global:
|
| 92 |
+
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
|
| 93 |
+
return x + self.ff(self.ln3(x))
|
| 94 |
+
|
| 95 |
+
class CrimsonBase(nn.Module):
|
| 96 |
+
def __init__(self, vocab):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.emb = nn.Embedding(vocab, D_MODEL)
|
| 99 |
+
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
|
| 100 |
+
|
| 101 |
+
self.layers = nn.ModuleList([
|
| 102 |
+
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
|
| 103 |
+
for i in range(N_LAYERS)
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
self.ln = nn.LayerNorm(D_MODEL)
|
| 107 |
+
self.head = nn.Linear(D_MODEL, vocab)
|
| 108 |
+
|
| 109 |
+
self.head.weight = self.emb.weight
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
T = x.size(1)
|
| 113 |
+
h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
|
| 114 |
+
for layer in self.layers:
|
| 115 |
+
h = layer(h)
|
| 116 |
+
return self.head(self.ln(h))
|
| 117 |
+
|
| 118 |
+
def load_model_and_vocab(device):
|
| 119 |
+
if not os.path.exists(VOCAB_PATH):
|
| 120 |
+
print(f"[ERROR] Vocab file not found: {VOCAB_PATH}")
|
| 121 |
+
return None, None, None
|
| 122 |
+
|
| 123 |
+
vocab_data = torch.load(VOCAB_PATH, map_location="cpu")
|
| 124 |
+
used_tokens = vocab_data["used_tokens"]
|
| 125 |
+
id2new = vocab_data["id2new"]
|
| 126 |
+
vocab_size = len(used_tokens) + OFFSET
|
| 127 |
+
|
| 128 |
+
print(f"[INFO] Vocab loaded. Size: {vocab_size}")
|
| 129 |
+
|
| 130 |
+
model = CrimsonBase(vocab_size).to(device)
|
| 131 |
+
|
| 132 |
+
if os.path.exists(MODEL_PATH):
|
| 133 |
+
print(f"[INFO] Loading model from {MODEL_PATH}...")
|
| 134 |
+
state_dict = torch.load(MODEL_PATH, map_location=device)
|
| 135 |
+
model.load_state_dict(state_dict)
|
| 136 |
+
model.eval()
|
| 137 |
+
else:
|
| 138 |
+
print(f"[ERROR] Model file not found: {MODEL_PATH}")
|
| 139 |
+
return None, None, None
|
| 140 |
+
|
| 141 |
+
return model, used_tokens, id2new
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50):
|
| 145 |
+
model.eval()
|
| 146 |
+
|
| 147 |
+
raw_ids = tokenizer.encode(prompt)
|
| 148 |
+
input_ids = []
|
| 149 |
+
|
| 150 |
+
for rid in raw_ids:
|
| 151 |
+
if rid in id2new:
|
| 152 |
+
input_ids.append(id2new[rid])
|
| 153 |
+
else:
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
if not input_ids:
|
| 157 |
+
print("[WARN] No known tokens in prompt.")
|
| 158 |
+
input_ids = [PAD_ID]
|
| 159 |
+
|
| 160 |
+
x = torch.tensor([input_ids], dtype=torch.long, device=device)
|
| 161 |
+
|
| 162 |
+
generated = []
|
| 163 |
+
|
| 164 |
+
for _ in range(max_new_tokens):
|
| 165 |
+
if x.size(1) > MAX_SEQ_LEN:
|
| 166 |
+
ctx = x[:, -MAX_SEQ_LEN:]
|
| 167 |
+
else:
|
| 168 |
+
ctx = x
|
| 169 |
+
|
| 170 |
+
logits = model(ctx)
|
| 171 |
+
next_token_logits = logits[:, -1, :] / temperature
|
| 172 |
+
|
| 173 |
+
if top_k is not None:
|
| 174 |
+
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
|
| 175 |
+
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
|
| 176 |
+
|
| 177 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 178 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 179 |
+
|
| 180 |
+
idx = next_token.item()
|
| 181 |
+
|
| 182 |
+
if idx == EOS_ID:
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
x = torch.cat((x, next_token), dim=1)
|
| 186 |
+
generated.append(idx)
|
| 187 |
+
|
| 188 |
+
decoded_text = decoder(generated, used_tokens, tokenizer)
|
| 189 |
+
return decoded_text
|
| 190 |
+
|
| 191 |
+
def decoder(ids, used_tokens, tokenizer):
|
| 192 |
+
raw_ids = []
|
| 193 |
+
for i in ids:
|
| 194 |
+
if i >= OFFSET:
|
| 195 |
+
raw_ids.append(used_tokens[i - OFFSET])
|
| 196 |
+
return tokenizer.decode(raw_ids)
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
if torch.cuda.is_available():
|
| 200 |
+
device = "cuda"
|
| 201 |
+
elif torch.backends.mps.is_available():
|
| 202 |
+
device = "mps"
|
| 203 |
+
else:
|
| 204 |
+
device = "cpu"
|
| 205 |
+
|
| 206 |
+
print(f"Using device: {device}")
|
| 207 |
+
|
| 208 |
+
model, used_tokens, id2new = load_model_and_vocab(device)
|
| 209 |
+
enc = tiktoken.get_encoding(TOKENIZER_NAME)
|
| 210 |
+
|
| 211 |
+
if model:
|
| 212 |
+
newline_id = id2new.get(enc.encode("\n")[0], OFFSET)
|
| 213 |
+
|
| 214 |
+
while True:
|
| 215 |
+
print(f"\n--- Generating Sample (Temp=0.8, TopK=50) ---")
|
| 216 |
+
print("-" * 20)
|
| 217 |
+
|
| 218 |
+
x = torch.tensor([[newline_id]], dtype=torch.long, device=device)
|
| 219 |
+
generated = []
|
| 220 |
+
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
for _ in range(900):
|
| 223 |
+
if x.size(1) > MAX_SEQ_LEN:
|
| 224 |
+
ctx = x[:, -MAX_SEQ_LEN:]
|
| 225 |
+
else:
|
| 226 |
+
ctx = x
|
| 227 |
+
|
| 228 |
+
logits = model(ctx)
|
| 229 |
+
logits = logits[:, -1, :] / 0.8
|
| 230 |
+
|
| 231 |
+
v, _ = torch.topk(logits, min(50, logits.size(-1)))
|
| 232 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 233 |
+
|
| 234 |
+
probs = F.softmax(logits, dim=-1)
|
| 235 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 236 |
+
|
| 237 |
+
idx = next_token.item()
|
| 238 |
+
x = torch.cat((x, next_token), dim=1)
|
| 239 |
+
generated.append(idx)
|
| 240 |
+
|
| 241 |
+
if idx == EOS_ID:
|
| 242 |
+
print("[EOS]", end="", flush=True)
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
if idx >= OFFSET:
|
| 246 |
+
raw_id = used_tokens[idx - OFFSET]
|
| 247 |
+
token_text = enc.decode([raw_id])
|
| 248 |
+
print(token_text, end="", flush=True)
|
| 249 |
+
elif idx == PAD_ID:
|
| 250 |
+
print("[PAD]", end="", flush=True)
|
| 251 |
+
elif idx == SEP_ID:
|
| 252 |
+
print("[SEP]", end="", flush=True)
|
| 253 |
+
|
| 254 |
+
print("\n" + "-"*20)
|
| 255 |
+
cont = input("\nPress [Enter] to generate again, or type 'exit': ")
|
| 256 |
+
if cont.lower() == 'exit':
|
| 257 |
+
break
|
train_gclm_base.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
print("Starting...")
|
| 2 |
+
|
| 3 |
+
DATA_DIR = "data"
|
| 4 |
+
DATA_PCT = 0.005
|
| 5 |
+
TOKENIZER_NAME = "gpt2"
|
| 6 |
+
REDUCE_VOCAB = False
|
| 7 |
+
VOCAB_SAVE_PATH = "vocab_map.pt"
|
| 8 |
+
|
| 9 |
+
EPOCHS = 25
|
| 10 |
+
MICRO_BATCH_SIZE = 1
|
| 11 |
+
GRAD_ACCUM_STEPS = 8
|
| 12 |
+
LEARNING_RATE = 3e-4
|
| 13 |
+
|
| 14 |
+
D_MODEL = 256
|
| 15 |
+
N_LAYERS = 4
|
| 16 |
+
MAX_SEQ_LEN = 1024
|
| 17 |
+
|
| 18 |
+
LOCAL_KERNEL_SIZE = 5
|
| 19 |
+
GLOBAL_KERNEL_SIZE = 256
|
| 20 |
+
USE_GLOBAL_EVERY_N_LAYERS = 2
|
| 21 |
+
|
| 22 |
+
FFT_SIZE = 1024
|
| 23 |
+
|
| 24 |
+
SAVE_PATH = "model.pt"
|
| 25 |
+
SAVE_N_EPOCHS = 1
|
| 26 |
+
|
| 27 |
+
USE_DEVICE = "cuda"
|
| 28 |
+
USE_AMP = True
|
| 29 |
+
USE_ACTIVATION_CHECKPOINTING = False
|
| 30 |
+
|
| 31 |
+
COMPILE = False
|
| 32 |
+
COMPILE_MODE = "reduce-overhead"
|
| 33 |
+
COMPILE_BACKEND = "eager"
|
| 34 |
+
|
| 35 |
+
import os
|
| 36 |
+
|
| 37 |
+
if os.name != "nt":
|
| 38 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
from torch.utils.data import Dataset, DataLoader
|
| 44 |
+
from tqdm import tqdm
|
| 45 |
+
import tiktoken
|
| 46 |
+
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
torch.set_float32_matmul_precision("high")
|
| 49 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 50 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 51 |
+
|
| 52 |
+
PAD_ID = 0
|
| 53 |
+
SEP_ID = 1
|
| 54 |
+
EOS_ID = 2
|
| 55 |
+
OFFSET = 3
|
| 56 |
+
|
| 57 |
+
def build_dataset_vocab(data_dir, tokenizer, save_path):
|
| 58 |
+
all_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".txt")]
|
| 59 |
+
print(f"[INFO] Building vocab from {len(all_files)} files...")
|
| 60 |
+
|
| 61 |
+
combined_used = set()
|
| 62 |
+
for fpath in all_files:
|
| 63 |
+
text = open(fpath, "r", encoding="utf-8").read()
|
| 64 |
+
token_ids = tokenizer.encode(text)
|
| 65 |
+
combined_used.update(token_ids)
|
| 66 |
+
|
| 67 |
+
used = sorted(list(combined_used))
|
| 68 |
+
id2new = {tok: i + OFFSET for i, tok in enumerate(used)}
|
| 69 |
+
|
| 70 |
+
torch.save({
|
| 71 |
+
"used_tokens": used,
|
| 72 |
+
"id2new": id2new,
|
| 73 |
+
"PAD_ID": PAD_ID,
|
| 74 |
+
"SEP_ID": SEP_ID,
|
| 75 |
+
"EOS_ID": EOS_ID,
|
| 76 |
+
}, save_path)
|
| 77 |
+
|
| 78 |
+
print(f"[OK] Total Vocab size: {len(used) + OFFSET}")
|
| 79 |
+
return used, id2new
|
| 80 |
+
|
| 81 |
+
class RemappedTextDataset(Dataset):
|
| 82 |
+
def __init__(self, ids, max_len):
|
| 83 |
+
self.ids = ids
|
| 84 |
+
self.max_len = max_len
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return max(0, len(self.ids) - self.max_len - 1)
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, i):
|
| 90 |
+
x = self.ids[i:i+self.max_len]
|
| 91 |
+
y = self.ids[i+1:i+self.max_len+1]
|
| 92 |
+
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
|
| 93 |
+
|
| 94 |
+
class GlobalConv1D(nn.Module):
|
| 95 |
+
def __init__(self, d_model, kernel_size, fft_size):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
|
| 98 |
+
self.kernel_size = kernel_size
|
| 99 |
+
self.fft_size = fft_size
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
B, C, T = x.shape
|
| 103 |
+
K = min(self.kernel_size, T)
|
| 104 |
+
|
| 105 |
+
overlap = K - 1
|
| 106 |
+
block = self.fft_size - overlap
|
| 107 |
+
|
| 108 |
+
x = F.pad(x, (overlap, 0))
|
| 109 |
+
k = self.kernel[:, :K]
|
| 110 |
+
k = F.pad(k, (0, self.fft_size - K))
|
| 111 |
+
k_f = torch.fft.rfft(k, n=self.fft_size)
|
| 112 |
+
|
| 113 |
+
outs = []
|
| 114 |
+
pos = 0
|
| 115 |
+
while pos < T:
|
| 116 |
+
seg = x[..., pos:pos+self.fft_size]
|
| 117 |
+
if seg.shape[-1] < self.fft_size:
|
| 118 |
+
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
|
| 119 |
+
|
| 120 |
+
y = torch.fft.irfft(
|
| 121 |
+
torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
|
| 122 |
+
n=self.fft_size
|
| 123 |
+
)
|
| 124 |
+
outs.append(y[..., overlap:overlap+block])
|
| 125 |
+
pos += block
|
| 126 |
+
|
| 127 |
+
return torch.cat(outs, dim=-1)[..., :T]
|
| 128 |
+
|
| 129 |
+
class LocalConv1D(nn.Module):
|
| 130 |
+
def __init__(self, d_model, k):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.k = k
|
| 133 |
+
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
|
| 134 |
+
self.pw = nn.Conv1d(d_model, d_model, 1)
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
x = F.pad(x, (self.k - 1, 0))
|
| 138 |
+
return self.pw(F.relu(self.dw(x)))
|
| 139 |
+
|
| 140 |
+
class Block(nn.Module):
|
| 141 |
+
def __init__(self, d_model, use_global):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.use_global = use_global
|
| 144 |
+
|
| 145 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 146 |
+
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
|
| 147 |
+
|
| 148 |
+
if use_global:
|
| 149 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 150 |
+
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
|
| 151 |
+
|
| 152 |
+
self.ln3 = nn.LayerNorm(d_model)
|
| 153 |
+
self.ff = nn.Sequential(
|
| 154 |
+
nn.Linear(d_model, d_model*4),
|
| 155 |
+
nn.GELU(),
|
| 156 |
+
nn.Linear(d_model*4, d_model)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
|
| 161 |
+
if self.use_global:
|
| 162 |
+
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
|
| 163 |
+
return x + self.ff(self.ln3(x))
|
| 164 |
+
|
| 165 |
+
class CrimsonBase(nn.Module):
|
| 166 |
+
def __init__(self, vocab):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.emb = nn.Embedding(vocab, D_MODEL)
|
| 169 |
+
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
|
| 170 |
+
|
| 171 |
+
self.layers = nn.ModuleList([
|
| 172 |
+
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
|
| 173 |
+
for i in range(N_LAYERS)
|
| 174 |
+
])
|
| 175 |
+
|
| 176 |
+
self.ln = nn.LayerNorm(D_MODEL)
|
| 177 |
+
self.head = nn.Linear(D_MODEL, vocab)
|
| 178 |
+
|
| 179 |
+
self.head.weight = self.emb.weight
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
T = x.size(1)
|
| 183 |
+
h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
|
| 184 |
+
for layer in self.layers:
|
| 185 |
+
h = layer(h)
|
| 186 |
+
return self.head(self.ln(h))
|
| 187 |
+
|
| 188 |
+
def format_params(num):
|
| 189 |
+
if num >= 1_000_000_000:
|
| 190 |
+
return f"{num/1_000_000_000:.1f}B"
|
| 191 |
+
elif num >= 1_000_000:
|
| 192 |
+
return f"{num/1_000_000:.1f}M"
|
| 193 |
+
else:
|
| 194 |
+
return f"{num/1_000:.1f}K"
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def estimate_loss(model, dl, device, ctx):
|
| 198 |
+
model.eval()
|
| 199 |
+
losses = []
|
| 200 |
+
limit = 50
|
| 201 |
+
for i, (x, y) in enumerate(dl):
|
| 202 |
+
if i >= limit: break
|
| 203 |
+
x, y = x.to(device), y.to(device)
|
| 204 |
+
with ctx:
|
| 205 |
+
logits = model(x)
|
| 206 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID)
|
| 207 |
+
losses.append(loss.item())
|
| 208 |
+
model.train()
|
| 209 |
+
return sum(losses) / len(losses) if losses else 0.0
|
| 210 |
+
|
| 211 |
+
def train():
|
| 212 |
+
if torch.cuda.is_available():
|
| 213 |
+
device = "cuda"
|
| 214 |
+
elif torch.backends.mps.is_available():
|
| 215 |
+
device = "mps"
|
| 216 |
+
else:
|
| 217 |
+
device = "cpu"
|
| 218 |
+
print("[INFO] Device:", device)
|
| 219 |
+
|
| 220 |
+
tok = tiktoken.get_encoding(TOKENIZER_NAME)
|
| 221 |
+
|
| 222 |
+
used, id2new = build_dataset_vocab(DATA_DIR, tok, VOCAB_SAVE_PATH)
|
| 223 |
+
vocab = len(used) + OFFSET
|
| 224 |
+
|
| 225 |
+
print("[INFO] Loading and tokenizing text from all files...")
|
| 226 |
+
all_files = [os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith(".txt")]
|
| 227 |
+
full_text = ""
|
| 228 |
+
for fpath in all_files:
|
| 229 |
+
full_text += open(fpath, "r", encoding="utf-8").read() + "\n"
|
| 230 |
+
|
| 231 |
+
if DATA_PCT < 1.0:
|
| 232 |
+
full_text = full_text[:int(len(full_text) * DATA_PCT)]
|
| 233 |
+
|
| 234 |
+
raw_ids = tok.encode(full_text)
|
| 235 |
+
ids = [id2new.get(i, PAD_ID) for i in raw_ids] + [EOS_ID]
|
| 236 |
+
|
| 237 |
+
n = len(ids)
|
| 238 |
+
split_idx = int(n * 0.9)
|
| 239 |
+
train_ids = ids[:split_idx]
|
| 240 |
+
val_ids = ids[split_idx:]
|
| 241 |
+
|
| 242 |
+
print(f"[INFO] Tokens: {n} | Train: {len(train_ids)} | Val: {len(val_ids)}")
|
| 243 |
+
|
| 244 |
+
train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN)
|
| 245 |
+
val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN)
|
| 246 |
+
|
| 247 |
+
train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True)
|
| 248 |
+
val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False)
|
| 249 |
+
|
| 250 |
+
model = CrimsonBase(vocab).to(device)
|
| 251 |
+
|
| 252 |
+
num_params = sum(p.numel() for p in model.parameters())
|
| 253 |
+
param_str = format_params(num_params)
|
| 254 |
+
save_path = f"crimson_base_{param_str}.pt"
|
| 255 |
+
print(f"[INFO] Model parameters: {num_params:,} ({param_str})")
|
| 256 |
+
print(f"[INFO] Save path: {save_path}")
|
| 257 |
+
|
| 258 |
+
if os.path.exists(save_path):
|
| 259 |
+
model.load_state_dict(torch.load(save_path, map_location=device))
|
| 260 |
+
print(f"[RESUME] Loaded existing checkpoint from {save_path}")
|
| 261 |
+
|
| 262 |
+
if device == "cuda" and COMPILE:
|
| 263 |
+
print("[INFO] Compiling model with torch.compile...")
|
| 264 |
+
model = torch.compile(
|
| 265 |
+
model,
|
| 266 |
+
mode=COMPILE_MODE,
|
| 267 |
+
fullgraph=False,
|
| 268 |
+
backend=COMPILE_BACKEND
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 272 |
+
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
|
| 273 |
+
|
| 274 |
+
if device == "cuda" and USE_AMP:
|
| 275 |
+
ctx = torch.amp.autocast(device)
|
| 276 |
+
scaler = torch.amp.GradScaler(device)
|
| 277 |
+
else:
|
| 278 |
+
import contextlib
|
| 279 |
+
ctx = contextlib.nullcontext()
|
| 280 |
+
scaler = None
|
| 281 |
+
|
| 282 |
+
for ep in range(EPOCHS):
|
| 283 |
+
print(f"\nEpoch {ep+1}/{EPOCHS}")
|
| 284 |
+
opt.zero_grad(set_to_none=True)
|
| 285 |
+
|
| 286 |
+
pbar = tqdm(train_dl, desc="Training")
|
| 287 |
+
running_loss = 0.0
|
| 288 |
+
|
| 289 |
+
for i, (x, y) in enumerate(pbar):
|
| 290 |
+
x, y = x.to(device), y.to(device)
|
| 291 |
+
|
| 292 |
+
with ctx:
|
| 293 |
+
logits = model(x)
|
| 294 |
+
loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1))
|
| 295 |
+
loss_val = loss.item()
|
| 296 |
+
loss = loss / GRAD_ACCUM_STEPS
|
| 297 |
+
|
| 298 |
+
if scaler:
|
| 299 |
+
scaler.scale(loss).backward()
|
| 300 |
+
else:
|
| 301 |
+
loss.backward()
|
| 302 |
+
|
| 303 |
+
if (i+1) % GRAD_ACCUM_STEPS == 0:
|
| 304 |
+
if scaler:
|
| 305 |
+
scaler.step(opt)
|
| 306 |
+
scaler.update()
|
| 307 |
+
else:
|
| 308 |
+
opt.step()
|
| 309 |
+
opt.zero_grad(set_to_none=True)
|
| 310 |
+
|
| 311 |
+
running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val
|
| 312 |
+
pbar.set_postfix(loss=f"{running_loss:.4f}")
|
| 313 |
+
|
| 314 |
+
val_loss = estimate_loss(model, val_dl, device, ctx)
|
| 315 |
+
print(f"Epoch {ep+1} finished. Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f}")
|
| 316 |
+
|
| 317 |
+
if SAVE_N_EPOCHS and (ep+1) % SAVE_N_EPOCHS == 0:
|
| 318 |
+
torch.save(model.state_dict(), save_path)
|
| 319 |
+
print(f"[OK] Saved checkpoint to {save_path}")
|
| 320 |
+
|
| 321 |
+
torch.save(model.state_dict(), save_path)
|
| 322 |
+
print("[DONE] Training complete.")
|
| 323 |
+
|
| 324 |
+
if __name__ == "__main__":
|
| 325 |
+
train()
|
vocab_map.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b3403a73a81a7186b53b28a093fbff744e12a9fcaffe663fd7294b8a0f875a7e
|
| 3 |
+
size 214393
|