Upload train_diffreaper.py with huggingface_hub
Browse files- train_diffreaper.py +186 -0
train_diffreaper.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from diffusers import DDPMScheduler
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
# --- DiffReaper Configuration ---
|
| 11 |
+
MODEL_PATH = "./DiffReaper-Talk"
|
| 12 |
+
OUTPUT_DIR = "./training_output"
|
| 13 |
+
LOG_FILE = "training.log"
|
| 14 |
+
BATCH_SIZE = 32
|
| 15 |
+
LEARNING_RATE = 1e-4
|
| 16 |
+
NUM_STEPS = 50000
|
| 17 |
+
SAVE_EVERY = 2500
|
| 18 |
+
TEST_EVERY = 500
|
| 19 |
+
|
| 20 |
+
N_EMBD = 1024
|
| 21 |
+
N_HEAD = 16
|
| 22 |
+
N_LAYER = 12
|
| 23 |
+
MAX_PROMPT_LEN = 32
|
| 24 |
+
MAX_RESP_LEN = 32
|
| 25 |
+
TOTAL_LEN = MAX_PROMPT_LEN + MAX_RESP_LEN
|
| 26 |
+
|
| 27 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
def log(msg):
|
| 30 |
+
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
| 31 |
+
formatted = f"[{timestamp}] {msg}"
|
| 32 |
+
print(formatted)
|
| 33 |
+
with open(LOG_FILE, "a") as f:
|
| 34 |
+
f.write(formatted + "\n")
|
| 35 |
+
|
| 36 |
+
# --- Time Embedding (Sinusoidal) ---
|
| 37 |
+
class TimeEmbedding(nn.Module):
|
| 38 |
+
def __init__(self, n_embd):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.mlp = nn.Sequential(
|
| 41 |
+
nn.Linear(n_embd, n_embd),
|
| 42 |
+
nn.GELU(),
|
| 43 |
+
nn.Linear(n_embd, n_embd),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, t):
|
| 47 |
+
half_dim = N_EMBD // 2
|
| 48 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 49 |
+
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
|
| 50 |
+
emb = t[:, None] * emb[None, :]
|
| 51 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 52 |
+
return self.mlp(emb)
|
| 53 |
+
|
| 54 |
+
# --- Conditioned Diffusion Block ---
|
| 55 |
+
class DiffReaperBlock(nn.Module):
|
| 56 |
+
def __init__(self, n_embd, n_head):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 59 |
+
self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True)
|
| 60 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 61 |
+
self.mlp = nn.Sequential(
|
| 62 |
+
nn.Linear(n_embd, 4 * n_embd),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.Linear(4 * n_embd, n_embd),
|
| 65 |
+
)
|
| 66 |
+
self.time_mlp = nn.Linear(n_embd, n_embd * 2)
|
| 67 |
+
|
| 68 |
+
def forward(self, x, t_emb):
|
| 69 |
+
time_params = self.time_mlp(t_emb).unsqueeze(1)
|
| 70 |
+
scale, shift = time_params.chunk(2, dim=-1)
|
| 71 |
+
|
| 72 |
+
x_norm = self.ln1(x)
|
| 73 |
+
x_norm = x_norm * (1 + scale) + shift
|
| 74 |
+
|
| 75 |
+
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
|
| 76 |
+
x = x + attn_out
|
| 77 |
+
x = x + self.mlp(self.ln2(x))
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
class DiffReaperModel(nn.Module):
|
| 81 |
+
def __init__(self, vocab_size, n_embd, n_head, n_layer):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.token_embedding = nn.Embedding(vocab_size, n_embd)
|
| 84 |
+
self.pos_embedding = nn.Parameter(torch.zeros(1, TOTAL_LEN, n_embd))
|
| 85 |
+
self.time_embed = TimeEmbedding(n_embd)
|
| 86 |
+
self.blocks = nn.ModuleList([DiffReaperBlock(n_embd, n_head) for _ in range(n_layer)])
|
| 87 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 88 |
+
|
| 89 |
+
def forward(self, x_input, t):
|
| 90 |
+
t_emb = self.time_embed(t)
|
| 91 |
+
x = x_input + self.pos_embedding[:, :x_input.shape[1], :]
|
| 92 |
+
for block in self.blocks:
|
| 93 |
+
x = block(x, t_emb)
|
| 94 |
+
return self.ln_f(x)
|
| 95 |
+
|
| 96 |
+
log("Initializing DiffReaper Parallel Form (Conditioned Diffusion)...")
|
| 97 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 98 |
+
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
| 99 |
+
|
| 100 |
+
model = DiffReaperModel(tokenizer.vocab_size, N_EMBD, N_HEAD, N_LAYER).to("cuda")
|
| 101 |
+
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
|
| 102 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 103 |
+
|
| 104 |
+
log("Loading Dataset (Prompt/Response focus)...")
|
| 105 |
+
dataset = load_dataset("OpenAssistant/oasst1", split="train")
|
| 106 |
+
|
| 107 |
+
def tokenize_function(examples):
|
| 108 |
+
tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=TOTAL_LEN)
|
| 109 |
+
return tokens
|
| 110 |
+
|
| 111 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
|
| 112 |
+
tokenized_dataset.set_format("torch")
|
| 113 |
+
dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 114 |
+
data_iter = iter(dataloader)
|
| 115 |
+
|
| 116 |
+
def get_batch():
|
| 117 |
+
global data_iter
|
| 118 |
+
try: batch = next(data_iter)
|
| 119 |
+
except StopIteration:
|
| 120 |
+
data_iter = iter(dataloader)
|
| 121 |
+
batch = next(data_iter)
|
| 122 |
+
return batch["input_ids"].to("cuda")
|
| 123 |
+
|
| 124 |
+
def run_test():
|
| 125 |
+
log("Running Cropmark Diagnostic...")
|
| 126 |
+
model.eval()
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
prompt = "How are you?"
|
| 129 |
+
p_tokens = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
|
| 130 |
+
p_tokens = p_tokens[:, :MAX_PROMPT_LEN]
|
| 131 |
+
p_padded = torch.full((1, MAX_PROMPT_LEN), tokenizer.pad_token_id, device="cuda")
|
| 132 |
+
p_padded[:, :p_tokens.shape[1]] = p_tokens
|
| 133 |
+
|
| 134 |
+
p_emb = model.token_embedding(p_padded)
|
| 135 |
+
r_noise = torch.randn(1, MAX_RESP_LEN, N_EMBD).to("cuda")
|
| 136 |
+
|
| 137 |
+
for i in range(10):
|
| 138 |
+
t = torch.tensor([1000 - (i*100) - 1], device="cuda").long()
|
| 139 |
+
combined = torch.cat([p_emb, r_noise], dim=1)
|
| 140 |
+
pred = model(combined, t)
|
| 141 |
+
r_0_pred = pred[:, MAX_PROMPT_LEN:, :]
|
| 142 |
+
r_noise = 0.4 * r_noise + 0.6 * r_0_pred
|
| 143 |
+
|
| 144 |
+
logits = torch.matmul(r_noise, model.token_embedding.weight.T)
|
| 145 |
+
resp_ids = torch.argmax(logits, dim=-1)
|
| 146 |
+
log(f"Prompt: '{prompt}' | [Cropmark]: '{tokenizer.decode(resp_ids[0], skip_special_tokens=False)}'")
|
| 147 |
+
|
| 148 |
+
model.train()
|
| 149 |
+
|
| 150 |
+
log("Starting Conditioned Cropmark Training...")
|
| 151 |
+
start_time = time.time()
|
| 152 |
+
|
| 153 |
+
for step in range(NUM_STEPS):
|
| 154 |
+
optimizer.zero_grad()
|
| 155 |
+
|
| 156 |
+
input_ids = get_batch()
|
| 157 |
+
prompt_ids = input_ids[:, :MAX_PROMPT_LEN]
|
| 158 |
+
resp_ids = input_ids[:, MAX_PROMPT_LEN:]
|
| 159 |
+
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
prompt_emb = model.token_embedding(prompt_ids)
|
| 162 |
+
resp_emb = model.token_embedding(resp_ids)
|
| 163 |
+
|
| 164 |
+
noise = torch.randn_like(resp_emb)
|
| 165 |
+
t = torch.randint(0, 1000, (input_ids.shape[0],), device="cuda").long()
|
| 166 |
+
noisy_resp = noise_scheduler.add_noise(resp_emb, noise, t)
|
| 167 |
+
|
| 168 |
+
combined_input = torch.cat([prompt_emb, noisy_resp], dim=1)
|
| 169 |
+
|
| 170 |
+
predicted_clean_combined = model(combined_input, t)
|
| 171 |
+
predicted_resp = predicted_clean_combined[:, MAX_PROMPT_LEN:, :]
|
| 172 |
+
|
| 173 |
+
mask = (resp_ids != tokenizer.pad_token_id).unsqueeze(-1).expand_as(resp_emb).float()
|
| 174 |
+
loss = torch.nn.functional.mse_loss(predicted_resp * mask, resp_emb * mask, reduction='sum') / (mask.sum() + 1e-8)
|
| 175 |
+
|
| 176 |
+
loss.backward()
|
| 177 |
+
optimizer.step()
|
| 178 |
+
|
| 179 |
+
if step % 50 == 0:
|
| 180 |
+
elapsed = time.time() - start_time
|
| 181 |
+
log(f"Step {step}/{NUM_STEPS} - Loss: {loss.item():.6f} - Speed: {(step+1)/elapsed:.2f} s/s")
|
| 182 |
+
|
| 183 |
+
if step > 0 and step % TEST_EVERY == 0:
|
| 184 |
+
run_test()
|
| 185 |
+
|
| 186 |
+
log("Training complete.")
|