""" ADAPT-DIFF Inference & Benchmark Script Downloads 'dataopsnick/adapt-diff-qwen-0.8b' and compares it with 'Qwen/Qwen3.5-0.8B'. """ import os import gc import time import re from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F # 1. Install/Update Dependencies print("Ensuring dependencies are installed...") os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub") import transformers from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM from transformers.cache_utils import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from datasets import load_dataset from huggingface_hub import hf_hub_download # Clean up GPU cache before running gc.collect() torch.cuda.empty_cache() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B" ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b" print(f"Loading {BASE_MODEL_ID} metadata to dynamically resolve architecture classes...") src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) if src_tokenizer.pad_token is None: src_tokenizer.pad_token = src_tokenizer.eos_token # Load temporary instance to resolve base classes exactly as in your environment temp_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.bfloat16, device_map="cpu" ) src_config = temp_model.config BaseConfig = src_config.__class__ BaseModel = temp_model.model.__class__ BaseCausalLM = temp_model.__class__ BasePreTrainedModel = next( (cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")), None ) if BasePreTrainedModel is None: BasePreTrainedModel = BaseCausalLM.__bases__[0] # Free temporary model memory del temp_model gc.collect() # ============================================================================== # Custom ADAPT-DIFF Architecture Classes # ============================================================================== class A2DQwenConfig(BaseConfig): model_type = "a2d-qwen" class A2DQwenModel(BaseModel): def forward( self, input_ids = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, use_cache = None, cache_position = None, **kwargs, ): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("Specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # Core ADAPT-DIFF modification: replace causal mask with bidirectional/padding-only mask if not isinstance(causal_mask_mapping := attention_mask, dict): if attention_mask is None: attention_mask = torch.ones( inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long ) if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4): attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) causal_mask_mapping = defaultdict(lambda: attention_mask) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: attn_type = getattr(decoder_layer, "attention_type", "self_attn") hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[attn_type], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) class A2DQwenLMHeadModel(BaseCausalLM): config_class = A2DQwenConfig def __init__(self, config): BasePreTrainedModel.__init__(self, config) self.model = A2DQwenModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() # Register custom classes with Hugging Face AutoClasses transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig) transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel) transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel) # ============================================================================== # Custom Projection and Search Pipeline Components # ============================================================================== class StackedLDMHeads(nn.Module): def __init__(self, hidden_size, vocab_size, block_size=12): super().__init__() self.block_size = block_size self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16) self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16) def forward(self, hidden_states): batch_size, seq_len, hidden_size = hidden_states.shape forecast = self.proj(hidden_states) forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size) logits = self.head(forecast) return logits class LogitUncertaintyFilter(nn.Module): def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor: probs = F.softmax(logits.float(), dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) return entropy def forward(self, logits: torch.Tensor, threshold: float): entropy = self.compute_entropy(logits) mask = entropy >= threshold return mask, entropy class ActorCriticPruner: def __init__(self, lm_head, lambda_reg=0.1): self.lm_head = lm_head self.lambda_reg = lambda_reg def evaluate_sequence_value(self, candidate_tokens, logits): log_probs = F.log_softmax(logits.float(), dim=-1) gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1) return gathered.mean().item() def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta): refined_sequence = sequence.clone() if depth == 0 or mask.sum() == 0: return refined_sequence, self.evaluate_sequence_value(sequence, logits) high_unc_positions = torch.where(mask)[0] if len(high_unc_positions) == 0: return refined_sequence, self.evaluate_sequence_value(sequence, logits) target_pos = high_unc_positions[0].item() top_logits, top_tokens = torch.topk(logits[target_pos], k=3) best_val = float('-inf') for token_opt in top_tokens: candidate = sequence.clone() candidate[target_pos] = token_opt approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item()) if approx_val < alpha: continue new_mask = mask.clone() new_mask[target_pos] = False _, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta) if path_val > alpha: alpha = path_val best_val = path_val refined_sequence = candidate if alpha >= beta: break return refined_sequence, best_val class ADAPTDIFFPipeline(nn.Module): def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5): super().__init__() self.base_model = base_lm_model.model self.lm_head = base_lm_model.lm_head self.block_size = block_size self.entropy_threshold = entropy_threshold self.ldm_heads = StackedLDMHeads( hidden_size=base_lm_model.config.hidden_size, vocab_size=base_lm_model.config.vocab_size, block_size=block_size ).to(DEVICE) self.router = LogitUncertaintyFilter() self.pruner = ActorCriticPruner(self.lm_head) def generate_adapt_diff(self, input_ids, max_new_tokens=128): current_seq = input_ids.clone() generated_count = 0 total_full_transformer_evals = 0 while generated_count < max_new_tokens: outputs = self.base_model(input_ids=current_seq) total_full_transformer_evals += 1 last_hidden = outputs.last_hidden_state[:, -1:, :] block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0) draft_tokens = torch.argmax(block_logits, dim=-1) mask, entropy = self.router(block_logits, self.entropy_threshold) if not mask.any(): final_block = draft_tokens else: total_full_transformer_evals += 1 final_block, _ = self.pruner.recursive_refine( sequence=draft_tokens, logits=block_logits, mask=mask, entropy=entropy, depth=2, alpha=float('-inf'), beta=float('inf') ) current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1) generated_count += self.block_size return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals # ============================================================================== # Model Loading & LDM Weights Initialization # ============================================================================== print(f"Downloading custom bidirectional model {ADAPT_DIFF_ID} from Hugging Face...") a2d_model = AutoModelForCausalLM.from_pretrained( ADAPT_DIFF_ID, torch_dtype=torch.bfloat16, device_map=DEVICE ) print(f"Downloading baseline model {BASE_MODEL_ID} for comparative evaluation...") baseline_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.bfloat16, device_map=DEVICE ) # Initialize generation pipeline and load pre-trained custom LDM weights pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5) print("Downloading LDM head projection weights...") ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt") pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE)) pipeline.eval() # ============================================================================== # Sub-Sampled Benchmark Initialization # ============================================================================== print("\nLoading GSM8K and MBPP evaluation datasets...") gsm8k_ds = load_dataset("openai/gsm8k", "main", split="test") mbpp_ds = load_dataset("google-research-datasets/mbpp", split="test") val_math = [] for item in gsm8k_ds: val_math.append((f"Problem: {item['question']}\nSolution:", item['answer'])) if len(val_math) >= 10: # Fast benchmark slice break val_code = [] for item in mbpp_ds: val_code.append((f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n", item['code'], item['test_list'])) if len(val_code) >= 10: break # ============================================================================== # Validation Helpers # ============================================================================== def extract_answer(text): if "####" in text: text = text.split("####")[-1] matches = re.findall(r'-?[\d,]*\.?\d+', text) return matches[-1].replace(',', '') if matches else None def verify_math(generated_text, ref_ans): pred_val = extract_answer(generated_text) ref_val = extract_answer(ref_ans) if pred_val is None or ref_val is None: return 0.0 try: return 1.0 if float(pred_val) == float(ref_val) else 0.0 except ValueError: return 1.0 if str(pred_val).strip() == str(ref_val).strip() else 0.0 def verify_code(generated_text, test_list): code_block = generated_text if "```python" in generated_text: code_block = generated_text.split("```python")[-1].split("```")[0] elif "```" in generated_text: code_block = generated_text.split("```")[-1].split("```")[0] local_scope = {} try: compiled_code = compile(code_block, "", "exec") exec(compiled_code, local_scope, local_scope) for test in test_list: exec(test, local_scope, local_scope) return 1.0 except Exception: return 0.0 # ============================================================================== # Evaluation Loop # ============================================================================== def run_benchmark(pipeline, base_model, dataset, is_code=False): ar_correct = 0 ad_correct = 0 total = len(dataset) ar_total_tokens = 0 ad_total_tokens = 0 ar_total_time = 0.0 ad_total_time = 0.0 ad_total_evals = 0 for idx, item in enumerate(dataset): prompt = item[0] inputs = src_tokenizer(prompt, return_tensors="pt").to(DEVICE) max_new_tokens = 48 # Autoregressive generation t_start = time.time() with torch.no_grad(): ar_outputs = base_model.generate( **inputs, max_new_tokens=max_new_tokens, pad_token_id=src_tokenizer.pad_token_id, eos_token_id=src_tokenizer.eos_token_id, do_sample=False ) ar_total_time += (time.time() - t_start) ar_gen_tokens = ar_outputs[0][inputs.input_ids.shape[1]:] ar_total_tokens += len(ar_gen_tokens) ar_text = src_tokenizer.decode(ar_gen_tokens, skip_special_tokens=True) # ADAPT-DIFF speculative generation t_start = time.time() with torch.no_grad(): ad_gen_tokens, step_evals = pipeline.generate_adapt_diff( input_ids=inputs.input_ids, max_new_tokens=max_new_tokens ) ad_total_time += (time.time() - t_start) ad_total_tokens += len(ad_gen_tokens) ad_total_evals += step_evals ad_text = src_tokenizer.decode(ad_gen_tokens, skip_special_tokens=True) if is_code: ar_correct += verify_code(ar_text, item[2]) ad_correct += verify_code(ad_text, item[2]) else: ar_correct += verify_math(ar_text, item[1]) ad_correct += verify_math(ad_text, item[1]) ar_throughput = ar_total_tokens / (ar_total_time + 1e-9) ad_throughput = ad_total_tokens / (ad_total_time + 1e-9) ad_flops_per_token = ad_total_evals / (ad_total_tokens + 1e-9) return { "ar_acc": ar_correct / total, "ad_acc": ad_correct / total, "ar_speed": ar_throughput, "ad_speed": ad_throughput, "ar_flops": 1.0, "ad_flops": ad_flops_per_token } print("\nStarting evaluation run...") math_results = run_benchmark(pipeline, baseline_model, val_math, is_code=False) code_results = run_benchmark(pipeline, baseline_model, val_code, is_code=True) # Print comparative results print("\n" + "="*95) print(" ADAPT-DIFF INFERENCE BENCHMARK RESULTS (Block Size L = 12)") print("="*95) print(f"{'Task / Strategy':<30} | {'Throughput (tok/s)':<20} | {'Task Acc':<15} | {'Relative FLOPs/Tok':<20}") print("-"*95) print(f"{'GSM8K (Autoregressive Baseline)':<30} | {math_results['ar_speed']:<20.2f} | {math_results['ar_acc']:<15.2%} | {math_results['ar_flops']:<20.4f}") print(f"{'GSM8K (ADAPT-DIFF Speculative)':<30} | {math_results['ad_speed']:<20.2f} | {math_results['ad_acc']:<15.2%} | {math_results['ad_flops']:<20.4f}") print("-"*95) print(f"{'MBPP (Autoregressive Baseline)':<30} | {code_results['ar_speed']:<20.2f} | {code_results['ar_acc']:<15.2%} | {code_results['ar_flops']:<20.4f}") print(f"{'MBPP (ADAPT-DIFF Speculative)':<30} | {code_results['ad_speed']:<20.2f} | {code_results['ad_acc']:<15.2%} | {code_results['ad_flops']:<20.4f}") print("="*95)