| """ |
| ADAPT-DIFF Inference & Benchmark Script |
| Downloads 'dataopsnick/adapt-diff-qwen-0.8b' and compares it with 'Qwen/Qwen3.5-0.8B'. |
| """ |
|
|
| import os |
| import gc |
| import time |
| import re |
| from collections import defaultdict |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| print("Ensuring dependencies are installed...") |
| os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub") |
|
|
| import transformers |
| from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM |
| from transformers.cache_utils import DynamicCache |
| from transformers.modeling_outputs import BaseModelOutputWithPast |
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
| from datasets import load_dataset |
| from huggingface_hub import hf_hub_download |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B" |
| ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b" |
|
|
| print(f"Loading {BASE_MODEL_ID} metadata to dynamically resolve architecture classes...") |
| src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
| if src_tokenizer.pad_token is None: |
| src_tokenizer.pad_token = src_tokenizer.eos_token |
|
|
| |
| temp_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu" |
| ) |
| src_config = temp_model.config |
|
|
| BaseConfig = src_config.__class__ |
| BaseModel = temp_model.model.__class__ |
| BaseCausalLM = temp_model.__class__ |
|
|
| BasePreTrainedModel = next( |
| (cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")), |
| None |
| ) |
| if BasePreTrainedModel is None: |
| BasePreTrainedModel = BaseCausalLM.__bases__[0] |
|
|
| |
| del temp_model |
| gc.collect() |
|
|
|
|
| |
| |
| |
| class A2DQwenConfig(BaseConfig): |
| model_type = "a2d-qwen" |
|
|
| class A2DQwenModel(BaseModel): |
| def forward( |
| self, |
| input_ids = None, |
| attention_mask = None, |
| position_ids = None, |
| past_key_values = None, |
| inputs_embeds = None, |
| use_cache = None, |
| cache_position = None, |
| **kwargs, |
| ): |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("Specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| if not isinstance(causal_mask_mapping := attention_mask, dict): |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long |
| ) |
| if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4): |
| attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) |
| causal_mask_mapping = defaultdict(lambda: attention_mask) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| attn_type = getattr(decoder_layer, "attention_type", "self_attn") |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask_mapping[attn_type], |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values if use_cache else None, |
| ) |
|
|
| class A2DQwenLMHeadModel(BaseCausalLM): |
| config_class = A2DQwenConfig |
| def __init__(self, config): |
| BasePreTrainedModel.__init__(self, config) |
| self.model = A2DQwenModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
|
|
| |
| transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig) |
| transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel) |
| transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel) |
|
|
|
|
| |
| |
| |
| class StackedLDMHeads(nn.Module): |
| def __init__(self, hidden_size, vocab_size, block_size=12): |
| super().__init__() |
| self.block_size = block_size |
| self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16) |
| self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16) |
|
|
| def forward(self, hidden_states): |
| batch_size, seq_len, hidden_size = hidden_states.shape |
| forecast = self.proj(hidden_states) |
| forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size) |
| logits = self.head(forecast) |
| return logits |
|
|
| class LogitUncertaintyFilter(nn.Module): |
| def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor: |
| probs = F.softmax(logits.float(), dim=-1) |
| entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) |
| return entropy |
|
|
| def forward(self, logits: torch.Tensor, threshold: float): |
| entropy = self.compute_entropy(logits) |
| mask = entropy >= threshold |
| return mask, entropy |
|
|
| class ActorCriticPruner: |
| def __init__(self, lm_head, lambda_reg=0.1): |
| self.lm_head = lm_head |
| self.lambda_reg = lambda_reg |
|
|
| def evaluate_sequence_value(self, candidate_tokens, logits): |
| log_probs = F.log_softmax(logits.float(), dim=-1) |
| gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1) |
| return gathered.mean().item() |
|
|
| def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta): |
| refined_sequence = sequence.clone() |
| if depth == 0 or mask.sum() == 0: |
| return refined_sequence, self.evaluate_sequence_value(sequence, logits) |
|
|
| high_unc_positions = torch.where(mask)[0] |
| if len(high_unc_positions) == 0: |
| return refined_sequence, self.evaluate_sequence_value(sequence, logits) |
|
|
| target_pos = high_unc_positions[0].item() |
| top_logits, top_tokens = torch.topk(logits[target_pos], k=3) |
|
|
| best_val = float('-inf') |
| for token_opt in top_tokens: |
| candidate = sequence.clone() |
| candidate[target_pos] = token_opt |
|
|
| approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item()) |
| if approx_val < alpha: |
| continue |
|
|
| new_mask = mask.clone() |
| new_mask[target_pos] = False |
|
|
| _, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta) |
| if path_val > alpha: |
| alpha = path_val |
| best_val = path_val |
| refined_sequence = candidate |
|
|
| if alpha >= beta: |
| break |
|
|
| return refined_sequence, best_val |
|
|
|
|
| class ADAPTDIFFPipeline(nn.Module): |
| def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5): |
| super().__init__() |
| self.base_model = base_lm_model.model |
| self.lm_head = base_lm_model.lm_head |
| self.block_size = block_size |
| self.entropy_threshold = entropy_threshold |
|
|
| self.ldm_heads = StackedLDMHeads( |
| hidden_size=base_lm_model.config.hidden_size, |
| vocab_size=base_lm_model.config.vocab_size, |
| block_size=block_size |
| ).to(DEVICE) |
| |
| self.router = LogitUncertaintyFilter() |
| self.pruner = ActorCriticPruner(self.lm_head) |
|
|
| def generate_adapt_diff(self, input_ids, max_new_tokens=128): |
| current_seq = input_ids.clone() |
| generated_count = 0 |
| total_full_transformer_evals = 0 |
|
|
| while generated_count < max_new_tokens: |
| outputs = self.base_model(input_ids=current_seq) |
| total_full_transformer_evals += 1 |
| last_hidden = outputs.last_hidden_state[:, -1:, :] |
|
|
| block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0) |
| draft_tokens = torch.argmax(block_logits, dim=-1) |
|
|
| mask, entropy = self.router(block_logits, self.entropy_threshold) |
|
|
| if not mask.any(): |
| final_block = draft_tokens |
| else: |
| total_full_transformer_evals += 1 |
| final_block, _ = self.pruner.recursive_refine( |
| sequence=draft_tokens, |
| logits=block_logits, |
| mask=mask, |
| entropy=entropy, |
| depth=2, |
| alpha=float('-inf'), |
| beta=float('inf') |
| ) |
|
|
| current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1) |
| generated_count += self.block_size |
|
|
| return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals |
|
|
|
|
| |
| |
| |
| print(f"Downloading custom bidirectional model {ADAPT_DIFF_ID} from Hugging Face...") |
| a2d_model = AutoModelForCausalLM.from_pretrained( |
| ADAPT_DIFF_ID, |
| torch_dtype=torch.bfloat16, |
| device_map=DEVICE |
| ) |
|
|
| print(f"Downloading baseline model {BASE_MODEL_ID} for comparative evaluation...") |
| baseline_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map=DEVICE |
| ) |
|
|
| |
| pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5) |
| print("Downloading LDM head projection weights...") |
| ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt") |
| pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE)) |
| pipeline.eval() |
|
|
|
|
| |
| |
| |
| print("\nLoading GSM8K and MBPP evaluation datasets...") |
| gsm8k_ds = load_dataset("openai/gsm8k", "main", split="test") |
| mbpp_ds = load_dataset("google-research-datasets/mbpp", split="test") |
|
|
| val_math = [] |
| for item in gsm8k_ds: |
| val_math.append((f"Problem: {item['question']}\nSolution:", item['answer'])) |
| if len(val_math) >= 10: |
| break |
|
|
| val_code = [] |
| for item in mbpp_ds: |
| val_code.append((f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n", item['code'], item['test_list'])) |
| if len(val_code) >= 10: |
| break |
|
|
|
|
| |
| |
| |
| def extract_answer(text): |
| if "####" in text: |
| text = text.split("####")[-1] |
| matches = re.findall(r'-?[\d,]*\.?\d+', text) |
| return matches[-1].replace(',', '') if matches else None |
|
|
| def verify_math(generated_text, ref_ans): |
| pred_val = extract_answer(generated_text) |
| ref_val = extract_answer(ref_ans) |
| if pred_val is None or ref_val is None: |
| return 0.0 |
| try: |
| return 1.0 if float(pred_val) == float(ref_val) else 0.0 |
| except ValueError: |
| return 1.0 if str(pred_val).strip() == str(ref_val).strip() else 0.0 |
|
|
| def verify_code(generated_text, test_list): |
| code_block = generated_text |
| if "```python" in generated_text: |
| code_block = generated_text.split("```python")[-1].split("```")[0] |
| elif "```" in generated_text: |
| code_block = generated_text.split("```")[-1].split("```")[0] |
|
|
| local_scope = {} |
| try: |
| compiled_code = compile(code_block, "<string>", "exec") |
| exec(compiled_code, local_scope, local_scope) |
| for test in test_list: |
| exec(test, local_scope, local_scope) |
| return 1.0 |
| except Exception: |
| return 0.0 |
|
|
|
|
| |
| |
| |
| def run_benchmark(pipeline, base_model, dataset, is_code=False): |
| ar_correct = 0 |
| ad_correct = 0 |
| total = len(dataset) |
|
|
| ar_total_tokens = 0 |
| ad_total_tokens = 0 |
| ar_total_time = 0.0 |
| ad_total_time = 0.0 |
| ad_total_evals = 0 |
|
|
| for idx, item in enumerate(dataset): |
| prompt = item[0] |
| inputs = src_tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| max_new_tokens = 48 |
|
|
| |
| t_start = time.time() |
| with torch.no_grad(): |
| ar_outputs = base_model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| pad_token_id=src_tokenizer.pad_token_id, |
| eos_token_id=src_tokenizer.eos_token_id, |
| do_sample=False |
| ) |
| ar_total_time += (time.time() - t_start) |
| ar_gen_tokens = ar_outputs[0][inputs.input_ids.shape[1]:] |
| ar_total_tokens += len(ar_gen_tokens) |
| ar_text = src_tokenizer.decode(ar_gen_tokens, skip_special_tokens=True) |
|
|
| |
| t_start = time.time() |
| with torch.no_grad(): |
| ad_gen_tokens, step_evals = pipeline.generate_adapt_diff( |
| input_ids=inputs.input_ids, |
| max_new_tokens=max_new_tokens |
| ) |
| ad_total_time += (time.time() - t_start) |
| ad_total_tokens += len(ad_gen_tokens) |
| ad_total_evals += step_evals |
| ad_text = src_tokenizer.decode(ad_gen_tokens, skip_special_tokens=True) |
|
|
| if is_code: |
| ar_correct += verify_code(ar_text, item[2]) |
| ad_correct += verify_code(ad_text, item[2]) |
| else: |
| ar_correct += verify_math(ar_text, item[1]) |
| ad_correct += verify_math(ad_text, item[1]) |
|
|
| ar_throughput = ar_total_tokens / (ar_total_time + 1e-9) |
| ad_throughput = ad_total_tokens / (ad_total_time + 1e-9) |
| ad_flops_per_token = ad_total_evals / (ad_total_tokens + 1e-9) |
|
|
| return { |
| "ar_acc": ar_correct / total, |
| "ad_acc": ad_correct / total, |
| "ar_speed": ar_throughput, |
| "ad_speed": ad_throughput, |
| "ar_flops": 1.0, |
| "ad_flops": ad_flops_per_token |
| } |
|
|
| print("\nStarting evaluation run...") |
| math_results = run_benchmark(pipeline, baseline_model, val_math, is_code=False) |
| code_results = run_benchmark(pipeline, baseline_model, val_code, is_code=True) |
|
|
| |
| print("\n" + "="*95) |
| print(" ADAPT-DIFF INFERENCE BENCHMARK RESULTS (Block Size L = 12)") |
| print("="*95) |
| print(f"{'Task / Strategy':<30} | {'Throughput (tok/s)':<20} | {'Task Acc':<15} | {'Relative FLOPs/Tok':<20}") |
| print("-"*95) |
| print(f"{'GSM8K (Autoregressive Baseline)':<30} | {math_results['ar_speed']:<20.2f} | {math_results['ar_acc']:<15.2%} | {math_results['ar_flops']:<20.4f}") |
| print(f"{'GSM8K (ADAPT-DIFF Speculative)':<30} | {math_results['ad_speed']:<20.2f} | {math_results['ad_acc']:<15.2%} | {math_results['ad_flops']:<20.4f}") |
| print("-"*95) |
| print(f"{'MBPP (Autoregressive Baseline)':<30} | {code_results['ar_speed']:<20.2f} | {code_results['ar_acc']:<15.2%} | {code_results['ar_flops']:<20.4f}") |
| print(f"{'MBPP (ADAPT-DIFF Speculative)':<30} | {code_results['ad_speed']:<20.2f} | {code_results['ad_acc']:<15.2%} | {code_results['ad_flops']:<20.4f}") |
| print("="*95) |