|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from diffusers import DDPMScheduler |
|
|
from transformers import AutoTokenizer |
|
|
from datasets import load_dataset |
|
|
import os |
|
|
import time |
|
|
import math |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
|
|
|
MODEL_PATH = "./DiffReaper-Talk" |
|
|
REPO_ID = "darwinkernelpanic/DiffReaper-5L" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
OUTPUT_DIR = "./training_output" |
|
|
LOG_FILE = "training.log" |
|
|
CHECKPOINT_LOG = "checkpoint_log.txt" |
|
|
BATCH_SIZE = 16 |
|
|
LEARNING_RATE = 1e-4 |
|
|
SAVE_EVERY = 2500 |
|
|
TEST_EVERY = 500 |
|
|
|
|
|
N_EMBD = 2048 |
|
|
N_HEAD = 32 |
|
|
N_LAYER = 24 |
|
|
MAX_PROMPT_LEN = 32 |
|
|
MAX_RESP_LEN = 128 |
|
|
TOTAL_LEN = MAX_PROMPT_LEN + MAX_RESP_LEN |
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
def log(msg): |
|
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S") |
|
|
formatted = f"[{timestamp}] {msg}" |
|
|
print(formatted) |
|
|
with open(LOG_FILE, "a") as f: |
|
|
f.write(formatted + "\n") |
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
|
def __init__(self, n_embd): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential(nn.Linear(n_embd, n_embd), nn.GELU(), nn.Linear(n_embd, n_embd)) |
|
|
def forward(self, t): |
|
|
half_dim = N_EMBD // 2 |
|
|
emb = math.log(10000) / (half_dim - 1) |
|
|
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) |
|
|
emb = t[:, None] * emb[None, :] |
|
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
|
return self.mlp(emb) |
|
|
|
|
|
class DiffReaperBlock(nn.Module): |
|
|
def __init__(self, n_embd, n_head): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(n_embd) |
|
|
self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True) |
|
|
self.ln2 = nn.LayerNorm(n_embd) |
|
|
self.mlp = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd)) |
|
|
self.time_mlp = nn.Linear(n_embd, n_embd * 2) |
|
|
def forward(self, x, t_emb): |
|
|
time_params = self.time_mlp(t_emb).unsqueeze(1) |
|
|
scale, shift = time_params.chunk(2, dim=-1) |
|
|
x_norm = self.ln1(x) * (1 + scale) + shift |
|
|
attn_out, _ = self.attn(x_norm, x_norm, x_norm) |
|
|
x = x + attn_out |
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class DiffReaperModel(nn.Module): |
|
|
def __init__(self, vocab_size, n_embd, n_head, n_layer): |
|
|
super().__init__() |
|
|
self.token_embedding = nn.Embedding(vocab_size, n_embd) |
|
|
self.pos_embedding = nn.Parameter(torch.zeros(1, TOTAL_LEN, n_embd)) |
|
|
self.time_embed = TimeEmbedding(n_embd) |
|
|
self.blocks = nn.ModuleList([DiffReaperBlock(n_embd, n_head) for _ in range(n_layer)]) |
|
|
self.ln_f = nn.LayerNorm(n_embd) |
|
|
def forward(self, x_input, t): |
|
|
t_emb = self.time_embed(t) |
|
|
x = x_input + self.pos_embedding[:, :x_input.shape[1], :] |
|
|
for block in self.blocks: x = block(x, t_emb) |
|
|
return self.ln_f(x) |
|
|
|
|
|
def run_test(model, tokenizer, step): |
|
|
log(f"Running Cropmark Diagnostic [Step {step}]...") |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
prompt = "Hello! Who are you today?" |
|
|
p_tokens = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")[:, :MAX_PROMPT_LEN] |
|
|
p_padded = torch.full((1, MAX_PROMPT_LEN), tokenizer.pad_token_id, device="cuda") |
|
|
p_padded[:, :p_tokens.shape[1]] = p_tokens |
|
|
p_emb = model.token_embedding(p_padded) |
|
|
r_noise = torch.randn(1, MAX_RESP_LEN, N_EMBD).to("cuda") |
|
|
for i in range(10): |
|
|
t = torch.tensor([1000 - (i*100) - 1], device="cuda").long() |
|
|
pred = model(torch.cat([p_emb, r_noise], dim=1), t)[:, MAX_PROMPT_LEN:, :] |
|
|
r_noise = 0.4 * r_noise + 0.6 * pred |
|
|
norm_weights = F.normalize(model.token_embedding.weight, dim=-1) |
|
|
norm_r = F.normalize(r_noise, dim=-1) |
|
|
logits = torch.matmul(norm_r, norm_weights.T) |
|
|
resp_ids = torch.argmax(logits, dim=-1) |
|
|
result = tokenizer.decode(resp_ids[0], skip_special_tokens=True) |
|
|
log(f"Prompt: '{prompt}' | [Cropmark]: '{result}'") |
|
|
with open(CHECKPOINT_LOG, "a") as f: |
|
|
f.write(f"Step {step} - Prompt: '{prompt}' | [Cropmark]: '{result}'\n") |
|
|
model.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
log("Initializing DiffReaper-5L (Autopilot)...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token |
|
|
model = DiffReaperModel(tokenizer.vocab_size, N_EMBD, N_HEAD, N_LAYER).to("cuda") |
|
|
|
|
|
|
|
|
|
|
|
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2") |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) |
|
|
|
|
|
log("Loading OpenAssistant...") |
|
|
dataset = load_dataset("OpenAssistant/oasst1", split="train") |
|
|
def tokenize_function(examples): |
|
|
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=TOTAL_LEN) |
|
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) |
|
|
tokenized_dataset.set_format("torch") |
|
|
dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
|
|
|
log("DiffReaper-5L training starting...") |
|
|
api = HfApi() |
|
|
start_time = time.time() |
|
|
step = 0 |
|
|
while True: |
|
|
for batch in dataloader: |
|
|
optimizer.zero_grad() |
|
|
input_ids = batch["input_ids"].to("cuda") |
|
|
prompt_ids = input_ids[:, :MAX_PROMPT_LEN] |
|
|
resp_ids = input_ids[:, MAX_PROMPT_LEN:] |
|
|
prompt_emb = model.token_embedding(prompt_ids) |
|
|
resp_emb = model.token_embedding(resp_ids) |
|
|
|
|
|
noise = torch.randn_like(resp_emb) |
|
|
t = torch.randint(0, 1000, (input_ids.shape[0],), device="cuda").long() |
|
|
noisy_resp = noise_scheduler.add_noise(resp_emb, noise, t) |
|
|
|
|
|
pred_resp = model(torch.cat([prompt_emb, noisy_resp], dim=1), t)[:, MAX_PROMPT_LEN:, :] |
|
|
mask = (resp_ids != tokenizer.pad_token_id).float().unsqueeze(-1) |
|
|
loss = 1 - (F.cosine_similarity(pred_resp, resp_emb, dim=-1) * mask.squeeze(-1)).sum() / (mask.sum() + 1e-8) |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
if step % 50 == 0: |
|
|
elapsed = time.time() - start_time |
|
|
log(f"Step {step} - Loss: {loss.item():.6f} - Speed: {(step+1)/elapsed:.2f} s/s") |
|
|
if step > 0 and step % TEST_EVERY == 0: run_test(model, tokenizer, step) |
|
|
if step > 0 and step % SAVE_EVERY == 0: |
|
|
ckpt_path = os.path.join(OUTPUT_DIR, f"diffreaper5l_{step}.pt") |
|
|
torch.save(model.state_dict(), ckpt_path) |
|
|
log("Syncing to HF...") |
|
|
try: |
|
|
api.upload_file(path_or_fileobj=ckpt_path, path_in_repo=f"diffreaper5l_{step}.pt", repo_id=REPO_ID, token=HF_TOKEN) |
|
|
api.upload_file(path_or_fileobj=CHECKPOINT_LOG, path_in_repo="checkpoint_log.txt", repo_id=REPO_ID, token=HF_TOKEN) |
|
|
api.upload_file(path_or_fileobj="train_diffreaper_5l.py", path_in_repo="train_diffreaper_5l.py", repo_id=REPO_ID, token=HF_TOKEN) |
|
|
except Exception as e: log(f"HF Sync Error: {e}") |
|
|
step += 1 |
|
|
|