darwinkernelpanic commited on
Commit
89b6927
·
verified ·
1 Parent(s): 12d112c

Upload train_diffreaper.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_diffreaper.py +186 -0
train_diffreaper.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusers import DDPMScheduler
4
+ from transformers import AutoTokenizer
5
+ from datasets import load_dataset
6
+ import os
7
+ import time
8
+ import math
9
+
10
+ # --- DiffReaper Configuration ---
11
+ MODEL_PATH = "./DiffReaper-Talk"
12
+ OUTPUT_DIR = "./training_output"
13
+ LOG_FILE = "training.log"
14
+ BATCH_SIZE = 32
15
+ LEARNING_RATE = 1e-4
16
+ NUM_STEPS = 50000
17
+ SAVE_EVERY = 2500
18
+ TEST_EVERY = 500
19
+
20
+ N_EMBD = 1024
21
+ N_HEAD = 16
22
+ N_LAYER = 12
23
+ MAX_PROMPT_LEN = 32
24
+ MAX_RESP_LEN = 32
25
+ TOTAL_LEN = MAX_PROMPT_LEN + MAX_RESP_LEN
26
+
27
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
28
+
29
+ def log(msg):
30
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
31
+ formatted = f"[{timestamp}] {msg}"
32
+ print(formatted)
33
+ with open(LOG_FILE, "a") as f:
34
+ f.write(formatted + "\n")
35
+
36
+ # --- Time Embedding (Sinusoidal) ---
37
+ class TimeEmbedding(nn.Module):
38
+ def __init__(self, n_embd):
39
+ super().__init__()
40
+ self.mlp = nn.Sequential(
41
+ nn.Linear(n_embd, n_embd),
42
+ nn.GELU(),
43
+ nn.Linear(n_embd, n_embd),
44
+ )
45
+
46
+ def forward(self, t):
47
+ half_dim = N_EMBD // 2
48
+ emb = math.log(10000) / (half_dim - 1)
49
+ emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
50
+ emb = t[:, None] * emb[None, :]
51
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
52
+ return self.mlp(emb)
53
+
54
+ # --- Conditioned Diffusion Block ---
55
+ class DiffReaperBlock(nn.Module):
56
+ def __init__(self, n_embd, n_head):
57
+ super().__init__()
58
+ self.ln1 = nn.LayerNorm(n_embd)
59
+ self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True)
60
+ self.ln2 = nn.LayerNorm(n_embd)
61
+ self.mlp = nn.Sequential(
62
+ nn.Linear(n_embd, 4 * n_embd),
63
+ nn.GELU(),
64
+ nn.Linear(4 * n_embd, n_embd),
65
+ )
66
+ self.time_mlp = nn.Linear(n_embd, n_embd * 2)
67
+
68
+ def forward(self, x, t_emb):
69
+ time_params = self.time_mlp(t_emb).unsqueeze(1)
70
+ scale, shift = time_params.chunk(2, dim=-1)
71
+
72
+ x_norm = self.ln1(x)
73
+ x_norm = x_norm * (1 + scale) + shift
74
+
75
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm)
76
+ x = x + attn_out
77
+ x = x + self.mlp(self.ln2(x))
78
+ return x
79
+
80
+ class DiffReaperModel(nn.Module):
81
+ def __init__(self, vocab_size, n_embd, n_head, n_layer):
82
+ super().__init__()
83
+ self.token_embedding = nn.Embedding(vocab_size, n_embd)
84
+ self.pos_embedding = nn.Parameter(torch.zeros(1, TOTAL_LEN, n_embd))
85
+ self.time_embed = TimeEmbedding(n_embd)
86
+ self.blocks = nn.ModuleList([DiffReaperBlock(n_embd, n_head) for _ in range(n_layer)])
87
+ self.ln_f = nn.LayerNorm(n_embd)
88
+
89
+ def forward(self, x_input, t):
90
+ t_emb = self.time_embed(t)
91
+ x = x_input + self.pos_embedding[:, :x_input.shape[1], :]
92
+ for block in self.blocks:
93
+ x = block(x, t_emb)
94
+ return self.ln_f(x)
95
+
96
+ log("Initializing DiffReaper Parallel Form (Conditioned Diffusion)...")
97
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
98
+ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
99
+
100
+ model = DiffReaperModel(tokenizer.vocab_size, N_EMBD, N_HEAD, N_LAYER).to("cuda")
101
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
102
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
103
+
104
+ log("Loading Dataset (Prompt/Response focus)...")
105
+ dataset = load_dataset("OpenAssistant/oasst1", split="train")
106
+
107
+ def tokenize_function(examples):
108
+ tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=TOTAL_LEN)
109
+ return tokens
110
+
111
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
112
+ tokenized_dataset.set_format("torch")
113
+ dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True)
114
+ data_iter = iter(dataloader)
115
+
116
+ def get_batch():
117
+ global data_iter
118
+ try: batch = next(data_iter)
119
+ except StopIteration:
120
+ data_iter = iter(dataloader)
121
+ batch = next(data_iter)
122
+ return batch["input_ids"].to("cuda")
123
+
124
+ def run_test():
125
+ log("Running Cropmark Diagnostic...")
126
+ model.eval()
127
+ with torch.no_grad():
128
+ prompt = "How are you?"
129
+ p_tokens = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
130
+ p_tokens = p_tokens[:, :MAX_PROMPT_LEN]
131
+ p_padded = torch.full((1, MAX_PROMPT_LEN), tokenizer.pad_token_id, device="cuda")
132
+ p_padded[:, :p_tokens.shape[1]] = p_tokens
133
+
134
+ p_emb = model.token_embedding(p_padded)
135
+ r_noise = torch.randn(1, MAX_RESP_LEN, N_EMBD).to("cuda")
136
+
137
+ for i in range(10):
138
+ t = torch.tensor([1000 - (i*100) - 1], device="cuda").long()
139
+ combined = torch.cat([p_emb, r_noise], dim=1)
140
+ pred = model(combined, t)
141
+ r_0_pred = pred[:, MAX_PROMPT_LEN:, :]
142
+ r_noise = 0.4 * r_noise + 0.6 * r_0_pred
143
+
144
+ logits = torch.matmul(r_noise, model.token_embedding.weight.T)
145
+ resp_ids = torch.argmax(logits, dim=-1)
146
+ log(f"Prompt: '{prompt}' | [Cropmark]: '{tokenizer.decode(resp_ids[0], skip_special_tokens=False)}'")
147
+
148
+ model.train()
149
+
150
+ log("Starting Conditioned Cropmark Training...")
151
+ start_time = time.time()
152
+
153
+ for step in range(NUM_STEPS):
154
+ optimizer.zero_grad()
155
+
156
+ input_ids = get_batch()
157
+ prompt_ids = input_ids[:, :MAX_PROMPT_LEN]
158
+ resp_ids = input_ids[:, MAX_PROMPT_LEN:]
159
+
160
+ with torch.no_grad():
161
+ prompt_emb = model.token_embedding(prompt_ids)
162
+ resp_emb = model.token_embedding(resp_ids)
163
+
164
+ noise = torch.randn_like(resp_emb)
165
+ t = torch.randint(0, 1000, (input_ids.shape[0],), device="cuda").long()
166
+ noisy_resp = noise_scheduler.add_noise(resp_emb, noise, t)
167
+
168
+ combined_input = torch.cat([prompt_emb, noisy_resp], dim=1)
169
+
170
+ predicted_clean_combined = model(combined_input, t)
171
+ predicted_resp = predicted_clean_combined[:, MAX_PROMPT_LEN:, :]
172
+
173
+ mask = (resp_ids != tokenizer.pad_token_id).unsqueeze(-1).expand_as(resp_emb).float()
174
+ loss = torch.nn.functional.mse_loss(predicted_resp * mask, resp_emb * mask, reduction='sum') / (mask.sum() + 1e-8)
175
+
176
+ loss.backward()
177
+ optimizer.step()
178
+
179
+ if step % 50 == 0:
180
+ elapsed = time.time() - start_time
181
+ log(f"Step {step}/{NUM_STEPS} - Loss: {loss.item():.6f} - Speed: {(step+1)/elapsed:.2f} s/s")
182
+
183
+ if step > 0 and step % TEST_EVERY == 0:
184
+ run_test()
185
+
186
+ log("Training complete.")