dataopsnick's picture
Create train.py
f580dce verified
Raw
History Blame Contribute Delete
15.7 kB
"""
ADAPT-DIFF Calibration & Training Script
Finetunes the Custom Stacked LDM Heads using target sequences from GSM8K & MBPP.
"""
import os
import gc
import copy
import random
import time
import re
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
print("Ensuring dependencies are installed...")
os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub")
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from datasets import load_dataset
from huggingface_hub import hf_hub_download
# Clean up GPU cache before running
gc.collect()
torch.cuda.empty_cache()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B"
ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b"
print(f"Loading {BASE_MODEL_ID} tokenizer and model structure metadata...")
src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
if src_tokenizer.pad_token is None:
src_tokenizer.pad_token = src_tokenizer.eos_token
# Load temporary instance to resolve base classes dynamically
temp_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
src_config = temp_model.config
BaseConfig = src_config.__class__
BaseModel = temp_model.model.__class__
BaseCausalLM = temp_model.__class__
BasePreTrainedModel = next(
(cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")),
None
)
if BasePreTrainedModel is None:
BasePreTrainedModel = BaseCausalLM.__bases__[0]
del temp_model
gc.collect()
# ==============================================================================
# Model & Pipeline Definitions
# ==============================================================================
class A2DQwenConfig(BaseConfig):
model_type = "a2d-qwen"
class A2DQwenModel(BaseModel):
def forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
past_key_values = None,
inputs_embeds = None,
use_cache = None,
cache_position = None,
**kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("Specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if not isinstance(causal_mask_mapping := attention_mask, dict):
if attention_mask is None:
attention_mask = torch.ones(
inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long
)
if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4):
attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
causal_mask_mapping = defaultdict(lambda: attention_mask)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
attn_type = getattr(decoder_layer, "attention_type", "self_attn")
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[attn_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class A2DQwenLMHeadModel(BaseCausalLM):
config_class = A2DQwenConfig
def __init__(self, config):
BasePreTrainedModel.__init__(self, config)
self.model = A2DQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
# Register custom classes
transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig)
transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel)
transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel)
class StackedLDMHeads(nn.Module):
def __init__(self, hidden_size, vocab_size, block_size=12):
super().__init__()
self.block_size = block_size
self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16)
self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
forecast = self.proj(hidden_states)
forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size)
logits = self.head(forecast)
return logits
class LogitUncertaintyFilter(nn.Module):
def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
probs = F.softmax(logits.float(), dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
return entropy
def forward(self, logits: torch.Tensor, threshold: float):
entropy = self.compute_entropy(logits)
mask = entropy >= threshold
return mask, entropy
class ActorCriticPruner:
def __init__(self, lm_head, lambda_reg=0.1):
self.lm_head = lm_head
self.lambda_reg = lambda_reg
def evaluate_sequence_value(self, candidate_tokens, logits):
log_probs = F.log_softmax(logits.float(), dim=-1)
gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1)
return gathered.mean().item()
def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta):
refined_sequence = sequence.clone()
if depth == 0 or mask.sum() == 0:
return refined_sequence, self.evaluate_sequence_value(sequence, logits)
high_unc_positions = torch.where(mask)[0]
if len(high_unc_positions) == 0:
return refined_sequence, self.evaluate_sequence_value(sequence, logits)
target_pos = high_unc_positions[0].item()
top_logits, top_tokens = torch.topk(logits[target_pos], k=3)
best_val = float('-inf')
for token_opt in top_tokens:
candidate = sequence.clone()
candidate[target_pos] = token_opt
approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item())
if approx_val < alpha:
continue
new_mask = mask.clone()
new_mask[target_pos] = False
_, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta)
if path_val > alpha:
alpha = path_val
best_val = path_val
refined_sequence = candidate
if alpha >= beta:
break
return refined_sequence, best_val
class ADAPTDIFFPipeline(nn.Module):
def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5):
super().__init__()
self.base_model = base_lm_model.model
self.lm_head = base_lm_model.lm_head
self.block_size = block_size
self.entropy_threshold = entropy_threshold
self.ldm_heads = StackedLDMHeads(
hidden_size=base_lm_model.config.hidden_size,
vocab_size=base_lm_model.config.vocab_size,
block_size=block_size
).to(DEVICE)
self.router = LogitUncertaintyFilter()
self.pruner = ActorCriticPruner(self.lm_head)
def generate_adapt_diff(self, input_ids, max_new_tokens=128):
current_seq = input_ids.clone()
generated_count = 0
total_full_transformer_evals = 0
while generated_count < max_new_tokens:
outputs = self.base_model(input_ids=current_seq)
total_full_transformer_evals += 1
last_hidden = outputs.last_hidden_state[:, -1:, :]
block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0)
draft_tokens = torch.argmax(block_logits, dim=-1)
mask, entropy = self.router(block_logits, self.entropy_threshold)
if not mask.any():
final_block = draft_tokens
else:
total_full_transformer_evals += 1
final_block, _ = self.pruner.recursive_refine(
sequence=draft_tokens,
logits=block_logits,
mask=mask,
entropy=entropy,
depth=2,
alpha=float('-inf'),
beta=float('inf')
)
current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1)
generated_count += self.block_size
return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals
# ==============================================================================
# Model Loading
# ==============================================================================
print(f"Loading ADAPT-DIFF base model {ADAPT_DIFF_ID}...")
a2d_model = AutoModelForCausalLM.from_pretrained(
ADAPT_DIFF_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5)
print("Downloading LDM head projection weights for calibration baseline...")
ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt")
pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE))
# ==============================================================================
# SFT Training Dataset Setup
# ==============================================================================
print("\nDownloading datasets (GSM8K & MBPP) for calibration phase...")
gsm8k_ds = load_dataset("openai/gsm8k", "main")
mbpp_ds = load_dataset("google-research-datasets/mbpp")
candidate_train = []
if "train" in gsm8k_ds:
for item in gsm8k_ds["train"]:
prompt = f"Problem: {item['question']}\nSolution:"
completion = f" {item['answer']}"
candidate_train.append((prompt, completion))
if len(candidate_train) >= 40:
break
mbpp_train_raw = mbpp_ds["train"] if "train" in mbpp_ds else list(mbpp_ds.values())[0]
code_count = 0
for item in mbpp_train_raw:
if 'text' in item and 'code' in item:
prompt = f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n"
completion = f"{item['code']}"
candidate_train.append((prompt, completion))
code_count += 1
if code_count >= 40:
break
print(f"Assembled training set with {len(candidate_train)} sequences.")
train_tensors = []
for prompt, completion in candidate_train:
full_text = prompt + completion
encoded = src_tokenizer(full_text, return_tensors="pt").to(DEVICE)
if encoded.input_ids.shape[1] > (pipeline.block_size + 2):
train_tensors.append(encoded.input_ids)
# ==============================================================================
# Calibration Loop
# ==============================================================================
pipeline.train()
optimizer = torch.optim.AdamW(pipeline.parameters(), lr=2e-4, weight_decay=0.01)
def compute_ldm_forecast_loss(pipeline, input_ids):
outputs = pipeline.base_model(input_ids=input_ids)
hidden_states = outputs.last_hidden_state
block_logits = pipeline.ldm_heads(hidden_states)
B, S, L, V = block_logits.shape
max_idx = S - 1 - L
if max_idx <= 0:
return torch.tensor(0.0, device=input_ids.device, requires_grad=True)
pred_logits = block_logits[:, :max_idx, :, :]
targets = torch.stack([
input_ids[:, i + 1 : i + 1 + L] for i in range(max_idx)
], dim=1)
loss_fct = nn.CrossEntropyLoss()
return loss_fct(pred_logits.reshape(-1, V), targets.reshape(-1))
epochs = 20
step = 0
best_loss = float('inf')
best_state_dict = None
print(f"\nCalibrating Stacked LDM heads across {epochs} epochs...")
for epoch in range(epochs):
random.shuffle(train_tensors)
epoch_loss = 0.0
for input_ids in train_tensors:
pipeline.train()
optimizer.zero_grad(set_to_none=True)
loss = compute_ldm_forecast_loss(pipeline, input_ids)
if loss.item() == 0.0:
continue
loss.backward()
torch.nn.utils.clip_grad_norm_(pipeline.parameters(), max_norm=1.0)
optimizer.step()
current_loss = loss.item()
epoch_loss += current_loss
step += 1
if current_loss < best_loss:
best_loss = current_loss
best_state_dict = copy.deepcopy(pipeline.state_dict())
if step % 20 == 0:
print(f"Step {step:3d} | Epoch {epoch+1} | Loss: {current_loss:.4f} (Best: {best_loss:.4f})")
print("\nSFT alignment completed.")
if best_state_dict is not None:
pipeline.load_state_dict(best_state_dict)
print(f"Successfully loaded best state checkpoint with loss: {best_loss:.4f}")
# ==============================================================================
# Model Post-Training Evaluation
# ==============================================================================
pipeline.eval()
print("\nVerifying model calibration progress on training sequence forecasts...")
with torch.no_grad():
for idx, input_ids in enumerate(train_tensors[:2]):
seq_len = input_ids.shape[1]
L = pipeline.block_size
if seq_len <= L + 1:
continue
prefix_len = seq_len - L
prefix_ids = input_ids[:, :prefix_len]
target_ids = input_ids[0, prefix_len : prefix_len + L]
outputs = pipeline.base_model(input_ids=prefix_ids)
hidden_states = outputs.last_hidden_state
block_logits = pipeline.ldm_heads(hidden_states)
forecast_logits = block_logits[0, -1, :, :]
pred_ids = torch.argmax(forecast_logits, dim=-1)
prompt_text = src_tokenizer.decode(prefix_ids[0], skip_special_tokens=True)
expected_text = src_tokenizer.decode(target_ids, skip_special_tokens=True)
predicted_text = src_tokenizer.decode(pred_ids, skip_special_tokens=True)
truncated_prompt = prompt_text[-200:] if len(prompt_text) > 200 else prompt_text
print(f"\n--- Sequence Output Check {idx + 1} ---")
print(f"[Context Prompt Segment]: ... {truncated_prompt}")
print(f"[Expected Block Output]: {expected_text}")
print(f"[Predicted Block Output]: {predicted_text}")