dataopsnick's picture
Create infer.py
763a9f6 verified
Raw
History Blame Contribute Delete
16.9 kB
"""
ADAPT-DIFF Inference & Benchmark Script
Downloads 'dataopsnick/adapt-diff-qwen-0.8b' and compares it with 'Qwen/Qwen3.5-0.8B'.
"""
import os
import gc
import time
import re
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
# 1. Install/Update Dependencies
print("Ensuring dependencies are installed...")
os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub")
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from datasets import load_dataset
from huggingface_hub import hf_hub_download
# Clean up GPU cache before running
gc.collect()
torch.cuda.empty_cache()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B"
ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b"
print(f"Loading {BASE_MODEL_ID} metadata to dynamically resolve architecture classes...")
src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
if src_tokenizer.pad_token is None:
src_tokenizer.pad_token = src_tokenizer.eos_token
# Load temporary instance to resolve base classes exactly as in your environment
temp_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
src_config = temp_model.config
BaseConfig = src_config.__class__
BaseModel = temp_model.model.__class__
BaseCausalLM = temp_model.__class__
BasePreTrainedModel = next(
(cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")),
None
)
if BasePreTrainedModel is None:
BasePreTrainedModel = BaseCausalLM.__bases__[0]
# Free temporary model memory
del temp_model
gc.collect()
# ==============================================================================
# Custom ADAPT-DIFF Architecture Classes
# ==============================================================================
class A2DQwenConfig(BaseConfig):
model_type = "a2d-qwen"
class A2DQwenModel(BaseModel):
def forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
past_key_values = None,
inputs_embeds = None,
use_cache = None,
cache_position = None,
**kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("Specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Core ADAPT-DIFF modification: replace causal mask with bidirectional/padding-only mask
if not isinstance(causal_mask_mapping := attention_mask, dict):
if attention_mask is None:
attention_mask = torch.ones(
inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long
)
if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4):
attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
causal_mask_mapping = defaultdict(lambda: attention_mask)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
attn_type = getattr(decoder_layer, "attention_type", "self_attn")
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[attn_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class A2DQwenLMHeadModel(BaseCausalLM):
config_class = A2DQwenConfig
def __init__(self, config):
BasePreTrainedModel.__init__(self, config)
self.model = A2DQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
# Register custom classes with Hugging Face AutoClasses
transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig)
transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel)
transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel)
# ==============================================================================
# Custom Projection and Search Pipeline Components
# ==============================================================================
class StackedLDMHeads(nn.Module):
def __init__(self, hidden_size, vocab_size, block_size=12):
super().__init__()
self.block_size = block_size
self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16)
self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
forecast = self.proj(hidden_states)
forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size)
logits = self.head(forecast)
return logits
class LogitUncertaintyFilter(nn.Module):
def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
probs = F.softmax(logits.float(), dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
return entropy
def forward(self, logits: torch.Tensor, threshold: float):
entropy = self.compute_entropy(logits)
mask = entropy >= threshold
return mask, entropy
class ActorCriticPruner:
def __init__(self, lm_head, lambda_reg=0.1):
self.lm_head = lm_head
self.lambda_reg = lambda_reg
def evaluate_sequence_value(self, candidate_tokens, logits):
log_probs = F.log_softmax(logits.float(), dim=-1)
gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1)
return gathered.mean().item()
def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta):
refined_sequence = sequence.clone()
if depth == 0 or mask.sum() == 0:
return refined_sequence, self.evaluate_sequence_value(sequence, logits)
high_unc_positions = torch.where(mask)[0]
if len(high_unc_positions) == 0:
return refined_sequence, self.evaluate_sequence_value(sequence, logits)
target_pos = high_unc_positions[0].item()
top_logits, top_tokens = torch.topk(logits[target_pos], k=3)
best_val = float('-inf')
for token_opt in top_tokens:
candidate = sequence.clone()
candidate[target_pos] = token_opt
approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item())
if approx_val < alpha:
continue
new_mask = mask.clone()
new_mask[target_pos] = False
_, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta)
if path_val > alpha:
alpha = path_val
best_val = path_val
refined_sequence = candidate
if alpha >= beta:
break
return refined_sequence, best_val
class ADAPTDIFFPipeline(nn.Module):
def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5):
super().__init__()
self.base_model = base_lm_model.model
self.lm_head = base_lm_model.lm_head
self.block_size = block_size
self.entropy_threshold = entropy_threshold
self.ldm_heads = StackedLDMHeads(
hidden_size=base_lm_model.config.hidden_size,
vocab_size=base_lm_model.config.vocab_size,
block_size=block_size
).to(DEVICE)
self.router = LogitUncertaintyFilter()
self.pruner = ActorCriticPruner(self.lm_head)
def generate_adapt_diff(self, input_ids, max_new_tokens=128):
current_seq = input_ids.clone()
generated_count = 0
total_full_transformer_evals = 0
while generated_count < max_new_tokens:
outputs = self.base_model(input_ids=current_seq)
total_full_transformer_evals += 1
last_hidden = outputs.last_hidden_state[:, -1:, :]
block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0)
draft_tokens = torch.argmax(block_logits, dim=-1)
mask, entropy = self.router(block_logits, self.entropy_threshold)
if not mask.any():
final_block = draft_tokens
else:
total_full_transformer_evals += 1
final_block, _ = self.pruner.recursive_refine(
sequence=draft_tokens,
logits=block_logits,
mask=mask,
entropy=entropy,
depth=2,
alpha=float('-inf'),
beta=float('inf')
)
current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1)
generated_count += self.block_size
return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals
# ==============================================================================
# Model Loading & LDM Weights Initialization
# ==============================================================================
print(f"Downloading custom bidirectional model {ADAPT_DIFF_ID} from Hugging Face...")
a2d_model = AutoModelForCausalLM.from_pretrained(
ADAPT_DIFF_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
print(f"Downloading baseline model {BASE_MODEL_ID} for comparative evaluation...")
baseline_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
# Initialize generation pipeline and load pre-trained custom LDM weights
pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5)
print("Downloading LDM head projection weights...")
ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt")
pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE))
pipeline.eval()
# ==============================================================================
# Sub-Sampled Benchmark Initialization
# ==============================================================================
print("\nLoading GSM8K and MBPP evaluation datasets...")
gsm8k_ds = load_dataset("openai/gsm8k", "main", split="test")
mbpp_ds = load_dataset("google-research-datasets/mbpp", split="test")
val_math = []
for item in gsm8k_ds:
val_math.append((f"Problem: {item['question']}\nSolution:", item['answer']))
if len(val_math) >= 10: # Fast benchmark slice
break
val_code = []
for item in mbpp_ds:
val_code.append((f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n", item['code'], item['test_list']))
if len(val_code) >= 10:
break
# ==============================================================================
# Validation Helpers
# ==============================================================================
def extract_answer(text):
if "####" in text:
text = text.split("####")[-1]
matches = re.findall(r'-?[\d,]*\.?\d+', text)
return matches[-1].replace(',', '') if matches else None
def verify_math(generated_text, ref_ans):
pred_val = extract_answer(generated_text)
ref_val = extract_answer(ref_ans)
if pred_val is None or ref_val is None:
return 0.0
try:
return 1.0 if float(pred_val) == float(ref_val) else 0.0
except ValueError:
return 1.0 if str(pred_val).strip() == str(ref_val).strip() else 0.0
def verify_code(generated_text, test_list):
code_block = generated_text
if "```python" in generated_text:
code_block = generated_text.split("```python")[-1].split("```")[0]
elif "```" in generated_text:
code_block = generated_text.split("```")[-1].split("```")[0]
local_scope = {}
try:
compiled_code = compile(code_block, "<string>", "exec")
exec(compiled_code, local_scope, local_scope)
for test in test_list:
exec(test, local_scope, local_scope)
return 1.0
except Exception:
return 0.0
# ==============================================================================
# Evaluation Loop
# ==============================================================================
def run_benchmark(pipeline, base_model, dataset, is_code=False):
ar_correct = 0
ad_correct = 0
total = len(dataset)
ar_total_tokens = 0
ad_total_tokens = 0
ar_total_time = 0.0
ad_total_time = 0.0
ad_total_evals = 0
for idx, item in enumerate(dataset):
prompt = item[0]
inputs = src_tokenizer(prompt, return_tensors="pt").to(DEVICE)
max_new_tokens = 48
# Autoregressive generation
t_start = time.time()
with torch.no_grad():
ar_outputs = base_model.generate(
**inputs,
max_new_tokens=max_new_tokens,
pad_token_id=src_tokenizer.pad_token_id,
eos_token_id=src_tokenizer.eos_token_id,
do_sample=False
)
ar_total_time += (time.time() - t_start)
ar_gen_tokens = ar_outputs[0][inputs.input_ids.shape[1]:]
ar_total_tokens += len(ar_gen_tokens)
ar_text = src_tokenizer.decode(ar_gen_tokens, skip_special_tokens=True)
# ADAPT-DIFF speculative generation
t_start = time.time()
with torch.no_grad():
ad_gen_tokens, step_evals = pipeline.generate_adapt_diff(
input_ids=inputs.input_ids,
max_new_tokens=max_new_tokens
)
ad_total_time += (time.time() - t_start)
ad_total_tokens += len(ad_gen_tokens)
ad_total_evals += step_evals
ad_text = src_tokenizer.decode(ad_gen_tokens, skip_special_tokens=True)
if is_code:
ar_correct += verify_code(ar_text, item[2])
ad_correct += verify_code(ad_text, item[2])
else:
ar_correct += verify_math(ar_text, item[1])
ad_correct += verify_math(ad_text, item[1])
ar_throughput = ar_total_tokens / (ar_total_time + 1e-9)
ad_throughput = ad_total_tokens / (ad_total_time + 1e-9)
ad_flops_per_token = ad_total_evals / (ad_total_tokens + 1e-9)
return {
"ar_acc": ar_correct / total,
"ad_acc": ad_correct / total,
"ar_speed": ar_throughput,
"ad_speed": ad_throughput,
"ar_flops": 1.0,
"ad_flops": ad_flops_per_token
}
print("\nStarting evaluation run...")
math_results = run_benchmark(pipeline, baseline_model, val_math, is_code=False)
code_results = run_benchmark(pipeline, baseline_model, val_code, is_code=True)
# Print comparative results
print("\n" + "="*95)
print(" ADAPT-DIFF INFERENCE BENCHMARK RESULTS (Block Size L = 12)")
print("="*95)
print(f"{'Task / Strategy':<30} | {'Throughput (tok/s)':<20} | {'Task Acc':<15} | {'Relative FLOPs/Tok':<20}")
print("-"*95)
print(f"{'GSM8K (Autoregressive Baseline)':<30} | {math_results['ar_speed']:<20.2f} | {math_results['ar_acc']:<15.2%} | {math_results['ar_flops']:<20.4f}")
print(f"{'GSM8K (ADAPT-DIFF Speculative)':<30} | {math_results['ad_speed']:<20.2f} | {math_results['ad_acc']:<15.2%} | {math_results['ad_flops']:<20.4f}")
print("-"*95)
print(f"{'MBPP (Autoregressive Baseline)':<30} | {code_results['ar_speed']:<20.2f} | {code_results['ar_acc']:<15.2%} | {code_results['ar_flops']:<20.4f}")
print(f"{'MBPP (ADAPT-DIFF Speculative)':<30} | {code_results['ad_speed']:<20.2f} | {code_results['ad_acc']:<15.2%} | {code_results['ad_flops']:<20.4f}")
print("="*95)