darwinkernelpanic commited on
Commit
2827773
·
verified ·
1 Parent(s): 7d0032b

Upload train_autogrow.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_autogrow.py +149 -0
train_autogrow.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers import DDPMScheduler
5
+ from transformers import AutoTokenizer
6
+ from datasets import load_dataset
7
+ import os
8
+ import time
9
+ import math
10
+ from huggingface_hub import HfApi
11
+
12
+ # --- FAILPROOF CONFIG ---
13
+ MODEL_PATH = "./DiffReaper-Talk"
14
+ REPO_ID = "darwinkernelpanic/DiffReaper-5"
15
+ HF_TOKEN = "${HF_TOKEN}"
16
+ OUTPUT_DIR = "./training_output"
17
+ LOG_FILE = "training.log"
18
+ BATCH_SIZE = 16 # Lower for 3090 VRAM
19
+ LEARNING_RATE = 1e-4
20
+ SAVE_EVERY = 2500
21
+ TEST_EVERY = 500
22
+
23
+ N_EMBD = 1024
24
+ N_HEAD = 16
25
+ N_LAYER = 12
26
+ MAX_PROMPT_LEN = 32
27
+ MAX_RESP_LEN = 32
28
+ TOTAL_LEN = MAX_PROMPT_LEN + MAX_RESP_LEN
29
+
30
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
31
+
32
+ def log(msg):
33
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
34
+ formatted = f"[{timestamp}] {msg}"
35
+ print(formatted)
36
+ with open(LOG_FILE, "a") as f:
37
+ f.write(formatted + "\n")
38
+
39
+ class TimeEmbedding(nn.Module):
40
+ def __init__(self, n_embd):
41
+ super().__init__()
42
+ self.mlp = nn.Sequential(nn.Linear(n_embd, n_embd), nn.GELU(), nn.Linear(n_embd, n_embd))
43
+ def forward(self, t):
44
+ half_dim = N_EMBD // 2
45
+ emb = math.log(10000) / (half_dim - 1)
46
+ emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
47
+ emb = t[:, None] * emb[None, :]
48
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
49
+ return self.mlp(emb)
50
+
51
+ class DiffReaperBlock(nn.Module):
52
+ def __init__(self, n_embd, n_head):
53
+ super().__init__()
54
+ self.ln1 = nn.LayerNorm(n_embd)
55
+ self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True)
56
+ self.ln2 = nn.LayerNorm(n_embd)
57
+ self.mlp = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd))
58
+ self.time_mlp = nn.Linear(n_embd, n_embd * 2)
59
+ def forward(self, x, t_emb):
60
+ time_params = self.time_mlp(t_emb).unsqueeze(1)
61
+ scale, shift = time_params.chunk(2, dim=-1)
62
+ x_norm = self.ln1(x) * (1 + scale) + shift
63
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm)
64
+ x = x + attn_out
65
+ x = x + self.mlp(self.ln2(x))
66
+ return x
67
+
68
+ class DiffReaperModel(nn.Module):
69
+ def __init__(self, vocab_size, n_embd, n_head, n_layer):
70
+ super().__init__()
71
+ self.token_embedding = nn.Embedding(vocab_size, n_embd)
72
+ self.pos_embedding = nn.Parameter(torch.zeros(1, TOTAL_LEN, n_embd))
73
+ self.time_embed = TimeEmbedding(n_embd)
74
+ self.blocks = nn.ModuleList([DiffReaperBlock(n_embd, n_head) for _ in range(n_layer)])
75
+ self.ln_f = nn.LayerNorm(n_embd)
76
+ def forward(self, x_input, t):
77
+ t_emb = self.time_embed(t)
78
+ x = x_input + self.pos_embedding[:, :x_input.shape[1], :]
79
+ for block in self.blocks: x = block(x, t_emb)
80
+ return self.ln_f(x)
81
+
82
+ log("Initializing Autogrow Model...")
83
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
84
+ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ model = DiffReaperModel(tokenizer.vocab_size, N_EMBD, N_HEAD, N_LAYER).to("cuda")
87
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
88
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
89
+
90
+ log("Loading Dataset...")
91
+ dataset = load_dataset("OpenAssistant/oasst1", split="train")
92
+ def tokenize_function(examples):
93
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=TOTAL_LEN)
94
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
95
+ tokenized_dataset.set_format("torch")
96
+ dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True)
97
+
98
+ def run_test(step):
99
+ log(f"Running Cropmark Diagnostic [Step {step}]...")
100
+ model.eval()
101
+ with torch.no_grad():
102
+ prompt = "Hello! Who are you?"
103
+ p_tokens = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")[:, :MAX_PROMPT_LEN]
104
+ p_padded = torch.full((1, MAX_PROMPT_LEN), tokenizer.pad_token_id, device="cuda")
105
+ p_padded[:, :p_tokens.shape[1]] = p_tokens
106
+ p_emb = model.token_embedding(p_padded)
107
+ r_noise = torch.randn(1, MAX_RESP_LEN, N_EMBD).to("cuda")
108
+ for i in range(10):
109
+ t = torch.tensor([1000 - (i*100) - 1], device="cuda").long()
110
+ pred = model(torch.cat([p_emb, r_noise], dim=1), t)
111
+ r_noise = 0.4 * r_noise + 0.6 * pred[:, MAX_PROMPT_LEN:, :]
112
+ norm_weights = F.normalize(model.token_embedding.weight, dim=-1)
113
+ norm_r = F.normalize(r_noise, dim=-1)
114
+ logits = torch.matmul(norm_r, norm_weights.T)
115
+ resp_ids = torch.argmax(logits, dim=-1)
116
+ log(f"Prompt: '{prompt}' | [Cropmark]: '{tokenizer.decode(resp_ids[0], skip_special_tokens=True)}'")
117
+ model.train()
118
+
119
+ log("Autonomous growth starting...")
120
+ api = HfApi()
121
+ start_time = time.time()
122
+ step = 0
123
+ while True: # Unlimited steps, controlled by your credit
124
+ for batch in dataloader:
125
+ optimizer.zero_grad()
126
+ input_ids = batch["input_ids"].to("cuda")
127
+ prompt_emb = model.token_embedding(input_ids[:, :MAX_PROMPT_LEN])
128
+ resp_emb = model.token_embedding(input_ids[:, MAX_PROMPT_LEN:])
129
+
130
+ noise = torch.randn_like(resp_emb)
131
+ t = torch.randint(0, 1000, (input_ids.shape[0],), device="cuda").long()
132
+ noisy_resp = noise_scheduler.add_noise(resp_emb, noise, t)
133
+
134
+ pred_resp = model(torch.cat([prompt_emb, noisy_resp], dim=1), t)[:, MAX_PROMPT_LEN:, :]
135
+ loss = 1 - F.cosine_similarity(pred_resp, resp_emb, dim=-1).mean()
136
+ loss.backward()
137
+ optimizer.step()
138
+
139
+ if step % 100 == 0:
140
+ elapsed = time.time() - start_time
141
+ log(f"Step {step} - Loss: {loss.item():.6f} - Speed: {(step+1)/elapsed:.2f} s/s")
142
+ if step > 0 and step % TEST_EVERY == 0: run_test(step)
143
+ if step > 0 and step % SAVE_EVERY == 0:
144
+ ckpt_path = os.path.join(OUTPUT_DIR, f"cropmark_latest.pt")
145
+ torch.save(model.state_dict(), ckpt_path)
146
+ log("Syncing to HF...")
147
+ try: api.upload_file(path_or_fileobj=ckpt_path, path_in_repo="cropmark_latest.pt", repo_id=REPO_ID, token=HF_TOKEN)
148
+ except Exception as e: log(f"HF Sync Error: {e}")
149
+ step += 1