AGofficial commited on
Commit
238d08f
·
verified ·
1 Parent(s): 25bd6e1

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +43 -0
  2. banner.png +0 -0
  3. model.py +94 -0
  4. sample.py +103 -0
  5. train_chatgclm.py +174 -0
  6. vocab_map.pt +3 -0
README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ChatGCLM-330M
2
+ <img src="./banner.png" alt="ChatGCLM Hero" width="600">
3
+ <strong>A high-performance language model architecture.</strong>
4
+
5
+ ---
6
+
7
+ ## Overview
8
+
9
+ **ChatGCLM** is a generative language model that deviates from the traditional Transformer architecture by utilizing a hybrid approach of **Local** and **Global Convolutions**. By leveraging Fast Fourier Transforms (FFT) for global context, ChatGCLM achieves a massive receptive field with a fraction of the computational overhead associated with standard attention mechanisms.
10
+
11
+ The architecture is designed for efficiency, speed, and high-quality generation, featuring a custom vocabulary reduction system that optimizes the embedding space for specific datasets.
12
+
13
+
14
+ ## 📦 Installation
15
+
16
+ Download this repository and extract it.
17
+
18
+ ---
19
+
20
+ ## Usage
21
+
22
+ ### 1. Training the Model
23
+ Place your `.txt` data files in the `data/` directory and run:
24
+ ```bash
25
+ python train_chatgclm.py
26
+ ```
27
+ This script will build the vocabulary and train the foundation model
28
+
29
+ ### 2. Interactive Chat Interface
30
+ Launch the Tkinter-based UI to interact with your model:
31
+ ```bash
32
+ python chat_interface.py
33
+ ```
34
+
35
+ ---
36
+
37
+ ## Fine-tuning
38
+
39
+ You may fine-tune the model by resuming training from a checkpoint, you may use a different dataset as long as the vocabulary is the same, you may also change parameters such as the learning rate, batch size, etc.
40
+
41
+ <p align="center">
42
+ Built with ❤️ by AG
43
+ </p>
banner.png ADDED
model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ D_MODEL = 1152
6
+ N_LAYERS = 22
7
+ MAX_SEQ_LEN = 4096
8
+ LOCAL_KERNEL_SIZE = 5
9
+ GLOBAL_KERNEL_SIZE = 256
10
+ USE_GLOBAL_EVERY_N_LAYERS = 2
11
+ FFT_SIZE = 1024
12
+
13
+ class GlobalConv1D(nn.Module):
14
+ def __init__(self, d_model, kernel_size, fft_size):
15
+ super().__init__()
16
+ self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
17
+ self.kernel_size = kernel_size
18
+ self.fft_size = fft_size
19
+
20
+ def forward(self, x):
21
+ B, C, T = x.shape
22
+ K = min(self.kernel_size, T)
23
+ overlap = K - 1
24
+ block = self.fft_size - overlap
25
+ x = F.pad(x, (overlap, 0))
26
+ k = self.kernel[:, :K]
27
+ k = F.pad(k, (0, self.fft_size - K))
28
+ k_f = torch.fft.rfft(k, n=self.fft_size)
29
+ outs = []
30
+ pos = 0
31
+ while pos < T:
32
+ seg = x[..., pos:pos+self.fft_size]
33
+ if seg.shape[-1] < self.fft_size:
34
+ seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
35
+ y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size)
36
+ outs.append(y[..., overlap:overlap+block])
37
+ pos += block
38
+ return torch.cat(outs, dim=-1)[..., :T]
39
+
40
+ class LocalConv1D(nn.Module):
41
+ def __init__(self, d_model, k):
42
+ super().__init__()
43
+ self.k = k
44
+ self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
45
+ self.pw = nn.Conv1d(d_model, d_model, 1)
46
+
47
+ def forward(self, x):
48
+ x = F.pad(x, (self.k - 1, 0))
49
+ return self.pw(F.relu(self.dw(x)))
50
+
51
+ class Block(nn.Module):
52
+ def __init__(self, d_model, use_global):
53
+ super().__init__()
54
+ self.use_global = use_global
55
+ self.ln1 = nn.LayerNorm(d_model)
56
+ self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
57
+ if use_global:
58
+ self.ln2 = nn.LayerNorm(d_model)
59
+ self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
60
+ self.ln3 = nn.LayerNorm(d_model)
61
+ self.ff = nn.Sequential(
62
+ nn.Linear(d_model, d_model*4),
63
+ nn.GELU(),
64
+ nn.Linear(d_model*4, d_model)
65
+ )
66
+
67
+ def forward(self, x):
68
+ x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
69
+ if self.use_global:
70
+ x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
71
+ return x + self.ff(self.ln3(x))
72
+
73
+ class ChatGCLM(nn.Module):
74
+ def __init__(self, vocab):
75
+ super().__init__()
76
+ self.emb = nn.Embedding(vocab, D_MODEL)
77
+ self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
78
+ self.layers = nn.ModuleList([
79
+ Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
80
+ for i in range(N_LAYERS)
81
+ ])
82
+ self.ln = nn.LayerNorm(D_MODEL)
83
+ self.head = nn.Linear(D_MODEL, vocab)
84
+ self.head.weight = self.emb.weight
85
+
86
+ def forward(self, x):
87
+ T = x.size(1)
88
+ if T > MAX_SEQ_LEN:
89
+ x = x[:, -MAX_SEQ_LEN:]
90
+ T = MAX_SEQ_LEN
91
+ h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
92
+ for layer in self.layers:
93
+ h = layer(h)
94
+ return self.head(self.ln(h))
sample.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import tiktoken
5
+ from model import ChatGCLM, MAX_SEQ_LEN
6
+
7
+ MODEL_PATH = None
8
+ for f in os.listdir("."):
9
+ if f.startswith("ChatGCLM_") and f.endswith(".pt"):
10
+ MODEL_PATH = f
11
+ break
12
+
13
+ if MODEL_PATH is None:
14
+ print("Error: No model checkpoint found!")
15
+ print("Please train the model first with: python3 train_chatgclm.py")
16
+ exit(1)
17
+
18
+ TOKENIZER_NAME = "gpt2"
19
+ EOS_ID = 2
20
+
21
+ def load_model(device):
22
+ tok = tiktoken.get_encoding(TOKENIZER_NAME)
23
+ vocab_size = tok.n_vocab
24
+
25
+ model = ChatGCLM(vocab_size).to(device)
26
+ if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0:
27
+ print(f"Loading model from: {MODEL_PATH}")
28
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
29
+ model.eval()
30
+ return model, tok
31
+ else:
32
+ print(f"Error: Could not load model from {MODEL_PATH}")
33
+ return None, None
34
+
35
+ @torch.no_grad()
36
+ def generate(model, prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50):
37
+ model.eval()
38
+ input_ids = tokenizer.encode(prompt)
39
+ x = torch.tensor([input_ids], dtype=torch.long, device=device)
40
+
41
+ print(f"\n{'='*70}")
42
+ print(f"PROMPT: {prompt}")
43
+ print(f"{'='*70}")
44
+ print("GENERATED TEXT:")
45
+ print(prompt, end="", flush=True)
46
+
47
+ generated_tokens = []
48
+ for _ in range(max_new_tokens):
49
+ ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
50
+ logits = model(ctx)
51
+ next_token_logits = logits[:, -1, :] / temperature
52
+
53
+ if top_k is not None:
54
+ v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
55
+ next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
56
+
57
+ probs = F.softmax(next_token_logits, dim=-1)
58
+ next_token = torch.multinomial(probs, num_samples=1)
59
+ idx = next_token.item()
60
+
61
+ if idx == EOS_ID:
62
+ break
63
+
64
+ x = torch.cat((x, next_token), dim=1)
65
+ generated_tokens.append(idx)
66
+ token_text = tokenizer.decode([idx])
67
+ print(token_text, end="", flush=True)
68
+
69
+ print(f"\n{'='*70}\n")
70
+ return tokenizer.decode(generated_tokens)
71
+
72
+ if __name__ == "__main__":
73
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
74
+ print(f"Using device: {device}")
75
+
76
+ model, tokenizer = load_model(device)
77
+
78
+ if model is None:
79
+ exit(1)
80
+
81
+ test_prompts = [
82
+ "Once upon a time",
83
+ "The future of AI is",
84
+ "In a world where",
85
+ ]
86
+
87
+ print("\n" + "="*70)
88
+ print("ChatGCLM Text Generation Demo")
89
+ print("="*70)
90
+
91
+ for prompt in test_prompts:
92
+ generate(model, prompt, tokenizer, device, max_new_tokens=150, temperature=0.8, top_k=50)
93
+
94
+ print("\n" + "="*70)
95
+ print("Interactive Mode - Enter your own prompts!")
96
+ print("="*70)
97
+
98
+ while True:
99
+ user_prompt = input("\nEnter prompt (or 'exit' to quit): ")
100
+ if user_prompt.lower() == 'exit':
101
+ break
102
+ if user_prompt.strip():
103
+ generate(model, user_prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50)
train_chatgclm.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from tqdm import tqdm
7
+ import tiktoken
8
+ import contextlib
9
+ from model import ChatGCLM, MAX_SEQ_LEN
10
+
11
+ if os.name != "nt":
12
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
13
+
14
+ if torch.cuda.is_available():
15
+ torch.set_float32_matmul_precision("high")
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
+
19
+ DATA_DIR = "data"
20
+ DATA_PCT = 0.002
21
+ TOKENIZER_NAME = "gpt2"
22
+ VOCAB_SAVE_PATH = "vocab_map.pt"
23
+
24
+ EPOCHS = 50
25
+ MICRO_BATCH_SIZE = 1
26
+ GRAD_ACCUM_STEPS = 8
27
+ LEARNING_RATE = 5e-4
28
+ MIN_LR = 1e-5
29
+
30
+ SAVE_N_EPOCHS = 1
31
+
32
+ PAD_ID = 0
33
+ SEP_ID = 1
34
+ EOS_ID = 2
35
+ OFFSET = 3
36
+
37
+ def build_dataset_vocab(data_dir, tokenizer, save_path):
38
+ vocab_size = tokenizer.n_vocab
39
+ torch.save({
40
+ "vocab_size": vocab_size,
41
+ "PAD_ID": PAD_ID,
42
+ "SEP_ID": SEP_ID,
43
+ "EOS_ID": EOS_ID,
44
+ }, save_path)
45
+ return vocab_size
46
+
47
+ class RemappedTextDataset(Dataset):
48
+ def __init__(self, ids, max_len):
49
+ self.ids = ids
50
+ self.max_len = max_len
51
+
52
+ def __len__(self):
53
+ return max(0, (len(self.ids) - 1) // self.max_len)
54
+
55
+ def __getitem__(self, i):
56
+ start = i * self.max_len
57
+ x = self.ids[start : start + self.max_len]
58
+ y = self.ids[start + 1 : start + self.max_len + 1]
59
+
60
+ if len(x) < self.max_len:
61
+ x = x + [0] * (self.max_len - len(x))
62
+ if len(y) < self.max_len:
63
+ y = y + [0] * (self.max_len - len(y))
64
+
65
+ return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
66
+
67
+ def format_params(num):
68
+ if num >= 1_000_000_000:
69
+ return f"{num/1_000_000_000:.1f}B"
70
+ elif num >= 1_000_000:
71
+ return f"{num/1_000_000:.1f}M"
72
+ else:
73
+ return f"{num/1_000:.1f}K"
74
+
75
+ @torch.no_grad()
76
+ def estimate_loss(model, dl, device, ctx):
77
+ model.eval()
78
+ losses = []
79
+ limit = 50
80
+ for i, (x, y) in enumerate(dl):
81
+ if i >= limit: break
82
+ x, y = x.to(device), y.to(device)
83
+ with ctx:
84
+ logits = model(x)
85
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID)
86
+ losses.append(loss.item())
87
+ model.train()
88
+ return sum(losses) / len(losses) if losses else 0.0
89
+
90
+ def train():
91
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
92
+ tok = tiktoken.get_encoding(TOKENIZER_NAME)
93
+ vocab = build_dataset_vocab(DATA_DIR, tok, VOCAB_SAVE_PATH)
94
+
95
+ full_text = ""
96
+ for f in os.listdir(DATA_DIR):
97
+ if not f.endswith(".txt"): continue
98
+ fpath = os.path.join(DATA_DIR, f)
99
+ content = open(fpath, "r", encoding="utf-8").read()
100
+ full_text += content + "\n"
101
+
102
+ ids = tok.encode(full_text) + [EOS_ID]
103
+
104
+ n = len(ids)
105
+ split_idx = int(n * 0.9)
106
+ train_ids = ids[:split_idx]
107
+ val_ids = ids[split_idx:]
108
+
109
+ train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN)
110
+ val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN)
111
+ train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True)
112
+ val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False)
113
+
114
+ model = ChatGCLM(vocab).to(device)
115
+ num_params = sum(p.numel() for p in model.parameters())
116
+ param_str = format_params(num_params)
117
+ save_path = f"ChatGCLM_{param_str}.pt"
118
+
119
+ print("-" * 30)
120
+ print(f"ChatGCLM TRAINING START")
121
+ print(f"Model ID: {save_path}")
122
+ print(f"Parameters: {num_params:,}")
123
+ print(f"Device: {device}")
124
+ print(f"Vocab Size: {vocab}")
125
+ print(f"Learning Rate: {LEARNING_RATE}")
126
+ print(f"Epochs: {EPOCHS}")
127
+ print("-" * 30)
128
+
129
+ if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
130
+ print(f"⏳ Found checkpoint at {save_path}, loading...")
131
+ model.load_state_dict(torch.load(save_path, map_location=device))
132
+ print("✓ Model weights loaded successfully! Resuming training.")
133
+ else:
134
+ print("ℹ No checkpoint found. Starting training from scratch.")
135
+
136
+ opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
137
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS, eta_min=MIN_LR)
138
+ loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
139
+ ctx = torch.amp.autocast(device) if device == "cuda" else contextlib.nullcontext()
140
+ scaler = torch.amp.GradScaler(device) if device == "cuda" else None
141
+
142
+ for ep in range(EPOCHS):
143
+ opt.zero_grad(set_to_none=True)
144
+ pbar = tqdm(train_dl, desc=f"Epoch {ep+1}/{EPOCHS}")
145
+ running_loss = 0.0
146
+ for i, (x, y) in enumerate(pbar):
147
+ x, y = x.to(device), y.to(device)
148
+ with ctx:
149
+ logits = model(x)
150
+ loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1))
151
+ loss_val = loss.item()
152
+ loss = loss / GRAD_ACCUM_STEPS
153
+ if scaler:
154
+ scaler.scale(loss).backward()
155
+ else:
156
+ loss.backward()
157
+ if (i+1) % GRAD_ACCUM_STEPS == 0:
158
+ if scaler:
159
+ scaler.step(opt)
160
+ scaler.update()
161
+ else:
162
+ opt.step()
163
+ opt.zero_grad(set_to_none=True)
164
+ running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val
165
+ pbar.set_postfix(loss=f"{running_loss:.4f}")
166
+ val_loss = estimate_loss(model, val_dl, device, ctx)
167
+ current_lr = scheduler.get_last_lr()[0]
168
+ print(f"Epoch {ep+1} | Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
169
+ torch.save(model.state_dict(), save_path)
170
+ print(f"✓ Model saved successfully after epoch {ep+1} to {save_path}")
171
+ scheduler.step()
172
+
173
+ if __name__ == "__main__":
174
+ train()
vocab_map.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3622ed31c3f722a9e12ae90ffdc9a51a063809d43c7aee885c1b75037161b202
3
+ size 1337