File size: 6,489 Bytes
89b6927 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import torch
import torch.nn as nn
from diffusers import DDPMScheduler
from transformers import AutoTokenizer
from datasets import load_dataset
import os
import time
import math
# --- DiffReaper Configuration ---
MODEL_PATH = "./DiffReaper-Talk"
OUTPUT_DIR = "./training_output"
LOG_FILE = "training.log"
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_STEPS = 50000
SAVE_EVERY = 2500
TEST_EVERY = 500
N_EMBD = 1024
N_HEAD = 16
N_LAYER = 12
MAX_PROMPT_LEN = 32
MAX_RESP_LEN = 32
TOTAL_LEN = MAX_PROMPT_LEN + MAX_RESP_LEN
os.makedirs(OUTPUT_DIR, exist_ok=True)
def log(msg):
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
formatted = f"[{timestamp}] {msg}"
print(formatted)
with open(LOG_FILE, "a") as f:
f.write(formatted + "\n")
# --- Time Embedding (Sinusoidal) ---
class TimeEmbedding(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(n_embd, n_embd),
nn.GELU(),
nn.Linear(n_embd, n_embd),
)
def forward(self, t):
half_dim = N_EMBD // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return self.mlp(emb)
# --- Conditioned Diffusion Block ---
class DiffReaperBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
self.time_mlp = nn.Linear(n_embd, n_embd * 2)
def forward(self, x, t_emb):
time_params = self.time_mlp(t_emb).unsqueeze(1)
scale, shift = time_params.chunk(2, dim=-1)
x_norm = self.ln1(x)
x_norm = x_norm * (1 + scale) + shift
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x
class DiffReaperModel(nn.Module):
def __init__(self, vocab_size, n_embd, n_head, n_layer):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.pos_embedding = nn.Parameter(torch.zeros(1, TOTAL_LEN, n_embd))
self.time_embed = TimeEmbedding(n_embd)
self.blocks = nn.ModuleList([DiffReaperBlock(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
def forward(self, x_input, t):
t_emb = self.time_embed(t)
x = x_input + self.pos_embedding[:, :x_input.shape[1], :]
for block in self.blocks:
x = block(x, t_emb)
return self.ln_f(x)
log("Initializing DiffReaper Parallel Form (Conditioned Diffusion)...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
model = DiffReaperModel(tokenizer.vocab_size, N_EMBD, N_HEAD, N_LAYER).to("cuda")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
log("Loading Dataset (Prompt/Response focus)...")
dataset = load_dataset("OpenAssistant/oasst1", split="train")
def tokenize_function(examples):
tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=TOTAL_LEN)
return tokens
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
tokenized_dataset.set_format("torch")
dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True)
data_iter = iter(dataloader)
def get_batch():
global data_iter
try: batch = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
batch = next(data_iter)
return batch["input_ids"].to("cuda")
def run_test():
log("Running Cropmark Diagnostic...")
model.eval()
with torch.no_grad():
prompt = "How are you?"
p_tokens = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
p_tokens = p_tokens[:, :MAX_PROMPT_LEN]
p_padded = torch.full((1, MAX_PROMPT_LEN), tokenizer.pad_token_id, device="cuda")
p_padded[:, :p_tokens.shape[1]] = p_tokens
p_emb = model.token_embedding(p_padded)
r_noise = torch.randn(1, MAX_RESP_LEN, N_EMBD).to("cuda")
for i in range(10):
t = torch.tensor([1000 - (i*100) - 1], device="cuda").long()
combined = torch.cat([p_emb, r_noise], dim=1)
pred = model(combined, t)
r_0_pred = pred[:, MAX_PROMPT_LEN:, :]
r_noise = 0.4 * r_noise + 0.6 * r_0_pred
logits = torch.matmul(r_noise, model.token_embedding.weight.T)
resp_ids = torch.argmax(logits, dim=-1)
log(f"Prompt: '{prompt}' | [Cropmark]: '{tokenizer.decode(resp_ids[0], skip_special_tokens=False)}'")
model.train()
log("Starting Conditioned Cropmark Training...")
start_time = time.time()
for step in range(NUM_STEPS):
optimizer.zero_grad()
input_ids = get_batch()
prompt_ids = input_ids[:, :MAX_PROMPT_LEN]
resp_ids = input_ids[:, MAX_PROMPT_LEN:]
with torch.no_grad():
prompt_emb = model.token_embedding(prompt_ids)
resp_emb = model.token_embedding(resp_ids)
noise = torch.randn_like(resp_emb)
t = torch.randint(0, 1000, (input_ids.shape[0],), device="cuda").long()
noisy_resp = noise_scheduler.add_noise(resp_emb, noise, t)
combined_input = torch.cat([prompt_emb, noisy_resp], dim=1)
predicted_clean_combined = model(combined_input, t)
predicted_resp = predicted_clean_combined[:, MAX_PROMPT_LEN:, :]
mask = (resp_ids != tokenizer.pad_token_id).unsqueeze(-1).expand_as(resp_emb).float()
loss = torch.nn.functional.mse_loss(predicted_resp * mask, resp_emb * mask, reduction='sum') / (mask.sum() + 1e-8)
loss.backward()
optimizer.step()
if step % 50 == 0:
elapsed = time.time() - start_time
log(f"Step {step}/{NUM_STEPS} - Loss: {loss.item():.6f} - Speed: {(step+1)/elapsed:.2f} s/s")
if step > 0 and step % TEST_EVERY == 0:
run_test()
log("Training complete.")
|