| """ |
| ADAPT-DIFF Calibration & Training Script |
| Finetunes the Custom Stacked LDM Heads using target sequences from GSM8K & MBPP. |
| """ |
|
|
| import os |
| import gc |
| import copy |
| import random |
| import time |
| import re |
| from collections import defaultdict |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| print("Ensuring dependencies are installed...") |
| os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub") |
|
|
| import transformers |
| from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM |
| from transformers.cache_utils import DynamicCache |
| from transformers.modeling_outputs import BaseModelOutputWithPast |
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
| from datasets import load_dataset |
| from huggingface_hub import hf_hub_download |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B" |
| ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b" |
|
|
| print(f"Loading {BASE_MODEL_ID} tokenizer and model structure metadata...") |
| src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
| if src_tokenizer.pad_token is None: |
| src_tokenizer.pad_token = src_tokenizer.eos_token |
|
|
| |
| temp_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu" |
| ) |
| src_config = temp_model.config |
|
|
| BaseConfig = src_config.__class__ |
| BaseModel = temp_model.model.__class__ |
| BaseCausalLM = temp_model.__class__ |
|
|
| BasePreTrainedModel = next( |
| (cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")), |
| None |
| ) |
| if BasePreTrainedModel is None: |
| BasePreTrainedModel = BaseCausalLM.__bases__[0] |
|
|
| del temp_model |
| gc.collect() |
|
|
|
|
| |
| |
| |
| class A2DQwenConfig(BaseConfig): |
| model_type = "a2d-qwen" |
|
|
| class A2DQwenModel(BaseModel): |
| def forward( |
| self, |
| input_ids = None, |
| attention_mask = None, |
| position_ids = None, |
| past_key_values = None, |
| inputs_embeds = None, |
| use_cache = None, |
| cache_position = None, |
| **kwargs, |
| ): |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("Specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| if not isinstance(causal_mask_mapping := attention_mask, dict): |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long |
| ) |
| if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4): |
| attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) |
| causal_mask_mapping = defaultdict(lambda: attention_mask) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| attn_type = getattr(decoder_layer, "attention_type", "self_attn") |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask_mapping[attn_type], |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values if use_cache else None, |
| ) |
|
|
| class A2DQwenLMHeadModel(BaseCausalLM): |
| config_class = A2DQwenConfig |
| def __init__(self, config): |
| BasePreTrainedModel.__init__(self, config) |
| self.model = A2DQwenModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
|
|
| |
| transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig) |
| transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel) |
| transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel) |
|
|
|
|
| class StackedLDMHeads(nn.Module): |
| def __init__(self, hidden_size, vocab_size, block_size=12): |
| super().__init__() |
| self.block_size = block_size |
| self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16) |
| self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16) |
|
|
| def forward(self, hidden_states): |
| batch_size, seq_len, hidden_size = hidden_states.shape |
| forecast = self.proj(hidden_states) |
| forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size) |
| logits = self.head(forecast) |
| return logits |
|
|
| class LogitUncertaintyFilter(nn.Module): |
| def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor: |
| probs = F.softmax(logits.float(), dim=-1) |
| entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) |
| return entropy |
|
|
| def forward(self, logits: torch.Tensor, threshold: float): |
| entropy = self.compute_entropy(logits) |
| mask = entropy >= threshold |
| return mask, entropy |
|
|
| class ActorCriticPruner: |
| def __init__(self, lm_head, lambda_reg=0.1): |
| self.lm_head = lm_head |
| self.lambda_reg = lambda_reg |
|
|
| def evaluate_sequence_value(self, candidate_tokens, logits): |
| log_probs = F.log_softmax(logits.float(), dim=-1) |
| gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1) |
| return gathered.mean().item() |
|
|
| def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta): |
| refined_sequence = sequence.clone() |
| if depth == 0 or mask.sum() == 0: |
| return refined_sequence, self.evaluate_sequence_value(sequence, logits) |
|
|
| high_unc_positions = torch.where(mask)[0] |
| if len(high_unc_positions) == 0: |
| return refined_sequence, self.evaluate_sequence_value(sequence, logits) |
|
|
| target_pos = high_unc_positions[0].item() |
| top_logits, top_tokens = torch.topk(logits[target_pos], k=3) |
|
|
| best_val = float('-inf') |
| for token_opt in top_tokens: |
| candidate = sequence.clone() |
| candidate[target_pos] = token_opt |
|
|
| approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item()) |
| if approx_val < alpha: |
| continue |
|
|
| new_mask = mask.clone() |
| new_mask[target_pos] = False |
|
|
| _, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta) |
| if path_val > alpha: |
| alpha = path_val |
| best_val = path_val |
| refined_sequence = candidate |
|
|
| if alpha >= beta: |
| break |
|
|
| return refined_sequence, best_val |
|
|
|
|
| class ADAPTDIFFPipeline(nn.Module): |
| def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5): |
| super().__init__() |
| self.base_model = base_lm_model.model |
| self.lm_head = base_lm_model.lm_head |
| self.block_size = block_size |
| self.entropy_threshold = entropy_threshold |
|
|
| self.ldm_heads = StackedLDMHeads( |
| hidden_size=base_lm_model.config.hidden_size, |
| vocab_size=base_lm_model.config.vocab_size, |
| block_size=block_size |
| ).to(DEVICE) |
| |
| self.router = LogitUncertaintyFilter() |
| self.pruner = ActorCriticPruner(self.lm_head) |
|
|
| def generate_adapt_diff(self, input_ids, max_new_tokens=128): |
| current_seq = input_ids.clone() |
| generated_count = 0 |
| total_full_transformer_evals = 0 |
|
|
| while generated_count < max_new_tokens: |
| outputs = self.base_model(input_ids=current_seq) |
| total_full_transformer_evals += 1 |
| last_hidden = outputs.last_hidden_state[:, -1:, :] |
|
|
| block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0) |
| draft_tokens = torch.argmax(block_logits, dim=-1) |
|
|
| mask, entropy = self.router(block_logits, self.entropy_threshold) |
|
|
| if not mask.any(): |
| final_block = draft_tokens |
| else: |
| total_full_transformer_evals += 1 |
| final_block, _ = self.pruner.recursive_refine( |
| sequence=draft_tokens, |
| logits=block_logits, |
| mask=mask, |
| entropy=entropy, |
| depth=2, |
| alpha=float('-inf'), |
| beta=float('inf') |
| ) |
|
|
| current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1) |
| generated_count += self.block_size |
|
|
| return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals |
|
|
|
|
| |
| |
| |
| print(f"Loading ADAPT-DIFF base model {ADAPT_DIFF_ID}...") |
| a2d_model = AutoModelForCausalLM.from_pretrained( |
| ADAPT_DIFF_ID, |
| torch_dtype=torch.bfloat16, |
| device_map=DEVICE |
| ) |
|
|
| pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5) |
| print("Downloading LDM head projection weights for calibration baseline...") |
| ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt") |
| pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE)) |
|
|
|
|
| |
| |
| |
| print("\nDownloading datasets (GSM8K & MBPP) for calibration phase...") |
| gsm8k_ds = load_dataset("openai/gsm8k", "main") |
| mbpp_ds = load_dataset("google-research-datasets/mbpp") |
|
|
| candidate_train = [] |
|
|
| if "train" in gsm8k_ds: |
| for item in gsm8k_ds["train"]: |
| prompt = f"Problem: {item['question']}\nSolution:" |
| completion = f" {item['answer']}" |
| candidate_train.append((prompt, completion)) |
| if len(candidate_train) >= 40: |
| break |
|
|
| mbpp_train_raw = mbpp_ds["train"] if "train" in mbpp_ds else list(mbpp_ds.values())[0] |
| code_count = 0 |
| for item in mbpp_train_raw: |
| if 'text' in item and 'code' in item: |
| prompt = f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n" |
| completion = f"{item['code']}" |
| candidate_train.append((prompt, completion)) |
| code_count += 1 |
| if code_count >= 40: |
| break |
|
|
| print(f"Assembled training set with {len(candidate_train)} sequences.") |
|
|
| train_tensors = [] |
| for prompt, completion in candidate_train: |
| full_text = prompt + completion |
| encoded = src_tokenizer(full_text, return_tensors="pt").to(DEVICE) |
| if encoded.input_ids.shape[1] > (pipeline.block_size + 2): |
| train_tensors.append(encoded.input_ids) |
|
|
|
|
| |
| |
| |
| pipeline.train() |
| optimizer = torch.optim.AdamW(pipeline.parameters(), lr=2e-4, weight_decay=0.01) |
|
|
| def compute_ldm_forecast_loss(pipeline, input_ids): |
| outputs = pipeline.base_model(input_ids=input_ids) |
| hidden_states = outputs.last_hidden_state |
| |
| block_logits = pipeline.ldm_heads(hidden_states) |
| B, S, L, V = block_logits.shape |
| max_idx = S - 1 - L |
| |
| if max_idx <= 0: |
| return torch.tensor(0.0, device=input_ids.device, requires_grad=True) |
| |
| pred_logits = block_logits[:, :max_idx, :, :] |
| targets = torch.stack([ |
| input_ids[:, i + 1 : i + 1 + L] for i in range(max_idx) |
| ], dim=1) |
| |
| loss_fct = nn.CrossEntropyLoss() |
| return loss_fct(pred_logits.reshape(-1, V), targets.reshape(-1)) |
|
|
| epochs = 20 |
| step = 0 |
| best_loss = float('inf') |
| best_state_dict = None |
|
|
| print(f"\nCalibrating Stacked LDM heads across {epochs} epochs...") |
|
|
| for epoch in range(epochs): |
| random.shuffle(train_tensors) |
| epoch_loss = 0.0 |
| |
| for input_ids in train_tensors: |
| pipeline.train() |
| optimizer.zero_grad(set_to_none=True) |
| |
| loss = compute_ldm_forecast_loss(pipeline, input_ids) |
| if loss.item() == 0.0: |
| continue |
| |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(pipeline.parameters(), max_norm=1.0) |
| optimizer.step() |
| |
| current_loss = loss.item() |
| epoch_loss += current_loss |
| step += 1 |
| |
| if current_loss < best_loss: |
| best_loss = current_loss |
| best_state_dict = copy.deepcopy(pipeline.state_dict()) |
| |
| if step % 20 == 0: |
| print(f"Step {step:3d} | Epoch {epoch+1} | Loss: {current_loss:.4f} (Best: {best_loss:.4f})") |
|
|
| print("\nSFT alignment completed.") |
| if best_state_dict is not None: |
| pipeline.load_state_dict(best_state_dict) |
| print(f"Successfully loaded best state checkpoint with loss: {best_loss:.4f}") |
|
|
|
|
| |
| |
| |
| pipeline.eval() |
| print("\nVerifying model calibration progress on training sequence forecasts...") |
|
|
| with torch.no_grad(): |
| for idx, input_ids in enumerate(train_tensors[:2]): |
| seq_len = input_ids.shape[1] |
| L = pipeline.block_size |
| if seq_len <= L + 1: |
| continue |
| |
| prefix_len = seq_len - L |
| prefix_ids = input_ids[:, :prefix_len] |
| target_ids = input_ids[0, prefix_len : prefix_len + L] |
| |
| outputs = pipeline.base_model(input_ids=prefix_ids) |
| hidden_states = outputs.last_hidden_state |
| block_logits = pipeline.ldm_heads(hidden_states) |
| |
| forecast_logits = block_logits[0, -1, :, :] |
| pred_ids = torch.argmax(forecast_logits, dim=-1) |
| |
| prompt_text = src_tokenizer.decode(prefix_ids[0], skip_special_tokens=True) |
| expected_text = src_tokenizer.decode(target_ids, skip_special_tokens=True) |
| predicted_text = src_tokenizer.decode(pred_ids, skip_special_tokens=True) |
| |
| truncated_prompt = prompt_text[-200:] if len(prompt_text) > 200 else prompt_text |
| print(f"\n--- Sequence Output Check {idx + 1} ---") |
| print(f"[Context Prompt Segment]: ... {truncated_prompt}") |
| print(f"[Expected Block Output]: {expected_text}") |
| print(f"[Predicted Block Output]: {predicted_text}") |