| |
| """ |
| Offline evaluation for DFlash-b16 baseline: measure accepted length. |
| 8 GPUs parallel, each GPU loads target + draft independently. |
| |
| Usage: |
| # 8 GPUs |
| torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py |
| |
| # quick test |
| torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py --num-samples 20 |
| |
| # single GPU |
| python3 eval_dflash_b16_baseline.py --benchmarks humaneval |
| """ |
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.distributed as dist |
| from tqdm import tqdm |
| from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache |
|
|
| |
| sys.path.insert(0, "/workspace/models/Qwen3-8B-DFlash-b16") |
| from utils import extract_context_feature, sample |
|
|
| |
| BASE_MODEL = "/workspace/models/Qwen3-8B" |
| DRAFT_MODEL = "/workspace/models/Qwen3-8B-DFlash-b16" |
| RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results" |
|
|
|
|
| |
| |
| |
| def is_distributed(): |
| return dist.is_available() and dist.is_initialized() |
|
|
| def get_rank(): |
| return dist.get_rank() if is_distributed() else 0 |
|
|
| def get_world_size(): |
| return dist.get_world_size() if is_distributed() else 1 |
|
|
| def is_main(): |
| return get_rank() == 0 |
|
|
| def print_rank0(*args, **kwargs): |
| if is_main(): |
| print(*args, **kwargs) |
|
|
| def split_list(lst, rank, world_size): |
| return [x for i, x in enumerate(lst) if i % world_size == rank] |
|
|
|
|
| |
| |
| |
| def load_prompts(bench_name: str, num_samples: Optional[int] = None) -> List[str]: |
| local_paths = { |
| "humaneval": "/workspace/hanrui/datasets/humaneval/test.jsonl", |
| "mtbench": "/workspace/hanrui/datasets/mtbench/question.jsonl", |
| "gsm8k": "/workspace/hanrui/datasets/gsm8k/test.jsonl", |
| } |
| prompts = [] |
| path = local_paths.get(bench_name) |
| if path and os.path.exists(path): |
| with open(path) as f: |
| for line in f: |
| item = json.loads(line) |
| if bench_name == "humaneval": |
| p = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{item['prompt']}\n```" |
| elif bench_name == "mtbench": |
| p = item.get("turns", [item.get("prompt", "")])[0] |
| elif bench_name == "gsm8k": |
| p = item["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." |
| else: |
| p = str(item) |
| prompts.append(p) |
| else: |
| from datasets import load_dataset |
| if bench_name == "humaneval": |
| ds = load_dataset("openai/openai_humaneval", split="test") |
| prompts = [f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{x['prompt']}\n```" for x in ds] |
| elif bench_name == "mtbench": |
| ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") |
| prompts = [x["prompt"][0] for x in ds] |
| elif bench_name == "gsm8k": |
| ds = load_dataset("openai/gsm8k", "main", split="test") |
| prompts = [x["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." for x in ds] |
| if num_samples is not None: |
| prompts = prompts[:num_samples] |
| return prompts |
|
|
|
|
| |
| |
| |
| |
| @torch.inference_mode() |
| def spec_generate_b16( |
| draft_model, |
| target_model: nn.Module, |
| input_ids: torch.LongTensor, |
| max_new_tokens: int = 512, |
| temperature: float = 0.0, |
| stop_token_ids: Optional[List[int]] = None, |
| ) -> Tuple[torch.Tensor, List[int]]: |
| """Same as DFlashDraftModel.spec_generate but also returns acceptance_lengths.""" |
| draft_model.eval() |
| device = target_model.device if hasattr(target_model, 'device') else input_ids.device |
| num_input_tokens = input_ids.shape[1] |
| max_length = num_input_tokens + max_new_tokens |
| block_size = draft_model.block_size |
| mask_token_id = draft_model.mask_token_id |
|
|
| output_ids = torch.full( |
| (1, max_length + block_size), mask_token_id, |
| dtype=torch.long, device=device, |
| ) |
| position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0) |
|
|
| past_key_values_target = DynamicCache() |
| past_key_values_draft = DynamicCache() |
|
|
| |
| output = target_model( |
| input_ids, |
| position_ids=position_ids[:, :num_input_tokens], |
| past_key_values=past_key_values_target, |
| use_cache=True, |
| logits_to_keep=1, |
| output_hidden_states=True, |
| ) |
| output_ids[:, :num_input_tokens] = input_ids |
| output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature) |
| target_hidden = extract_context_feature(output.hidden_states, draft_model.target_layer_ids) |
|
|
| |
| acceptance_lengths = [] |
| start = num_input_tokens |
| while start < max_length: |
| block_output_ids = output_ids[:, start:start + block_size].clone() |
| block_position_ids = position_ids[:, start:start + block_size] |
| noise_embedding = target_model.model.embed_tokens(block_output_ids) |
|
|
| draft_logits = target_model.lm_head( |
| draft_model( |
| target_hidden=target_hidden, |
| noise_embedding=noise_embedding, |
| position_ids=position_ids[:, past_key_values_draft.get_seq_length():start + block_size], |
| past_key_values=past_key_values_draft, |
| use_cache=True, |
| is_causal=False, |
| )[:, -block_size + 1:, :] |
| ) |
| past_key_values_draft.crop(start) |
| block_output_ids[:, 1:] = sample(draft_logits) |
|
|
| output = target_model( |
| block_output_ids, |
| position_ids=block_position_ids, |
| past_key_values=past_key_values_target, |
| use_cache=True, |
| output_hidden_states=True, |
| ) |
|
|
| posterior = sample(output.logits, temperature) |
| acceptance_length = ( |
| (block_output_ids[:, 1:] == posterior[:, :-1]) |
| .cumprod(dim=1).sum(dim=1)[0].item() |
| ) |
| output_ids[:, start:start + int(acceptance_length) + 1] = block_output_ids[:, :int(acceptance_length) + 1] |
| output_ids[:, start + int(acceptance_length) + 1] = posterior[:, int(acceptance_length)] |
| start += int(acceptance_length) + 1 |
| past_key_values_target.crop(start) |
| target_hidden = extract_context_feature( |
| output.hidden_states, draft_model.target_layer_ids |
| )[:, :int(acceptance_length) + 1, :] |
| acceptance_lengths.append(int(acceptance_length) + 1) |
|
|
| if stop_token_ids is not None and any( |
| sid in output_ids[:, num_input_tokens:start] for sid in stop_token_ids |
| ): |
| break |
|
|
| output_ids = output_ids[:, :max_length] |
| output_ids = output_ids[:, output_ids[0] != mask_token_id] |
| if stop_token_ids is not None: |
| stop_t = torch.tensor(stop_token_ids, device=output_ids.device) |
| stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0] |
| if stop_idx.numel() > 0: |
| output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1] |
|
|
| return output_ids, acceptance_lengths |
|
|
|
|
| |
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--base-model", default=BASE_MODEL) |
| p.add_argument("--draft-model", default=DRAFT_MODEL) |
| p.add_argument("--max-new-tokens", type=int, default=512) |
| p.add_argument("--temperature", type=float, default=0.0) |
| p.add_argument("--benchmarks", nargs="+", default=["humaneval", "mtbench", "gsm8k"]) |
| p.add_argument("--num-samples", type=int, default=None) |
| p.add_argument("--output-dir", default=RESULT_DIR) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
|
| if world_size > 1: |
| dist.init_process_group(backend="nccl") |
| torch.cuda.set_device(local_rank) |
|
|
| device = f"cuda:{local_rank}" |
| rank = get_rank() |
|
|
| print_rank0(f"Running DFlash-b16 baseline on {world_size} GPU(s)") |
|
|
| |
| print_rank0(f"Loading target: {args.base_model}") |
| target_model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, |
| torch_dtype=torch.bfloat16, |
| device_map=device, |
| trust_remote_code=True, |
| ) |
| target_model.eval() |
|
|
| print_rank0(f"Loading DFlash-b16 draft: {args.draft_model}") |
| draft_model = AutoModel.from_pretrained( |
| args.draft_model, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| ).to(device) |
| draft_model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) |
| stop_token_ids = [tokenizer.eos_token_id] |
|
|
| print_rank0(f"DFlash-b16: block_size={draft_model.block_size}, " |
| f"target_layer_ids={draft_model.target_layer_ids}, " |
| f"num_layers={len(draft_model.layers)}") |
|
|
| |
| results = {"model": "Qwen3-8B-DFlash-b16", "type": "baseline", |
| "block_size": draft_model.block_size} |
|
|
| for bench_name in args.benchmarks: |
| print_rank0(f"\n{'='*60}") |
| print_rank0(f"Benchmark: {bench_name} ({world_size} GPUs)") |
| print_rank0(f"{'='*60}") |
|
|
| all_prompts = load_prompts(bench_name, args.num_samples) |
| my_prompts = split_list(all_prompts, rank, world_size) |
| print_rank0(f"Total {len(all_prompts)} prompts, ~{len(my_prompts)} per GPU") |
|
|
| local_accept_lengths = [] |
| local_tokens = 0 |
| t0 = time.time() |
|
|
| iterator = tqdm(my_prompts, desc=f"[GPU{rank}] {bench_name}", unit="sample", |
| disable=(rank != 0)) |
| for prompt in iterator: |
| messages = [{"role": "user", "content": prompt}] |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) |
|
|
| output_ids, accept_lens = spec_generate_b16( |
| draft_model=draft_model, |
| target_model=target_model, |
| input_ids=input_ids, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| stop_token_ids=stop_token_ids, |
| ) |
|
|
| local_accept_lengths.extend(accept_lens) |
| num_gen = output_ids.shape[1] - input_ids.shape[1] |
| local_tokens += num_gen |
|
|
| if rank == 0 and len(local_accept_lengths) > 0: |
| avg = sum(local_accept_lengths) / len(local_accept_lengths) |
| iterator.set_postfix(accept_len=f"{avg:.2f}", tokens=local_tokens, gen=num_gen) |
|
|
| elapsed = time.time() - t0 |
|
|
| |
| if world_size > 1: |
| local_sum = torch.tensor(sum(local_accept_lengths), dtype=torch.float64, device=device) |
| local_count = torch.tensor(len(local_accept_lengths), dtype=torch.long, device=device) |
| local_tok = torch.tensor(local_tokens, dtype=torch.long, device=device) |
| dist.all_reduce(local_sum, op=dist.ReduceOp.SUM) |
| dist.all_reduce(local_count, op=dist.ReduceOp.SUM) |
| dist.all_reduce(local_tok, op=dist.ReduceOp.SUM) |
| total_accept_sum = local_sum.item() |
| total_count = local_count.item() |
| total_tokens = local_tok.item() |
| else: |
| total_accept_sum = sum(local_accept_lengths) |
| total_count = len(local_accept_lengths) |
| total_tokens = local_tokens |
|
|
| avg_accept_length = total_accept_sum / max(total_count, 1) |
| throughput = total_tokens / elapsed if elapsed > 0 else 0 |
|
|
| print_rank0(f"\n{bench_name} Results:") |
| print_rank0(f" Avg Accept Length: {avg_accept_length:.3f}") |
| print_rank0(f" Total tokens: {total_tokens}") |
| print_rank0(f" Latency: {elapsed:.1f}s") |
| print_rank0(f" Throughput: {throughput:.1f} tok/s (aggregate {world_size} GPUs)") |
| print_rank0(f" Num verify rounds: {total_count}") |
| print_rank0(f" Num samples: {len(all_prompts)}") |
|
|
| results[bench_name] = { |
| "avg_accept_length": avg_accept_length, |
| "total_tokens": total_tokens, |
| "latency": elapsed, |
| "throughput": throughput, |
| "num_samples": len(all_prompts), |
| "num_verify_rounds": total_count, |
| "num_gpus": world_size, |
| } |
|
|
| |
| if is_main(): |
| os.makedirs(args.output_dir, exist_ok=True) |
| timestamp = time.strftime("%Y%m%d_%H%M%S") |
| result_file = os.path.join( |
| args.output_dir, |
| f"dflash_b16_baseline_offline_{timestamp}.json", |
| ) |
| with open(result_file, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"\nResults saved to: {result_file}") |
|
|
| if world_size > 1: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|