| |
| """ |
| Offline evaluation for DFlash-LoRA-Inject: measure accepted length & speedup. |
| Aligned with official DFlash benchmark.py methodology. |
| |
| Unlike DFlash-b16 which uses a small 5-layer draft model with fc/hidden_norm, |
| LoRA-Inject uses a full Qwen3-8B with LoRA adapters that receives target hidden |
| states via layer-by-layer injection. |
| |
| Usage: |
| conda activate spec |
| |
| # 8 GPU parallel (default, all 10 benchmarks) |
| torchrun --nproc_per_node 8 eval_dflash_lora_inject.py |
| |
| # single GPU |
| python3 eval_dflash_lora_inject.py |
| |
| # specific checkpoint / benchmark |
| torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --ckpt epoch_0_step_1000 --datasets humaneval |
| |
| # quick test |
| torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --max-samples 20 |
| """ |
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| import time |
| import warnings |
| from itertools import chain |
| from types import SimpleNamespace |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.distributed as dist |
| from peft import PeftModel |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache |
|
|
| |
| sys.path.insert(0, "/workspace/hanrui/dflash") |
| from model.utils import load_and_process_dataset |
|
|
| |
| |
| |
| BASE_MODEL = "/workspace/models/Qwen3-8B" |
| ADAPTER_ROOT = "/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-inject" |
| DEFAULT_CKPT = "epoch_3_step_1400" |
| MASK_TOKEN_ID = 151669 |
| BLOCK_SIZE = 16 |
| RESULT_DIR = "/workspace/hanrui/syxin/Specforge/benchmarks/results" |
|
|
| |
| OFFICIAL_TASKS = { |
| "gsm8k": 128, |
| "math500": 128, |
| "aime24": 30, |
| "aime25": 30, |
| "humaneval": 164, |
| "mbpp": 128, |
| "livecodebench": 128, |
| "swe-bench": 128, |
| "mt-bench": 80, |
| "alpaca": 128, |
| } |
|
|
|
|
| |
| |
| |
| def cuda_time() -> float: |
| torch.cuda.synchronize() |
| return time.perf_counter() |
|
|
|
|
| def has_flash_attn() -> bool: |
| try: |
| import flash_attn |
| return True |
| except ImportError: |
| print("[WARN] flash_attn not installed, falling back to sdpa.") |
| return False |
|
|
|
|
| |
| |
| |
| def dist_init(): |
| if "RANK" not in os.environ: |
| warnings.warn("RANK not set. Skipping distributed init.") |
| return |
| dist.init_process_group(backend="nccl", init_method="env://") |
|
|
| def dist_rank(): |
| return int(os.environ.get("RANK", 0)) |
|
|
| def dist_size(): |
| return int(os.environ.get("WORLD_SIZE", 1)) |
|
|
| def dist_local_rank(): |
| return int(os.environ.get("LOCAL_RANK", 0)) |
|
|
| def dist_is_main(): |
| return dist_rank() == 0 |
|
|
| def dist_gather(obj, dst=0): |
| if not dist.is_initialized(): |
| return [obj] |
| if dist_is_main(): |
| objs = [None for _ in range(dist_size())] |
| dist.gather_object(obj, objs, dst=dst) |
| return objs |
| else: |
| dist.gather_object(obj, dst=dst) |
| return None |
|
|
| def print_rank0(*args, **kwargs): |
| if dist_is_main(): |
| print(*args, **kwargs) |
|
|
|
|
| |
| |
| |
| def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: |
| if temperature < 1e-5: |
| return torch.argmax(logits, dim=-1) |
| bsz, seq_len, vocab_size = logits.shape |
| logits = logits.view(-1, vocab_size) |
| logits = logits / temperature |
| probs = torch.softmax(logits, dim=-1) |
| return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) |
|
|
|
|
| |
| |
| |
| def build_dflash_mask(ctx_len: int, block_size: int, device, dtype=torch.bfloat16): |
| """ |
| Build DFlash attention mask for [context | block] sequence. |
| - Context part: standard causal |
| - Block part: each token sees all context + all tokens in same block (bidirectional) |
| """ |
| full_len = ctx_len + block_size |
| neg_inf = torch.finfo(dtype).min |
|
|
| mask = torch.full((1, 1, full_len, full_len), neg_inf, device=device, dtype=dtype) |
|
|
| if ctx_len > 0: |
| ctx_rows = torch.arange(ctx_len, device=device) |
| ctx_cols = torch.arange(ctx_len, device=device) |
| causal = ctx_cols.unsqueeze(0) <= ctx_rows.unsqueeze(1) |
| mask[0, 0, :ctx_len, :ctx_len].masked_fill_(causal, 0) |
|
|
| if ctx_len > 0: |
| mask[0, 0, ctx_len:, :ctx_len] = 0 |
| mask[0, 0, ctx_len:, ctx_len:] = 0 |
|
|
| return mask |
|
|
|
|
| |
| |
| |
| |
| @torch.inference_mode() |
| def ar_generate( |
| target_model: nn.Module, |
| input_ids: torch.LongTensor, |
| max_new_tokens: int = 2048, |
| mask_token_id: int = MASK_TOKEN_ID, |
| temperature: float = 0.0, |
| stop_token_ids: Optional[List[int]] = None, |
| ) -> SimpleNamespace: |
| """ |
| Pure autoregressive generation using only the target model. |
| Mirrors official benchmark.py with block_size=1 (no draft model involved). |
| Returns SimpleNamespace matching official dflash_generate output format. |
| """ |
| device = input_ids.device |
| num_input_tokens = input_ids.shape[1] |
| max_length = num_input_tokens + max_new_tokens |
|
|
| output_ids = torch.full( |
| (1, max_length + 1), mask_token_id, |
| dtype=torch.long, device=device, |
| ) |
| output_ids[:, :num_input_tokens] = input_ids |
| position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0) |
| past_key_values = DynamicCache() |
|
|
| |
| prefill_start = cuda_time() |
| output = target_model( |
| input_ids, |
| position_ids=position_ids[:, :num_input_tokens], |
| past_key_values=past_key_values, |
| use_cache=True, |
| logits_to_keep=1, |
| output_hidden_states=False, |
| ) |
| first_token = sample(output.logits, temperature) |
| output_ids[:, num_input_tokens:num_input_tokens + 1] = first_token |
| time_to_first_token = cuda_time() - prefill_start |
|
|
| |
| decode_start = cuda_time() |
| start = num_input_tokens |
|
|
| while start < max_length: |
| cur_token = output_ids[:, start:start + 1] |
| cur_pos = position_ids[:, start:start + 1] |
|
|
| output = target_model( |
| cur_token, |
| position_ids=cur_pos, |
| past_key_values=past_key_values, |
| use_cache=True, |
| output_hidden_states=False, |
| ) |
|
|
| next_token = sample(output.logits, temperature) |
| start += 1 |
| output_ids[:, start:start + 1] = next_token |
| past_key_values.crop(start) |
|
|
| |
| if stop_token_ids is not None and any( |
| sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids |
| ): |
| break |
|
|
| output_ids = output_ids[:, :max_length] |
| output_ids = output_ids[:, output_ids[0] != mask_token_id] |
| if stop_token_ids is not None: |
| stop_t = torch.tensor(stop_token_ids, device=output_ids.device) |
| stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0] |
| if stop_idx.numel() > 0: |
| output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1] |
|
|
| num_output_tokens = output_ids.shape[1] - num_input_tokens |
| total_decode_time = cuda_time() - decode_start |
| time_per_output_token = total_decode_time / max(num_output_tokens, 1) |
|
|
| return SimpleNamespace( |
| output_ids=output_ids, |
| num_input_tokens=num_input_tokens, |
| num_output_tokens=num_output_tokens, |
| time_to_first_token=time_to_first_token, |
| time_per_output_token=time_per_output_token, |
| acceptance_lengths=[1] * max(num_output_tokens, 0), |
| ) |
|
|
|
|
| |
| |
| |
| @torch.inference_mode() |
| def spec_generate_inject( |
| target_model: nn.Module, |
| draft_model: nn.Module, |
| input_ids: torch.LongTensor, |
| max_new_tokens: int = 2048, |
| block_size: int = 16, |
| mask_token_id: int = MASK_TOKEN_ID, |
| temperature: float = 0.0, |
| stop_token_ids: Optional[List[int]] = None, |
| ) -> SimpleNamespace: |
| """ |
| Speculative generation using DFlash-LoRA-Inject inference pattern. |
| Returns SimpleNamespace matching official dflash_generate output format. |
| """ |
| device = input_ids.device |
| num_input_tokens = input_ids.shape[1] |
| max_length = num_input_tokens + max_new_tokens |
|
|
| draft_layers = draft_model.model.layers |
| draft_norm = draft_model.model.norm |
| draft_lm_head = draft_model.lm_head |
| rotary_emb = draft_model.model.rotary_emb |
| num_layers = len(draft_layers) |
|
|
| output_ids = torch.full( |
| (1, max_length + block_size), mask_token_id, |
| dtype=torch.long, device=device, |
| ) |
| output_ids[:, :num_input_tokens] = input_ids |
|
|
| |
| prefill_start = cuda_time() |
| target_kv = DynamicCache() |
| target_output = target_model( |
| input_ids, |
| past_key_values=target_kv, |
| use_cache=True, |
| output_hidden_states=True, |
| ) |
| first_token = sample(target_output.logits[:, -1:, :], temperature) |
| output_ids[:, num_input_tokens] = first_token.squeeze() |
|
|
| ctx_hidden_per_layer = [ |
| target_output.hidden_states[i + 1] |
| for i in range(num_layers) |
| ] |
|
|
| time_to_first_token = cuda_time() - prefill_start |
|
|
| |
| decode_start = cuda_time() |
| acceptance_lengths = [] |
| start = num_input_tokens |
| draft_prefill = True |
|
|
| while start < max_length: |
| end = min(start + block_size, max_length) |
| actual_block_size = end - start |
|
|
| block_ids = output_ids[:, start:end].clone() |
|
|
| |
| draft_hidden = draft_model.model.embed_tokens(block_ids) |
| ctx_len = ctx_hidden_per_layer[0].shape[1] |
|
|
| dflash_mask = build_dflash_mask(ctx_len, actual_block_size, device) |
| combined_pos = torch.arange(ctx_len + actual_block_size, device=device).unsqueeze(0) |
|
|
| dummy_combined = torch.empty(1, ctx_len + actual_block_size, draft_hidden.shape[-1], |
| device=device, dtype=torch.bfloat16) |
| position_embeddings = rotary_emb(dummy_combined, combined_pos) |
|
|
| for layer_idx in range(num_layers): |
| target_ctx = ctx_hidden_per_layer[layer_idx] |
| combined = torch.cat([target_ctx, draft_hidden], dim=1) |
|
|
| layer_output = draft_layers[layer_idx]( |
| combined, |
| attention_mask=dflash_mask, |
| position_ids=combined_pos, |
| position_embeddings=position_embeddings, |
| ) |
| if isinstance(layer_output, tuple): |
| layer_output = layer_output[0] |
| draft_hidden = layer_output[:, ctx_len:, :] |
|
|
| draft_hidden = draft_norm(draft_hidden) |
| draft_logits = draft_lm_head(draft_hidden) |
|
|
| draft_predictions = sample(draft_logits[:, :-1, :], temperature) |
| block_ids[:, 1:actual_block_size] = draft_predictions[:, :actual_block_size - 1] |
|
|
| |
| if draft_prefill: |
| draft_prefill = False |
| decode_start = cuda_time() |
|
|
| |
| position_ids_block = torch.arange( |
| start, start + actual_block_size, device=device |
| ).unsqueeze(0) |
|
|
| target_verify = target_model( |
| block_ids, |
| position_ids=position_ids_block, |
| past_key_values=target_kv, |
| use_cache=True, |
| output_hidden_states=True, |
| ) |
| target_tokens = sample(target_verify.logits, temperature) |
|
|
| |
| matches = (block_ids[:, 1:actual_block_size] == target_tokens[:, :actual_block_size - 1]) |
| acceptance_length = int(matches.cumprod(dim=1).sum(dim=1)[0].item()) |
|
|
| output_ids[:, start:start + acceptance_length + 1] = block_ids[:, :acceptance_length + 1] |
| output_ids[:, start + acceptance_length + 1] = target_tokens[:, acceptance_length] |
|
|
| accepted_end = start + acceptance_length + 1 |
| target_kv.crop(accepted_end) |
|
|
| for i in range(num_layers): |
| new_hidden = target_verify.hidden_states[i + 1][:, :acceptance_length + 1, :] |
| ctx_hidden_per_layer[i] = torch.cat([ctx_hidden_per_layer[i], new_hidden], dim=1) |
|
|
| start += acceptance_length + 1 |
| acceptance_lengths.append(acceptance_length + 1) |
|
|
| |
| if stop_token_ids is not None and any( |
| sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids |
| ): |
| break |
|
|
| output_ids = output_ids[:, :min(start, max_length)] |
| output_ids = output_ids[:, output_ids[0] != mask_token_id] |
| if stop_token_ids is not None: |
| stop_t = torch.tensor(stop_token_ids, device=output_ids.device) |
| stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0] |
| if stop_idx.numel() > 0: |
| output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1] |
|
|
| num_output_tokens = output_ids.shape[1] - num_input_tokens |
| total_decode_time = cuda_time() - decode_start |
| time_per_output_token = total_decode_time / max(num_output_tokens, 1) |
|
|
| return SimpleNamespace( |
| output_ids=output_ids, |
| num_input_tokens=num_input_tokens, |
| num_output_tokens=num_output_tokens, |
| time_to_first_token=time_to_first_token, |
| time_per_output_token=time_per_output_token, |
| acceptance_lengths=acceptance_lengths, |
| ) |
|
|
|
|
| |
| |
| |
| def parse_args(): |
| p = argparse.ArgumentParser(description="Offline eval for DFlash-LoRA-Inject (aligned with official)") |
| p.add_argument("--base-model", default=BASE_MODEL) |
| p.add_argument("--adapter-root", default=ADAPTER_ROOT) |
| p.add_argument("--ckpt", default=DEFAULT_CKPT, help="Checkpoint folder name") |
| p.add_argument("--merged-path", |
| default="/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged", |
| help="Path to pre-merged model. If None, will merge on the fly.") |
| p.add_argument("--block-size", type=int, default=BLOCK_SIZE) |
| p.add_argument("--max-new-tokens", type=int, default=2048, |
| help="Max new tokens per turn (official shell uses 2048)") |
| p.add_argument("--temperature", type=float, default=0.0) |
| p.add_argument("--datasets", nargs="+", default=list(OFFICIAL_TASKS.keys()), |
| help="Benchmarks to run (default: all 10 official tasks)") |
| p.add_argument("--max-samples", type=int, default=None, |
| help="Override max samples per dataset (None = use official per-task counts)") |
| p.add_argument("--output-dir", default=RESULT_DIR) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| |
| random.seed(0) |
| np.random.seed(0) |
| torch.manual_seed(0) |
| torch.cuda.manual_seed_all(0) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| dist_init() |
| torch.cuda.set_device(dist_local_rank()) |
| device = torch.device(f"cuda:{dist_local_rank()}") |
|
|
| print_rank0(f"Running on {dist_size()} GPU(s)") |
|
|
| |
| installed_flash_attn = has_flash_attn() |
| target_attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa" |
| draft_attn_impl = "sdpa" |
| print_rank0(f"Using attn_implementation: target={target_attn_impl}, draft={draft_attn_impl}") |
|
|
| |
| print_rank0(f"Loading target model: {args.base_model}") |
| target_model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, |
| torch_dtype=torch.bfloat16, |
| attn_implementation=target_attn_impl, |
| device_map=device, |
| trust_remote_code=True, |
| ) |
| target_model.eval() |
|
|
| if args.merged_path and os.path.isdir(args.merged_path): |
| print_rank0(f"Loading pre-merged draft model: {args.merged_path}") |
| draft_model = AutoModelForCausalLM.from_pretrained( |
| args.merged_path, |
| torch_dtype=torch.bfloat16, |
| attn_implementation=draft_attn_impl, |
| device_map=device, |
| trust_remote_code=True, |
| ) |
| else: |
| adapter_path = os.path.join(args.adapter_root, args.ckpt) |
| print_rank0(f"Loading base + LoRA adapter: {adapter_path}") |
| draft_model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, |
| torch_dtype=torch.bfloat16, |
| attn_implementation=draft_attn_impl, |
| device_map=device, |
| trust_remote_code=True, |
| ) |
| draft_model = PeftModel.from_pretrained(draft_model, adapter_path) |
| draft_model = draft_model.merge_and_unload() |
| draft_model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) |
| stop_token_ids = [tokenizer.eos_token_id] |
|
|
| block_size = args.block_size |
|
|
| |
| all_results = {"model": f"dflash-lora-inject/{args.ckpt}", "block_size": block_size} |
|
|
| for dataset_name in args.datasets: |
| print_rank0(f"\n{'=' * 60}") |
| print_rank0(f"Benchmark: {dataset_name} ({dist_size()} GPUs)") |
| print_rank0(f"{'=' * 60}") |
|
|
| |
| dataset = load_and_process_dataset(dataset_name) |
|
|
| |
| max_samples = args.max_samples if args.max_samples is not None else OFFICIAL_TASKS.get(dataset_name) |
| if max_samples is not None and len(dataset) > max_samples: |
| dataset = dataset.shuffle(seed=0).select(range(max_samples)) |
|
|
| print_rank0(f"Total {len(dataset)} samples, distributed across {dist_size()} GPUs") |
|
|
| responses = [] |
| indices = range(dist_rank(), len(dataset), dist_size()) |
|
|
| iterator = tqdm(indices, desc=f"[GPU{dist_rank()}] {dataset_name}", |
| unit="sample", disable=not dist_is_main()) |
|
|
| for idx in iterator: |
| instance = dataset[idx] |
|
|
| |
| messages = [] |
| for turn_index, user_content in enumerate(instance["turns"]): |
| messages.append({"role": "user", "content": user_content}) |
| input_text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True, |
| enable_thinking=False, |
| ) |
| input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) |
|
|
| response = {} |
|
|
| |
| response[1] = ar_generate( |
| target_model=target_model, |
| input_ids=input_ids, |
| max_new_tokens=args.max_new_tokens, |
| mask_token_id=MASK_TOKEN_ID, |
| temperature=args.temperature, |
| stop_token_ids=stop_token_ids, |
| ) |
|
|
| |
| response[block_size] = spec_generate_inject( |
| target_model=target_model, |
| draft_model=draft_model, |
| input_ids=input_ids, |
| max_new_tokens=args.max_new_tokens, |
| block_size=block_size, |
| mask_token_id=MASK_TOKEN_ID, |
| temperature=args.temperature, |
| stop_token_ids=stop_token_ids, |
| ) |
|
|
| |
| spec_response = response[block_size] |
| generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens:] |
| output_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
| messages.append({"role": "assistant", "content": output_text}) |
| responses.append(response) |
|
|
| if dist_is_main() and responses: |
| recent_tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses[-5:]]) |
| iterator.set_postfix(accept_len=f"{recent_tau:.2f}") |
|
|
| |
| if dist_size() > 1: |
| gathered = dist_gather(responses, dst=0) |
| if not dist_is_main(): |
| continue |
| responses = list(chain(*gathered)) |
| elif not dist_is_main(): |
| continue |
|
|
| |
| t1 = np.mean([r[1].time_per_output_token for r in responses]) |
| tb = np.mean([r[block_size].time_per_output_token for r in responses]) |
| speedup = t1 / tb if tb > 0 else 0 |
|
|
| |
| tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses]) |
|
|
| |
| acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses])) |
| histogram = [acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1)] |
|
|
| print_rank0(f"\n{dataset_name} Results:") |
| print_rank0(f" Decoding speedup: {speedup:.2f}x") |
| print_rank0(f" Average Acceptance length: {tau:.2f}") |
| print_rank0(f" Acceptance length histogram: {[f'{x * 100:.1f}%' for x in histogram]}") |
| print_rank0(f" Num responses: {len(responses)}") |
|
|
| all_results[dataset_name] = { |
| "decoding_speedup": speedup, |
| "avg_accept_length": tau, |
| "acceptance_histogram": histogram, |
| "num_responses": len(responses), |
| "num_gpus": dist_size(), |
| } |
|
|
| |
| if dist_is_main(): |
| os.makedirs(args.output_dir, exist_ok=True) |
| timestamp = time.strftime("%Y%m%d_%H%M%S") |
| result_file = os.path.join( |
| args.output_dir, |
| f"dflash_lora_inject_offline_{args.ckpt}_{timestamp}.json", |
| ) |
| with open(result_file, "w") as f: |
| json.dump(all_results, f, indent=2) |
| print(f"\nResults saved to: {result_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|