""" ADAPT-DIFF Calibration & Training Script Finetunes the Custom Stacked LDM Heads using target sequences from GSM8K & MBPP. """ import os import gc import copy import random import time import re from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F print("Ensuring dependencies are installed...") os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub") import transformers from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM from transformers.cache_utils import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from datasets import load_dataset from huggingface_hub import hf_hub_download # Clean up GPU cache before running gc.collect() torch.cuda.empty_cache() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B" ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b" print(f"Loading {BASE_MODEL_ID} tokenizer and model structure metadata...") src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) if src_tokenizer.pad_token is None: src_tokenizer.pad_token = src_tokenizer.eos_token # Load temporary instance to resolve base classes dynamically temp_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.bfloat16, device_map="cpu" ) src_config = temp_model.config BaseConfig = src_config.__class__ BaseModel = temp_model.model.__class__ BaseCausalLM = temp_model.__class__ BasePreTrainedModel = next( (cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")), None ) if BasePreTrainedModel is None: BasePreTrainedModel = BaseCausalLM.__bases__[0] del temp_model gc.collect() # ============================================================================== # Model & Pipeline Definitions # ============================================================================== class A2DQwenConfig(BaseConfig): model_type = "a2d-qwen" class A2DQwenModel(BaseModel): def forward( self, input_ids = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, use_cache = None, cache_position = None, **kwargs, ): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("Specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) if not isinstance(causal_mask_mapping := attention_mask, dict): if attention_mask is None: attention_mask = torch.ones( inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long ) if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4): attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) causal_mask_mapping = defaultdict(lambda: attention_mask) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: attn_type = getattr(decoder_layer, "attention_type", "self_attn") hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[attn_type], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) class A2DQwenLMHeadModel(BaseCausalLM): config_class = A2DQwenConfig def __init__(self, config): BasePreTrainedModel.__init__(self, config) self.model = A2DQwenModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() # Register custom classes transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig) transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel) transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel) class StackedLDMHeads(nn.Module): def __init__(self, hidden_size, vocab_size, block_size=12): super().__init__() self.block_size = block_size self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16) self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16) def forward(self, hidden_states): batch_size, seq_len, hidden_size = hidden_states.shape forecast = self.proj(hidden_states) forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size) logits = self.head(forecast) return logits class LogitUncertaintyFilter(nn.Module): def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor: probs = F.softmax(logits.float(), dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) return entropy def forward(self, logits: torch.Tensor, threshold: float): entropy = self.compute_entropy(logits) mask = entropy >= threshold return mask, entropy class ActorCriticPruner: def __init__(self, lm_head, lambda_reg=0.1): self.lm_head = lm_head self.lambda_reg = lambda_reg def evaluate_sequence_value(self, candidate_tokens, logits): log_probs = F.log_softmax(logits.float(), dim=-1) gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1) return gathered.mean().item() def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta): refined_sequence = sequence.clone() if depth == 0 or mask.sum() == 0: return refined_sequence, self.evaluate_sequence_value(sequence, logits) high_unc_positions = torch.where(mask)[0] if len(high_unc_positions) == 0: return refined_sequence, self.evaluate_sequence_value(sequence, logits) target_pos = high_unc_positions[0].item() top_logits, top_tokens = torch.topk(logits[target_pos], k=3) best_val = float('-inf') for token_opt in top_tokens: candidate = sequence.clone() candidate[target_pos] = token_opt approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item()) if approx_val < alpha: continue new_mask = mask.clone() new_mask[target_pos] = False _, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta) if path_val > alpha: alpha = path_val best_val = path_val refined_sequence = candidate if alpha >= beta: break return refined_sequence, best_val class ADAPTDIFFPipeline(nn.Module): def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5): super().__init__() self.base_model = base_lm_model.model self.lm_head = base_lm_model.lm_head self.block_size = block_size self.entropy_threshold = entropy_threshold self.ldm_heads = StackedLDMHeads( hidden_size=base_lm_model.config.hidden_size, vocab_size=base_lm_model.config.vocab_size, block_size=block_size ).to(DEVICE) self.router = LogitUncertaintyFilter() self.pruner = ActorCriticPruner(self.lm_head) def generate_adapt_diff(self, input_ids, max_new_tokens=128): current_seq = input_ids.clone() generated_count = 0 total_full_transformer_evals = 0 while generated_count < max_new_tokens: outputs = self.base_model(input_ids=current_seq) total_full_transformer_evals += 1 last_hidden = outputs.last_hidden_state[:, -1:, :] block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0) draft_tokens = torch.argmax(block_logits, dim=-1) mask, entropy = self.router(block_logits, self.entropy_threshold) if not mask.any(): final_block = draft_tokens else: total_full_transformer_evals += 1 final_block, _ = self.pruner.recursive_refine( sequence=draft_tokens, logits=block_logits, mask=mask, entropy=entropy, depth=2, alpha=float('-inf'), beta=float('inf') ) current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1) generated_count += self.block_size return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals # ============================================================================== # Model Loading # ============================================================================== print(f"Loading ADAPT-DIFF base model {ADAPT_DIFF_ID}...") a2d_model = AutoModelForCausalLM.from_pretrained( ADAPT_DIFF_ID, torch_dtype=torch.bfloat16, device_map=DEVICE ) pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5) print("Downloading LDM head projection weights for calibration baseline...") ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt") pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE)) # ============================================================================== # SFT Training Dataset Setup # ============================================================================== print("\nDownloading datasets (GSM8K & MBPP) for calibration phase...") gsm8k_ds = load_dataset("openai/gsm8k", "main") mbpp_ds = load_dataset("google-research-datasets/mbpp") candidate_train = [] if "train" in gsm8k_ds: for item in gsm8k_ds["train"]: prompt = f"Problem: {item['question']}\nSolution:" completion = f" {item['answer']}" candidate_train.append((prompt, completion)) if len(candidate_train) >= 40: break mbpp_train_raw = mbpp_ds["train"] if "train" in mbpp_ds else list(mbpp_ds.values())[0] code_count = 0 for item in mbpp_train_raw: if 'text' in item and 'code' in item: prompt = f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n" completion = f"{item['code']}" candidate_train.append((prompt, completion)) code_count += 1 if code_count >= 40: break print(f"Assembled training set with {len(candidate_train)} sequences.") train_tensors = [] for prompt, completion in candidate_train: full_text = prompt + completion encoded = src_tokenizer(full_text, return_tensors="pt").to(DEVICE) if encoded.input_ids.shape[1] > (pipeline.block_size + 2): train_tensors.append(encoded.input_ids) # ============================================================================== # Calibration Loop # ============================================================================== pipeline.train() optimizer = torch.optim.AdamW(pipeline.parameters(), lr=2e-4, weight_decay=0.01) def compute_ldm_forecast_loss(pipeline, input_ids): outputs = pipeline.base_model(input_ids=input_ids) hidden_states = outputs.last_hidden_state block_logits = pipeline.ldm_heads(hidden_states) B, S, L, V = block_logits.shape max_idx = S - 1 - L if max_idx <= 0: return torch.tensor(0.0, device=input_ids.device, requires_grad=True) pred_logits = block_logits[:, :max_idx, :, :] targets = torch.stack([ input_ids[:, i + 1 : i + 1 + L] for i in range(max_idx) ], dim=1) loss_fct = nn.CrossEntropyLoss() return loss_fct(pred_logits.reshape(-1, V), targets.reshape(-1)) epochs = 20 step = 0 best_loss = float('inf') best_state_dict = None print(f"\nCalibrating Stacked LDM heads across {epochs} epochs...") for epoch in range(epochs): random.shuffle(train_tensors) epoch_loss = 0.0 for input_ids in train_tensors: pipeline.train() optimizer.zero_grad(set_to_none=True) loss = compute_ldm_forecast_loss(pipeline, input_ids) if loss.item() == 0.0: continue loss.backward() torch.nn.utils.clip_grad_norm_(pipeline.parameters(), max_norm=1.0) optimizer.step() current_loss = loss.item() epoch_loss += current_loss step += 1 if current_loss < best_loss: best_loss = current_loss best_state_dict = copy.deepcopy(pipeline.state_dict()) if step % 20 == 0: print(f"Step {step:3d} | Epoch {epoch+1} | Loss: {current_loss:.4f} (Best: {best_loss:.4f})") print("\nSFT alignment completed.") if best_state_dict is not None: pipeline.load_state_dict(best_state_dict) print(f"Successfully loaded best state checkpoint with loss: {best_loss:.4f}") # ============================================================================== # Model Post-Training Evaluation # ============================================================================== pipeline.eval() print("\nVerifying model calibration progress on training sequence forecasts...") with torch.no_grad(): for idx, input_ids in enumerate(train_tensors[:2]): seq_len = input_ids.shape[1] L = pipeline.block_size if seq_len <= L + 1: continue prefix_len = seq_len - L prefix_ids = input_ids[:, :prefix_len] target_ids = input_ids[0, prefix_len : prefix_len + L] outputs = pipeline.base_model(input_ids=prefix_ids) hidden_states = outputs.last_hidden_state block_logits = pipeline.ldm_heads(hidden_states) forecast_logits = block_logits[0, -1, :, :] pred_ids = torch.argmax(forecast_logits, dim=-1) prompt_text = src_tokenizer.decode(prefix_ids[0], skip_special_tokens=True) expected_text = src_tokenizer.decode(target_ids, skip_special_tokens=True) predicted_text = src_tokenizer.decode(pred_ids, skip_special_tokens=True) truncated_prompt = prompt_text[-200:] if len(prompt_text) > 200 else prompt_text print(f"\n--- Sequence Output Check {idx + 1} ---") print(f"[Context Prompt Segment]: ... {truncated_prompt}") print(f"[Expected Block Output]: {expected_text}") print(f"[Predicted Block Output]: {predicted_text}")