#!/usr/bin/env python3 """ Offline evaluation for DFlash-b16 baseline: measure accepted length. 8 GPUs parallel, each GPU loads target + draft independently. Usage: # 8 GPUs torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py # quick test torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py --num-samples 20 # single GPU python3 eval_dflash_b16_baseline.py --benchmarks humaneval """ import argparse import json import os import sys import time from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.distributed as dist from tqdm import tqdm from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache # Add DFlash model path so we can import utils sys.path.insert(0, "/workspace/models/Qwen3-8B-DFlash-b16") from utils import extract_context_feature, sample # ────────────────────────────────────────────────────────────────── BASE_MODEL = "/workspace/models/Qwen3-8B" DRAFT_MODEL = "/workspace/models/Qwen3-8B-DFlash-b16" RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results" # ────────────────────────────────────────────────────────────────── # Distributed helpers # ────────────────────────────────────────────────────────────────── def is_distributed(): return dist.is_available() and dist.is_initialized() def get_rank(): return dist.get_rank() if is_distributed() else 0 def get_world_size(): return dist.get_world_size() if is_distributed() else 1 def is_main(): return get_rank() == 0 def print_rank0(*args, **kwargs): if is_main(): print(*args, **kwargs) def split_list(lst, rank, world_size): return [x for i, x in enumerate(lst) if i % world_size == rank] # ────────────────────────────────────────────────────────────────── # Prompts # ────────────────────────────────────────────────────────────────── def load_prompts(bench_name: str, num_samples: Optional[int] = None) -> List[str]: local_paths = { "humaneval": "/workspace/hanrui/datasets/humaneval/test.jsonl", "mtbench": "/workspace/hanrui/datasets/mtbench/question.jsonl", "gsm8k": "/workspace/hanrui/datasets/gsm8k/test.jsonl", } prompts = [] path = local_paths.get(bench_name) if path and os.path.exists(path): with open(path) as f: for line in f: item = json.loads(line) if bench_name == "humaneval": p = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{item['prompt']}\n```" elif bench_name == "mtbench": p = item.get("turns", [item.get("prompt", "")])[0] elif bench_name == "gsm8k": p = item["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." else: p = str(item) prompts.append(p) else: from datasets import load_dataset if bench_name == "humaneval": ds = load_dataset("openai/openai_humaneval", split="test") prompts = [f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{x['prompt']}\n```" for x in ds] elif bench_name == "mtbench": ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") prompts = [x["prompt"][0] for x in ds] elif bench_name == "gsm8k": ds = load_dataset("openai/gsm8k", "main", split="test") prompts = [x["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." for x in ds] if num_samples is not None: prompts = prompts[:num_samples] return prompts # ────────────────────────────────────────────────────────────────── # spec_generate with acceptance_lengths returned # (Same logic as DFlashDraftModel.spec_generate but returns accept lens) # ────────────────────────────────────────────────────────────────── @torch.inference_mode() def spec_generate_b16( draft_model, target_model: nn.Module, input_ids: torch.LongTensor, max_new_tokens: int = 512, temperature: float = 0.0, stop_token_ids: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, List[int]]: """Same as DFlashDraftModel.spec_generate but also returns acceptance_lengths.""" draft_model.eval() device = target_model.device if hasattr(target_model, 'device') else input_ids.device num_input_tokens = input_ids.shape[1] max_length = num_input_tokens + max_new_tokens block_size = draft_model.block_size mask_token_id = draft_model.mask_token_id output_ids = torch.full( (1, max_length + block_size), mask_token_id, dtype=torch.long, device=device, ) position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0) past_key_values_target = DynamicCache() past_key_values_draft = DynamicCache() # Prefill output = target_model( input_ids, position_ids=position_ids[:, :num_input_tokens], past_key_values=past_key_values_target, use_cache=True, logits_to_keep=1, output_hidden_states=True, ) output_ids[:, :num_input_tokens] = input_ids output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature) target_hidden = extract_context_feature(output.hidden_states, draft_model.target_layer_ids) # Decode acceptance_lengths = [] start = num_input_tokens while start < max_length: block_output_ids = output_ids[:, start:start + block_size].clone() block_position_ids = position_ids[:, start:start + block_size] noise_embedding = target_model.model.embed_tokens(block_output_ids) draft_logits = target_model.lm_head( draft_model( target_hidden=target_hidden, noise_embedding=noise_embedding, position_ids=position_ids[:, past_key_values_draft.get_seq_length():start + block_size], past_key_values=past_key_values_draft, use_cache=True, is_causal=False, )[:, -block_size + 1:, :] ) past_key_values_draft.crop(start) block_output_ids[:, 1:] = sample(draft_logits) output = target_model( block_output_ids, position_ids=block_position_ids, past_key_values=past_key_values_target, use_cache=True, output_hidden_states=True, ) posterior = sample(output.logits, temperature) acceptance_length = ( (block_output_ids[:, 1:] == posterior[:, :-1]) .cumprod(dim=1).sum(dim=1)[0].item() ) output_ids[:, start:start + int(acceptance_length) + 1] = block_output_ids[:, :int(acceptance_length) + 1] output_ids[:, start + int(acceptance_length) + 1] = posterior[:, int(acceptance_length)] start += int(acceptance_length) + 1 past_key_values_target.crop(start) target_hidden = extract_context_feature( output.hidden_states, draft_model.target_layer_ids )[:, :int(acceptance_length) + 1, :] acceptance_lengths.append(int(acceptance_length) + 1) if stop_token_ids is not None and any( sid in output_ids[:, num_input_tokens:start] for sid in stop_token_ids ): break output_ids = output_ids[:, :max_length] output_ids = output_ids[:, output_ids[0] != mask_token_id] if stop_token_ids is not None: stop_t = torch.tensor(stop_token_ids, device=output_ids.device) stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0] if stop_idx.numel() > 0: output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1] return output_ids, acceptance_lengths # ────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser() p.add_argument("--base-model", default=BASE_MODEL) p.add_argument("--draft-model", default=DRAFT_MODEL) p.add_argument("--max-new-tokens", type=int, default=512) p.add_argument("--temperature", type=float, default=0.0) p.add_argument("--benchmarks", nargs="+", default=["humaneval", "mtbench", "gsm8k"]) p.add_argument("--num-samples", type=int, default=None) p.add_argument("--output-dir", default=RESULT_DIR) return p.parse_args() def main(): args = parse_args() local_rank = int(os.environ.get("LOCAL_RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) if world_size > 1: dist.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) device = f"cuda:{local_rank}" rank = get_rank() print_rank0(f"Running DFlash-b16 baseline on {world_size} GPU(s)") # ── Load models ── print_rank0(f"Loading target: {args.base_model}") target_model = AutoModelForCausalLM.from_pretrained( args.base_model, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True, ) target_model.eval() print_rank0(f"Loading DFlash-b16 draft: {args.draft_model}") draft_model = AutoModel.from_pretrained( args.draft_model, torch_dtype=torch.bfloat16, trust_remote_code=True, ).to(device) draft_model.eval() tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) stop_token_ids = [tokenizer.eos_token_id] print_rank0(f"DFlash-b16: block_size={draft_model.block_size}, " f"target_layer_ids={draft_model.target_layer_ids}, " f"num_layers={len(draft_model.layers)}") # ── Run benchmarks ── results = {"model": "Qwen3-8B-DFlash-b16", "type": "baseline", "block_size": draft_model.block_size} for bench_name in args.benchmarks: print_rank0(f"\n{'='*60}") print_rank0(f"Benchmark: {bench_name} ({world_size} GPUs)") print_rank0(f"{'='*60}") all_prompts = load_prompts(bench_name, args.num_samples) my_prompts = split_list(all_prompts, rank, world_size) print_rank0(f"Total {len(all_prompts)} prompts, ~{len(my_prompts)} per GPU") local_accept_lengths = [] local_tokens = 0 t0 = time.time() iterator = tqdm(my_prompts, desc=f"[GPU{rank}] {bench_name}", unit="sample", disable=(rank != 0)) for prompt in iterator: messages = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) output_ids, accept_lens = spec_generate_b16( draft_model=draft_model, target_model=target_model, input_ids=input_ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, stop_token_ids=stop_token_ids, ) local_accept_lengths.extend(accept_lens) num_gen = output_ids.shape[1] - input_ids.shape[1] local_tokens += num_gen if rank == 0 and len(local_accept_lengths) > 0: avg = sum(local_accept_lengths) / len(local_accept_lengths) iterator.set_postfix(accept_len=f"{avg:.2f}", tokens=local_tokens, gen=num_gen) elapsed = time.time() - t0 # ── Gather ── if world_size > 1: local_sum = torch.tensor(sum(local_accept_lengths), dtype=torch.float64, device=device) local_count = torch.tensor(len(local_accept_lengths), dtype=torch.long, device=device) local_tok = torch.tensor(local_tokens, dtype=torch.long, device=device) dist.all_reduce(local_sum, op=dist.ReduceOp.SUM) dist.all_reduce(local_count, op=dist.ReduceOp.SUM) dist.all_reduce(local_tok, op=dist.ReduceOp.SUM) total_accept_sum = local_sum.item() total_count = local_count.item() total_tokens = local_tok.item() else: total_accept_sum = sum(local_accept_lengths) total_count = len(local_accept_lengths) total_tokens = local_tokens avg_accept_length = total_accept_sum / max(total_count, 1) throughput = total_tokens / elapsed if elapsed > 0 else 0 print_rank0(f"\n{bench_name} Results:") print_rank0(f" Avg Accept Length: {avg_accept_length:.3f}") print_rank0(f" Total tokens: {total_tokens}") print_rank0(f" Latency: {elapsed:.1f}s") print_rank0(f" Throughput: {throughput:.1f} tok/s (aggregate {world_size} GPUs)") print_rank0(f" Num verify rounds: {total_count}") print_rank0(f" Num samples: {len(all_prompts)}") results[bench_name] = { "avg_accept_length": avg_accept_length, "total_tokens": total_tokens, "latency": elapsed, "throughput": throughput, "num_samples": len(all_prompts), "num_verify_rounds": total_count, "num_gpus": world_size, } # ── Save ── if is_main(): os.makedirs(args.output_dir, exist_ok=True) timestamp = time.strftime("%Y%m%d_%H%M%S") result_file = os.path.join( args.output_dir, f"dflash_b16_baseline_offline_{timestamp}.json", ) with open(result_file, "w") as f: json.dump(results, f, indent=2) print(f"\nResults saved to: {result_file}") if world_size > 1: dist.destroy_process_group() if __name__ == "__main__": main()