Upload 5 files
Browse files- .gitattributes +1 -0
- AgGPT21.pt +3 -0
- AgGPT21.py +276 -0
- README.md +241 -3
- banner.png +3 -0
- chat.py +185 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
banner.png filter=lfs diff=lfs merge=lfs -text
|
AgGPT21.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ee462d86bdd3f9ae980836dc5ea91d2d0b308777c6fa5c7a229a05f30c8015c
|
| 3 |
+
size 16629913
|
AgGPT21.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from collections import Counter
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import glob
|
| 10 |
+
|
| 11 |
+
MODEL_FILE = "AgGPT21.pt"
|
| 12 |
+
DATA_FOLDER = "training_corpora/"
|
| 13 |
+
|
| 14 |
+
SEED = 42
|
| 15 |
+
random.seed(SEED)
|
| 16 |
+
torch.manual_seed(SEED)
|
| 17 |
+
|
| 18 |
+
SEQ_LEN = 64
|
| 19 |
+
STRIDE = 64
|
| 20 |
+
EMBED_SIZE = 128
|
| 21 |
+
HIDDEN_SIZE = 128
|
| 22 |
+
NUM_LAYERS = 1
|
| 23 |
+
DROPOUT = 0.2
|
| 24 |
+
|
| 25 |
+
BATCH_SIZE = 8
|
| 26 |
+
EPOCHS = 6
|
| 27 |
+
LR = 2e-3
|
| 28 |
+
WEIGHT_DECAY = 1e-4
|
| 29 |
+
CLIP_NORM = 1.0
|
| 30 |
+
|
| 31 |
+
GENERATE_LENGTH = 200
|
| 32 |
+
DATA_PERCENT = 0.1
|
| 33 |
+
MAX_TOKENS = 1_000_000
|
| 34 |
+
MAX_VOCAB = 30000
|
| 35 |
+
|
| 36 |
+
TEMPERATURE = 0.9
|
| 37 |
+
TOP_K = 50
|
| 38 |
+
TOP_P = 0.9
|
| 39 |
+
|
| 40 |
+
if torch.backends.mps.is_available():
|
| 41 |
+
DEVICE = torch.device("mps")
|
| 42 |
+
elif torch.cuda.is_available():
|
| 43 |
+
DEVICE = torch.device("cuda")
|
| 44 |
+
else:
|
| 45 |
+
DEVICE = torch.device("cpu")
|
| 46 |
+
|
| 47 |
+
def build_vocab_and_ids(folder_path, percent=1.0, max_tokens=None, max_vocab=None):
|
| 48 |
+
"""Build vocabulary and token IDs from all text files in a folder."""
|
| 49 |
+
all_words = []
|
| 50 |
+
|
| 51 |
+
# Get all .txt files in the folder
|
| 52 |
+
txt_files = glob.glob(os.path.join(folder_path, "*.txt"))
|
| 53 |
+
if not txt_files:
|
| 54 |
+
raise FileNotFoundError(f"No .txt files found in {folder_path}")
|
| 55 |
+
|
| 56 |
+
print(f"Found {len(txt_files)} training files")
|
| 57 |
+
|
| 58 |
+
# Limit number of files to process based on percent
|
| 59 |
+
if percent < 1.0:
|
| 60 |
+
num_files_to_use = max(1, int(len(txt_files) * percent))
|
| 61 |
+
txt_files = txt_files[:num_files_to_use]
|
| 62 |
+
print(f"Using {percent*100}% of files: {num_files_to_use}/{len(glob.glob(os.path.join(folder_path, '*.txt')))} files")
|
| 63 |
+
|
| 64 |
+
# Read and combine selected files
|
| 65 |
+
for file_path in sorted(txt_files):
|
| 66 |
+
print(f"Reading {os.path.basename(file_path)}...")
|
| 67 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 68 |
+
text = f.read().lower()
|
| 69 |
+
# Split by whitespace and filter out empty strings
|
| 70 |
+
words = [w for w in text.split() if w]
|
| 71 |
+
all_words.extend(words)
|
| 72 |
+
|
| 73 |
+
print(f"Total words loaded: {len(all_words):,}")
|
| 74 |
+
|
| 75 |
+
if max_tokens is not None:
|
| 76 |
+
all_words = all_words[:max_tokens]
|
| 77 |
+
print(f"Truncated to max_tokens: {len(all_words):,} words")
|
| 78 |
+
|
| 79 |
+
counts = Counter(all_words)
|
| 80 |
+
if max_vocab is not None:
|
| 81 |
+
keep = max(1, max_vocab - 1)
|
| 82 |
+
common = [w for w, _ in counts.most_common(keep) if w != "<unk>"]
|
| 83 |
+
vocab = ["<unk>"] + common
|
| 84 |
+
else:
|
| 85 |
+
vocab = ["<unk>"] + [w for w in counts if w != "<unk>"]
|
| 86 |
+
|
| 87 |
+
stoi = {w: i for i, w in enumerate(vocab)}
|
| 88 |
+
itos = {i: w for w, i in stoi.items()}
|
| 89 |
+
ids = [stoi.get(w, 0) for w in all_words]
|
| 90 |
+
|
| 91 |
+
print(f"Vocabulary size: {len(vocab):,}")
|
| 92 |
+
return vocab, stoi, itos, ids
|
| 93 |
+
|
| 94 |
+
class WordDataset(Dataset):
|
| 95 |
+
def __init__(self, ids, seq_len, stride=None):
|
| 96 |
+
self.ids = ids
|
| 97 |
+
self.seq_len = seq_len
|
| 98 |
+
self.stride = stride or seq_len
|
| 99 |
+
self.n = max(0, (len(self.ids) - self.seq_len - 1) // self.stride + 1)
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return self.n
|
| 102 |
+
def __getitem__(self, idx):
|
| 103 |
+
start = idx * self.stride
|
| 104 |
+
x = torch.tensor(self.ids[start:start + self.seq_len], dtype=torch.long)
|
| 105 |
+
y = torch.tensor(self.ids[start + 1:start + self.seq_len + 1], dtype=torch.long)
|
| 106 |
+
return x, y
|
| 107 |
+
|
| 108 |
+
class WordRNN(nn.Module):
|
| 109 |
+
def __init__(self, vocab_size, embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout=DROPOUT):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
| 112 |
+
self.drop = nn.Dropout(dropout)
|
| 113 |
+
self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
|
| 114 |
+
self.proj = None
|
| 115 |
+
if hidden_size != embed_size:
|
| 116 |
+
self.proj = nn.Linear(hidden_size, embed_size, bias=False)
|
| 117 |
+
out_size = embed_size if self.proj else hidden_size
|
| 118 |
+
self.fc = nn.Linear(out_size, vocab_size, bias=False)
|
| 119 |
+
self.fc.weight = self.embed.weight
|
| 120 |
+
def forward(self, x, hidden=None):
|
| 121 |
+
e = self.drop(self.embed(x))
|
| 122 |
+
out, h = self.gru(e, hidden)
|
| 123 |
+
out = self.drop(out)
|
| 124 |
+
if self.proj is not None:
|
| 125 |
+
out = self.proj(out)
|
| 126 |
+
logits = self.fc(out)
|
| 127 |
+
return logits, h
|
| 128 |
+
|
| 129 |
+
def count_parameters(model):
|
| 130 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 131 |
+
|
| 132 |
+
def evaluate(model, dataloader, device, use_amp):
|
| 133 |
+
model.eval()
|
| 134 |
+
criterion = nn.CrossEntropyLoss(ignore_index=0)
|
| 135 |
+
total_loss = 0.0
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
for x, y in dataloader:
|
| 138 |
+
x = x.to(device)
|
| 139 |
+
y = y.to(device)
|
| 140 |
+
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp):
|
| 141 |
+
logits, _ = model(x)
|
| 142 |
+
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 143 |
+
total_loss += loss.item()
|
| 144 |
+
return total_loss / max(1, len(dataloader))
|
| 145 |
+
|
| 146 |
+
def train(model, train_loader, val_loader, epochs, lr, device, weight_decay, clip_norm, stoi, itos):
|
| 147 |
+
model.to(device)
|
| 148 |
+
opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 149 |
+
criterion = nn.CrossEntropyLoss(ignore_index=0)
|
| 150 |
+
use_amp = device.type in {"mps", "cuda"}
|
| 151 |
+
best_val = float("inf")
|
| 152 |
+
patience = 2
|
| 153 |
+
epochs_no_improve = 0
|
| 154 |
+
print(f"Train batches per epoch: {len(train_loader)} | Val batches: {len(val_loader)}")
|
| 155 |
+
epoch_bar = tqdm(range(epochs), desc="Epochs")
|
| 156 |
+
for epoch in epoch_bar:
|
| 157 |
+
model.train()
|
| 158 |
+
total_loss = 0.0
|
| 159 |
+
batch_bar = tqdm(train_loader, desc=f"Train {epoch+1}/{epochs}", leave=False)
|
| 160 |
+
for x, y in batch_bar:
|
| 161 |
+
x = x.to(device)
|
| 162 |
+
y = y.to(device)
|
| 163 |
+
opt.zero_grad()
|
| 164 |
+
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp):
|
| 165 |
+
logits, _ = model(x)
|
| 166 |
+
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 167 |
+
loss.backward()
|
| 168 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
|
| 169 |
+
opt.step()
|
| 170 |
+
total_loss += loss.item()
|
| 171 |
+
batch_bar.close()
|
| 172 |
+
train_loss = total_loss / max(1, len(train_loader))
|
| 173 |
+
val_loss = evaluate(model, val_loader, device, use_amp)
|
| 174 |
+
epoch_bar.set_postfix(train=f"{train_loss:.4f}", val=f"{val_loss:.4f}")
|
| 175 |
+
if val_loss < best_val - 1e-4:
|
| 176 |
+
best_val = val_loss
|
| 177 |
+
epochs_no_improve = 0
|
| 178 |
+
torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE)
|
| 179 |
+
else:
|
| 180 |
+
epochs_no_improve += 1
|
| 181 |
+
if epochs_no_improve >= patience:
|
| 182 |
+
print("Early stopping.")
|
| 183 |
+
break
|
| 184 |
+
ckpt = torch.load(MODEL_FILE, map_location=device)
|
| 185 |
+
model.load_state_dict(ckpt["model_state"])
|
| 186 |
+
return model
|
| 187 |
+
|
| 188 |
+
def _sample_next_id(probs_1d, top_k=None, top_p=None, temperature=1.0, forbid_ids=None):
|
| 189 |
+
probs = probs_1d.clone()
|
| 190 |
+
if forbid_ids:
|
| 191 |
+
for i in forbid_ids:
|
| 192 |
+
if 0 <= i < probs.numel():
|
| 193 |
+
probs[i] = 0
|
| 194 |
+
if temperature != 1.0:
|
| 195 |
+
logits = torch.log(probs + 1e-9) / temperature
|
| 196 |
+
probs = torch.softmax(logits, dim=-1)
|
| 197 |
+
if probs.sum() <= 0:
|
| 198 |
+
probs = torch.ones_like(probs)
|
| 199 |
+
if forbid_ids:
|
| 200 |
+
for i in forbid_ids:
|
| 201 |
+
if 0 <= i < probs.numel():
|
| 202 |
+
probs[i] = 0
|
| 203 |
+
probs = probs / probs.sum()
|
| 204 |
+
if top_k is not None and top_k > 0:
|
| 205 |
+
k = min(top_k, probs.size(-1))
|
| 206 |
+
values, indices = torch.topk(probs, k)
|
| 207 |
+
values = values / values.sum()
|
| 208 |
+
idx = indices[torch.multinomial(values, 1)]
|
| 209 |
+
return idx.item()
|
| 210 |
+
if top_p is not None and 0 < top_p < 1.0:
|
| 211 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 212 |
+
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
| 213 |
+
keep_mask = cumulative <= top_p
|
| 214 |
+
keep = int(keep_mask.nonzero()[-1].item()) + 1 if keep_mask.any() else 1
|
| 215 |
+
sorted_probs = sorted_probs[:keep]
|
| 216 |
+
sorted_indices = sorted_indices[:keep]
|
| 217 |
+
sorted_probs = sorted_probs / sorted_probs.sum()
|
| 218 |
+
idx_pos = torch.multinomial(sorted_probs, 1)
|
| 219 |
+
return sorted_indices[idx_pos].item()
|
| 220 |
+
probs = probs / probs.sum()
|
| 221 |
+
return torch.multinomial(probs, 1).item()
|
| 222 |
+
|
| 223 |
+
def generate_text(model, stoi, itos, prompt, length=GENERATE_LENGTH, seq_len=SEQ_LEN, device=DEVICE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P):
|
| 224 |
+
model.to(device)
|
| 225 |
+
model.eval()
|
| 226 |
+
words = prompt.lower().split()
|
| 227 |
+
ids = [stoi.get(w, 0) for w in words]
|
| 228 |
+
context = ids[-seq_len:] if len(ids) >= seq_len else [0] * (seq_len - len(ids)) + ids
|
| 229 |
+
input_ids = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device)
|
| 230 |
+
hidden = None
|
| 231 |
+
generated = words.copy()
|
| 232 |
+
use_amp = device.type in {"mps", "cuda"}
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
gen_bar = tqdm(range(length), desc="Generating", leave=False)
|
| 235 |
+
for _ in gen_bar:
|
| 236 |
+
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp):
|
| 237 |
+
logits, hidden = model(input_ids, hidden)
|
| 238 |
+
probs = torch.softmax(logits[:, -1, :], dim=-1).squeeze(0)
|
| 239 |
+
next_id = _sample_next_id(probs, top_k=top_k, top_p=top_p, temperature=temperature, forbid_ids=[0])
|
| 240 |
+
next_word = itos.get(next_id, "<unk>")
|
| 241 |
+
generated.append(next_word)
|
| 242 |
+
input_ids = torch.tensor([[next_id]], dtype=torch.long).to(device)
|
| 243 |
+
return " ".join(generated)
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
if os.path.exists(MODEL_FILE):
|
| 247 |
+
ckpt = torch.load(MODEL_FILE, map_location=DEVICE)
|
| 248 |
+
stoi = ckpt["stoi"]
|
| 249 |
+
itos = ckpt["itos"]
|
| 250 |
+
model = WordRNN(len(stoi))
|
| 251 |
+
model.load_state_dict(ckpt["model_state"])
|
| 252 |
+
print(f"Loaded model {MODEL_FILE} | device={DEVICE} | params={count_parameters(model):,}")
|
| 253 |
+
else:
|
| 254 |
+
if not os.path.exists(DATA_FOLDER):
|
| 255 |
+
raise FileNotFoundError(f"Training folder not found: {DATA_FOLDER}")
|
| 256 |
+
vocab, stoi, itos, ids = build_vocab_and_ids(DATA_FOLDER, percent=DATA_PERCENT, max_tokens=MAX_TOKENS, max_vocab=MAX_VOCAB)
|
| 257 |
+
print(f"Vocab size: {len(vocab):,} | Tokens used: {len(ids):,} | device={DEVICE}")
|
| 258 |
+
val_tokens = max(SEQ_LEN * 5, int(0.05 * len(ids)))
|
| 259 |
+
train_ids = ids[:-val_tokens]
|
| 260 |
+
val_ids = ids[-val_tokens:]
|
| 261 |
+
train_dataset = WordDataset(train_ids, SEQ_LEN, stride=STRIDE)
|
| 262 |
+
val_dataset = WordDataset(val_ids, SEQ_LEN, stride=STRIDE)
|
| 263 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
| 264 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
|
| 265 |
+
model = WordRNN(len(vocab))
|
| 266 |
+
print(f"Model params: {count_parameters(model):,}")
|
| 267 |
+
model = train(model, train_loader, val_loader, EPOCHS, LR, DEVICE, WEIGHT_DECAY, CLIP_NORM, stoi, itos)
|
| 268 |
+
torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE)
|
| 269 |
+
print(f"Saved {MODEL_FILE}")
|
| 270 |
+
|
| 271 |
+
print("\n=== AgGPT-21 Demo ===")
|
| 272 |
+
prompts = ["hello world", "how are you", "once upon a time", "tell me about"]
|
| 273 |
+
for p in prompts:
|
| 274 |
+
print(f"\nPrompt: {p}")
|
| 275 |
+
print(f"Generated: {generate_text(model, stoi, itos, p)}")
|
| 276 |
+
print("\nTraining complete! Use chat.py for interactive conversation.")
|
README.md
CHANGED
|
@@ -1,3 +1,241 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🤖 AgGPT-21
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
A powerful and lightweight GPT-style language model built with PyTorch, featuring word-level tokenization and GRU-based architecture.
|
| 6 |
+
|
| 7 |
+
## ✨ Features
|
| 8 |
+
|
| 9 |
+
- **🧠 Intelligent Architecture**: GRU-based neural network with embedding layers
|
| 10 |
+
- **📚 Multi-File Training**: Trains on multiple corpus files automatically
|
| 11 |
+
- **⚡ Optimized Performance**: Supports GPU (CUDA), Apple Silicon (MPS), and CPU
|
| 12 |
+
- **🎛️ Flexible Generation**: Configurable temperature, top-k, and top-p sampling
|
| 13 |
+
- **💬 Interactive Chat**: Beautiful command-line chat interface
|
| 14 |
+
- **🔄 Early Stopping**: Prevents overfitting with validation-based early stopping
|
| 15 |
+
- **📊 Progress Tracking**: Real-time training progress with tqdm
|
| 16 |
+
|
| 17 |
+
## 🚀 Quick Start
|
| 18 |
+
|
| 19 |
+
### Prerequisites
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install torch tqdm
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Training the Model
|
| 26 |
+
|
| 27 |
+
1. **Prepare your training data**: Place your text files in the `training_corpora/` folder
|
| 28 |
+
2. **Start training**:
|
| 29 |
+
```bash
|
| 30 |
+
python AgGPT21.py
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
The model will automatically:
|
| 34 |
+
- Load all `.txt` files from `training_corpora/`
|
| 35 |
+
- Build vocabulary from your data
|
| 36 |
+
- Train with validation split and early stopping
|
| 37 |
+
- Save the trained model as `AgGPT21.pt`
|
| 38 |
+
|
| 39 |
+
### Interactive Chat
|
| 40 |
+
|
| 41 |
+
Once trained, start chatting with your model:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
python chat.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## 📁 Project Structure
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
AgGPT-21-2/
|
| 51 |
+
├── banner.png # Project banner image
|
| 52 |
+
├── AgGPT21.py # Main training script
|
| 53 |
+
├── chat.py # Interactive chat interface
|
| 54 |
+
├── README.md # This file
|
| 55 |
+
├── AgGPT21.pt # Trained model (generated after training)
|
| 56 |
+
└── training_corpora/ # Training data folder
|
| 57 |
+
├── corpora_000.txt # Training file 1
|
| 58 |
+
├── corpora_001.txt # Training file 2
|
| 59 |
+
├── ... # More training files
|
| 60 |
+
└── corpora_041.txt # Training file N
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## ⚙️ Configuration
|
| 64 |
+
|
| 65 |
+
### Model Hyperparameters
|
| 66 |
+
|
| 67 |
+
| Parameter | Default | Description |
|
| 68 |
+
|-----------|---------|-------------|
|
| 69 |
+
| `SEQ_LEN` | 64 | Sequence length for training |
|
| 70 |
+
| `EMBED_SIZE` | 128 | Embedding dimension |
|
| 71 |
+
| `HIDDEN_SIZE` | 128 | GRU hidden dimension |
|
| 72 |
+
| `NUM_LAYERS` | 1 | Number of GRU layers |
|
| 73 |
+
| `DROPOUT` | 0.2 | Dropout rate |
|
| 74 |
+
|
| 75 |
+
### Training Parameters
|
| 76 |
+
|
| 77 |
+
| Parameter | Default | Description |
|
| 78 |
+
|-----------|---------|-------------|
|
| 79 |
+
| `BATCH_SIZE` | 8 | Training batch size |
|
| 80 |
+
| `EPOCHS` | 6 | Maximum training epochs |
|
| 81 |
+
| `LR` | 2e-3 | Learning rate |
|
| 82 |
+
| `WEIGHT_DECAY` | 1e-4 | L2 regularization |
|
| 83 |
+
| `CLIP_NORM` | 1.0 | Gradient clipping |
|
| 84 |
+
|
| 85 |
+
### Generation Settings
|
| 86 |
+
|
| 87 |
+
| Parameter | Default | Description |
|
| 88 |
+
|-----------|---------|-------------|
|
| 89 |
+
| `TEMPERATURE` | 0.9 | Sampling temperature (0.1-2.0) |
|
| 90 |
+
| `TOP_K` | 50 | Top-k sampling limit |
|
| 91 |
+
| `TOP_P` | 0.9 | Nucleus sampling threshold |
|
| 92 |
+
| `GENERATE_LENGTH` | 200 | Default generation length |
|
| 93 |
+
|
| 94 |
+
## 🎮 Chat Commands
|
| 95 |
+
|
| 96 |
+
In the interactive chat mode, you can use these commands:
|
| 97 |
+
|
| 98 |
+
- **Basic Chat**: Just type your message
|
| 99 |
+
- **`quit`/`exit`/`bye`**: End the conversation
|
| 100 |
+
- **`help`**: Show available commands
|
| 101 |
+
- **`clear`**: Clear the screen
|
| 102 |
+
- **`model`**: Display model information
|
| 103 |
+
- **`temp X`**: Set temperature (e.g., `temp 0.8`)
|
| 104 |
+
- **`length X`**: Set response length (e.g., `length 150`)
|
| 105 |
+
|
| 106 |
+
## 🧪 Example Usage
|
| 107 |
+
|
| 108 |
+
### Training Example
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
# Train the model (automatic multi-file loading)
|
| 112 |
+
python AgGPT21.py
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
Output:
|
| 116 |
+
```
|
| 117 |
+
Found 42 training files
|
| 118 |
+
Reading corpora_000.txt...
|
| 119 |
+
Reading corpora_001.txt...
|
| 120 |
+
...
|
| 121 |
+
Total words loaded: 2,847,392
|
| 122 |
+
Vocabulary size: 30,000
|
| 123 |
+
Tokens used: 1,000,000 | device=mps
|
| 124 |
+
Model params: 4,099,200
|
| 125 |
+
Train batches per epoch: 1,562 | Val batches: 79
|
| 126 |
+
Epochs: 100%|████████████| 6/6 [05:23<00:00, 53.92s/it, train=2.1847, val=2.3456]
|
| 127 |
+
Saved AgGPT21.pt
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### Chat Example
|
| 131 |
+
|
| 132 |
+
```
|
| 133 |
+
👤 You: Tell me about artificial intelligence
|
| 134 |
+
|
| 135 |
+
🤖 AgGPT-21: Artificial intelligence is a fascinating field that focuses on creating systems capable of performing tasks that typically require human intelligence. These systems can learn from data, recognize patterns, make decisions, and solve complex problems. AI has applications in many areas including natural language processing, computer vision, robotics, and machine learning...
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
## 🔧 Advanced Usage
|
| 139 |
+
|
| 140 |
+
### Custom Vocabulary Size
|
| 141 |
+
|
| 142 |
+
```python
|
| 143 |
+
MAX_VOCAB = 50000 # Increase vocabulary size
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### Training on Subset of Data
|
| 147 |
+
|
| 148 |
+
```python
|
| 149 |
+
DATA_PERCENT = 0.5 # Use only 50% of available data
|
| 150 |
+
MAX_TOKENS = 500_000 # Limit to 500k tokens
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
### Multi-GPU Training
|
| 154 |
+
|
| 155 |
+
```python
|
| 156 |
+
# The model automatically detects and uses available accelerators:
|
| 157 |
+
# - CUDA (NVIDIA GPUs)
|
| 158 |
+
# - MPS (Apple Silicon)
|
| 159 |
+
# - CPU (fallback)
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## 📊 Model Architecture
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
Input → Embedding → Dropout → GRU → Dropout → [Projection] → Linear → Output
|
| 166 |
+
↑ ↓ ↓ ↓
|
| 167 |
+
Token Vector Hidden Logits
|
| 168 |
+
IDs (128-dim) States (Vocab-size)
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
**Key Features:**
|
| 172 |
+
- **Weight Tying**: Output layer shares weights with embedding layer
|
| 173 |
+
- **Gradient Clipping**: Prevents exploding gradients
|
| 174 |
+
- **Mixed Precision**: Automatic FP16 on supported devices
|
| 175 |
+
- **Early Stopping**: Validation-based training termination
|
| 176 |
+
|
| 177 |
+
## 🎯 Performance Tips
|
| 178 |
+
|
| 179 |
+
1. **GPU Acceleration**: Use CUDA or MPS for faster training
|
| 180 |
+
2. **Batch Size**: Increase if you have more memory
|
| 181 |
+
3. **Sequence Length**: Longer sequences capture more context
|
| 182 |
+
4. **Vocabulary**: Smaller vocab = faster training, larger vocab = better coverage
|
| 183 |
+
5. **Data Quality**: Clean, relevant training data improves results
|
| 184 |
+
|
| 185 |
+
## 🐛 Troubleshooting
|
| 186 |
+
|
| 187 |
+
### Common Issues
|
| 188 |
+
|
| 189 |
+
**"No .txt files found"**
|
| 190 |
+
- Ensure your training files are in `training_corpora/` with `.txt` extension
|
| 191 |
+
|
| 192 |
+
**"CUDA out of memory"**
|
| 193 |
+
- Reduce `BATCH_SIZE` or `SEQ_LEN`
|
| 194 |
+
- Use `DATA_PERCENT < 1.0` to train on less data
|
| 195 |
+
|
| 196 |
+
**"Model file not found"**
|
| 197 |
+
- Train the model first with `python AgGPT21.py`
|
| 198 |
+
- Ensure `AgGPT21.pt` exists in the project directory
|
| 199 |
+
|
| 200 |
+
## 📈 Training Data Format
|
| 201 |
+
|
| 202 |
+
Your training files should be plain text. The model will automatically:
|
| 203 |
+
- Convert to lowercase
|
| 204 |
+
- Split on whitespace
|
| 205 |
+
- Handle special tokens like `<pad>`, `<eos>`, etc.
|
| 206 |
+
- Build vocabulary from all files combined
|
| 207 |
+
|
| 208 |
+
Example format:
|
| 209 |
+
```
|
| 210 |
+
user: how are you today
|
| 211 |
+
<pad>
|
| 212 |
+
ai: I'm doing well, thank you for asking! How are you?
|
| 213 |
+
<eos>
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
## 🤝 Contributing
|
| 217 |
+
|
| 218 |
+
1. Fork the repository
|
| 219 |
+
2. Create a feature branch
|
| 220 |
+
3. Make your improvements
|
| 221 |
+
4. Test thoroughly
|
| 222 |
+
5. Submit a pull request
|
| 223 |
+
|
| 224 |
+
## 📄 License
|
| 225 |
+
|
| 226 |
+
This project is open source. Feel free to use, modify, and distribute as needed.
|
| 227 |
+
|
| 228 |
+
## 🙋♂️ Support
|
| 229 |
+
|
| 230 |
+
If you encounter issues or have questions:
|
| 231 |
+
|
| 232 |
+
1. Check the troubleshooting section
|
| 233 |
+
2. Review your training data format
|
| 234 |
+
3. Ensure all dependencies are installed
|
| 235 |
+
4. Verify your PyTorch installation supports your hardware
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
**Made with ❤️ for the AI community**
|
| 240 |
+
|
| 241 |
+
*AgGPT-21 - Where conversation meets intelligence.*
|
banner.png
ADDED
|
Git LFS Details
|
chat.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AgGPT-21 Interactive Chat Interface
|
| 4 |
+
A conversational interface for the trained AgGPT-21 model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import torch
|
| 10 |
+
from AgGPT21 import WordRNN, generate_text, MODEL_FILE, DEVICE
|
| 11 |
+
|
| 12 |
+
def load_model():
|
| 13 |
+
"""Load the trained AgGPT-21 model."""
|
| 14 |
+
if not os.path.exists(MODEL_FILE):
|
| 15 |
+
print(f"❌ Model file '{MODEL_FILE}' not found!")
|
| 16 |
+
print("Please train the model first by running: python AgGPT21.py")
|
| 17 |
+
sys.exit(1)
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
print("🔄 Loading AgGPT-21 model...")
|
| 21 |
+
ckpt = torch.load(MODEL_FILE, map_location=DEVICE)
|
| 22 |
+
stoi = ckpt["stoi"]
|
| 23 |
+
itos = ckpt["itos"]
|
| 24 |
+
model = WordRNN(len(stoi))
|
| 25 |
+
model.load_state_dict(ckpt["model_state"])
|
| 26 |
+
model.eval()
|
| 27 |
+
|
| 28 |
+
# Count parameters
|
| 29 |
+
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 30 |
+
print(f"✅ Model loaded successfully!")
|
| 31 |
+
print(f" • Parameters: {param_count:,}")
|
| 32 |
+
print(f" • Vocabulary size: {len(stoi):,}")
|
| 33 |
+
print(f" • Device: {DEVICE}")
|
| 34 |
+
print()
|
| 35 |
+
|
| 36 |
+
return model, stoi, itos
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"❌ Error loading model: {e}")
|
| 39 |
+
sys.exit(1)
|
| 40 |
+
|
| 41 |
+
def print_banner():
|
| 42 |
+
"""Display the AgGPT-21 banner."""
|
| 43 |
+
banner = """
|
| 44 |
+
╔══════════════════════════════════════════════════════════════════════╗
|
| 45 |
+
║ 🤖 AgGPT-21 🤖 ║
|
| 46 |
+
║ Interactive Chat Interface ║
|
| 47 |
+
║ ║
|
| 48 |
+
║ • Type your message and press Enter to chat ║
|
| 49 |
+
║ • Use 'quit', 'exit', or 'bye' to end the conversation ║
|
| 50 |
+
║ • Use 'help' for more options ║
|
| 51 |
+
╚══════════════════════════════════════════════════════════════════════╝
|
| 52 |
+
"""
|
| 53 |
+
print(banner)
|
| 54 |
+
|
| 55 |
+
def print_help():
|
| 56 |
+
"""Display help information."""
|
| 57 |
+
help_text = """
|
| 58 |
+
🔧 AgGPT-21 Chat Commands:
|
| 59 |
+
• Just type your message to chat with the AI
|
| 60 |
+
• 'quit', 'exit', 'bye' - End the conversation
|
| 61 |
+
• 'help' - Show this help message
|
| 62 |
+
• 'clear' - Clear the screen
|
| 63 |
+
• 'model' - Show model information
|
| 64 |
+
• 'temp X' - Set temperature (e.g., 'temp 0.8')
|
| 65 |
+
• 'length X' - Set response length (e.g., 'length 150')
|
| 66 |
+
|
| 67 |
+
🎛️ Current Settings:
|
| 68 |
+
• Temperature: Controls creativity (0.1-2.0, default: 0.9)
|
| 69 |
+
• Length: Number of words to generate (50-500, default: 200)
|
| 70 |
+
"""
|
| 71 |
+
print(help_text)
|
| 72 |
+
|
| 73 |
+
def main():
|
| 74 |
+
"""Main chat loop."""
|
| 75 |
+
print_banner()
|
| 76 |
+
|
| 77 |
+
# Load the model
|
| 78 |
+
model, stoi, itos = load_model()
|
| 79 |
+
|
| 80 |
+
# Chat settings
|
| 81 |
+
temperature = 0.9
|
| 82 |
+
length = 200
|
| 83 |
+
top_k = 50
|
| 84 |
+
top_p = 0.9
|
| 85 |
+
|
| 86 |
+
print("💬 Chat started! Type your message below:")
|
| 87 |
+
print("="*70)
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
try:
|
| 91 |
+
# Get user input
|
| 92 |
+
user_input = input("\n👤 You: ").strip()
|
| 93 |
+
|
| 94 |
+
if not user_input:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
# Handle commands
|
| 98 |
+
user_lower = user_input.lower()
|
| 99 |
+
|
| 100 |
+
if user_lower in ['quit', 'exit', 'bye']:
|
| 101 |
+
print("\n👋 Goodbye! Thanks for chatting with AgGPT-21!")
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
elif user_lower == 'help':
|
| 105 |
+
print_help()
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
elif user_lower == 'clear':
|
| 109 |
+
os.system('clear' if os.name == 'posix' else 'cls')
|
| 110 |
+
print_banner()
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
elif user_lower == 'model':
|
| 114 |
+
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 115 |
+
print(f"\n🤖 Model Information:")
|
| 116 |
+
print(f" • Parameters: {param_count:,}")
|
| 117 |
+
print(f" • Vocabulary: {len(stoi):,} words")
|
| 118 |
+
print(f" • Device: {DEVICE}")
|
| 119 |
+
print(f" • Temperature: {temperature}")
|
| 120 |
+
print(f" • Max length: {length}")
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
elif user_lower.startswith('temp '):
|
| 124 |
+
try:
|
| 125 |
+
new_temp = float(user_lower.split()[1])
|
| 126 |
+
if 0.1 <= new_temp <= 2.0:
|
| 127 |
+
temperature = new_temp
|
| 128 |
+
print(f"🌡️ Temperature set to {temperature}")
|
| 129 |
+
else:
|
| 130 |
+
print("❌ Temperature must be between 0.1 and 2.0")
|
| 131 |
+
except (IndexError, ValueError):
|
| 132 |
+
print("❌ Invalid temperature. Use: temp 0.8")
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
elif user_lower.startswith('length '):
|
| 136 |
+
try:
|
| 137 |
+
new_length = int(user_lower.split()[1])
|
| 138 |
+
if 50 <= new_length <= 500:
|
| 139 |
+
length = new_length
|
| 140 |
+
print(f"📏 Response length set to {length} words")
|
| 141 |
+
else:
|
| 142 |
+
print("❌ Length must be between 50 and 500")
|
| 143 |
+
except (IndexError, ValueError):
|
| 144 |
+
print("❌ Invalid length. Use: length 150")
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
# Generate AI response
|
| 148 |
+
print(f"\n🤖 AgGPT-21 (thinking...)", end="", flush=True)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
response = generate_text(
|
| 152 |
+
model=model,
|
| 153 |
+
stoi=stoi,
|
| 154 |
+
itos=itos,
|
| 155 |
+
prompt=user_input,
|
| 156 |
+
length=length,
|
| 157 |
+
temperature=temperature,
|
| 158 |
+
top_k=top_k,
|
| 159 |
+
top_p=top_p,
|
| 160 |
+
device=DEVICE
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Clean up the response (remove the original prompt)
|
| 164 |
+
response_words = response.split()
|
| 165 |
+
prompt_words = user_input.lower().split()
|
| 166 |
+
|
| 167 |
+
# Find where the new content starts
|
| 168 |
+
if len(response_words) > len(prompt_words):
|
| 169 |
+
ai_response = " ".join(response_words[len(prompt_words):])
|
| 170 |
+
else:
|
| 171 |
+
ai_response = response
|
| 172 |
+
|
| 173 |
+
print(f"\r🤖 AgGPT-21: {ai_response}")
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"\r❌ Error generating response: {e}")
|
| 177 |
+
|
| 178 |
+
except KeyboardInterrupt:
|
| 179 |
+
print("\n\n👋 Chat interrupted. Goodbye!")
|
| 180 |
+
break
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"\n❌ Unexpected error: {e}")
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
main()
|