Hanrui / syxin /eval_dflash_b16_baseline.py
Lekr0's picture
Add files using upload-large-folder tool
7c50656 verified
#!/usr/bin/env python3
"""
Offline evaluation for DFlash-b16 baseline: measure accepted length.
8 GPUs parallel, each GPU loads target + draft independently.
Usage:
# 8 GPUs
torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py
# quick test
torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py --num-samples 20
# single GPU
python3 eval_dflash_b16_baseline.py --benchmarks humaneval
"""
import argparse
import json
import os
import sys
import time
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.distributed as dist
from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache
# Add DFlash model path so we can import utils
sys.path.insert(0, "/workspace/models/Qwen3-8B-DFlash-b16")
from utils import extract_context_feature, sample
# ──────────────────────────────────────────────────────────────────
BASE_MODEL = "/workspace/models/Qwen3-8B"
DRAFT_MODEL = "/workspace/models/Qwen3-8B-DFlash-b16"
RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results"
# ──────────────────────────────────────────────────────────────────
# Distributed helpers
# ──────────────────────────────────────────────────────────────────
def is_distributed():
return dist.is_available() and dist.is_initialized()
def get_rank():
return dist.get_rank() if is_distributed() else 0
def get_world_size():
return dist.get_world_size() if is_distributed() else 1
def is_main():
return get_rank() == 0
def print_rank0(*args, **kwargs):
if is_main():
print(*args, **kwargs)
def split_list(lst, rank, world_size):
return [x for i, x in enumerate(lst) if i % world_size == rank]
# ──────────────────────────────────────────────────────────────────
# Prompts
# ──────────────────────────────────────────────────────────────────
def load_prompts(bench_name: str, num_samples: Optional[int] = None) -> List[str]:
local_paths = {
"humaneval": "/workspace/hanrui/datasets/humaneval/test.jsonl",
"mtbench": "/workspace/hanrui/datasets/mtbench/question.jsonl",
"gsm8k": "/workspace/hanrui/datasets/gsm8k/test.jsonl",
}
prompts = []
path = local_paths.get(bench_name)
if path and os.path.exists(path):
with open(path) as f:
for line in f:
item = json.loads(line)
if bench_name == "humaneval":
p = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{item['prompt']}\n```"
elif bench_name == "mtbench":
p = item.get("turns", [item.get("prompt", "")])[0]
elif bench_name == "gsm8k":
p = item["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}."
else:
p = str(item)
prompts.append(p)
else:
from datasets import load_dataset
if bench_name == "humaneval":
ds = load_dataset("openai/openai_humaneval", split="test")
prompts = [f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{x['prompt']}\n```" for x in ds]
elif bench_name == "mtbench":
ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
prompts = [x["prompt"][0] for x in ds]
elif bench_name == "gsm8k":
ds = load_dataset("openai/gsm8k", "main", split="test")
prompts = [x["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." for x in ds]
if num_samples is not None:
prompts = prompts[:num_samples]
return prompts
# ──────────────────────────────────────────────────────────────────
# spec_generate with acceptance_lengths returned
# (Same logic as DFlashDraftModel.spec_generate but returns accept lens)
# ──────────────────────────────────────────────────────────────────
@torch.inference_mode()
def spec_generate_b16(
draft_model,
target_model: nn.Module,
input_ids: torch.LongTensor,
max_new_tokens: int = 512,
temperature: float = 0.0,
stop_token_ids: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, List[int]]:
"""Same as DFlashDraftModel.spec_generate but also returns acceptance_lengths."""
draft_model.eval()
device = target_model.device if hasattr(target_model, 'device') else input_ids.device
num_input_tokens = input_ids.shape[1]
max_length = num_input_tokens + max_new_tokens
block_size = draft_model.block_size
mask_token_id = draft_model.mask_token_id
output_ids = torch.full(
(1, max_length + block_size), mask_token_id,
dtype=torch.long, device=device,
)
position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
past_key_values_target = DynamicCache()
past_key_values_draft = DynamicCache()
# Prefill
output = target_model(
input_ids,
position_ids=position_ids[:, :num_input_tokens],
past_key_values=past_key_values_target,
use_cache=True,
logits_to_keep=1,
output_hidden_states=True,
)
output_ids[:, :num_input_tokens] = input_ids
output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature)
target_hidden = extract_context_feature(output.hidden_states, draft_model.target_layer_ids)
# Decode
acceptance_lengths = []
start = num_input_tokens
while start < max_length:
block_output_ids = output_ids[:, start:start + block_size].clone()
block_position_ids = position_ids[:, start:start + block_size]
noise_embedding = target_model.model.embed_tokens(block_output_ids)
draft_logits = target_model.lm_head(
draft_model(
target_hidden=target_hidden,
noise_embedding=noise_embedding,
position_ids=position_ids[:, past_key_values_draft.get_seq_length():start + block_size],
past_key_values=past_key_values_draft,
use_cache=True,
is_causal=False,
)[:, -block_size + 1:, :]
)
past_key_values_draft.crop(start)
block_output_ids[:, 1:] = sample(draft_logits)
output = target_model(
block_output_ids,
position_ids=block_position_ids,
past_key_values=past_key_values_target,
use_cache=True,
output_hidden_states=True,
)
posterior = sample(output.logits, temperature)
acceptance_length = (
(block_output_ids[:, 1:] == posterior[:, :-1])
.cumprod(dim=1).sum(dim=1)[0].item()
)
output_ids[:, start:start + int(acceptance_length) + 1] = block_output_ids[:, :int(acceptance_length) + 1]
output_ids[:, start + int(acceptance_length) + 1] = posterior[:, int(acceptance_length)]
start += int(acceptance_length) + 1
past_key_values_target.crop(start)
target_hidden = extract_context_feature(
output.hidden_states, draft_model.target_layer_ids
)[:, :int(acceptance_length) + 1, :]
acceptance_lengths.append(int(acceptance_length) + 1)
if stop_token_ids is not None and any(
sid in output_ids[:, num_input_tokens:start] for sid in stop_token_ids
):
break
output_ids = output_ids[:, :max_length]
output_ids = output_ids[:, output_ids[0] != mask_token_id]
if stop_token_ids is not None:
stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
if stop_idx.numel() > 0:
output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
return output_ids, acceptance_lengths
# ──────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--base-model", default=BASE_MODEL)
p.add_argument("--draft-model", default=DRAFT_MODEL)
p.add_argument("--max-new-tokens", type=int, default=512)
p.add_argument("--temperature", type=float, default=0.0)
p.add_argument("--benchmarks", nargs="+", default=["humaneval", "mtbench", "gsm8k"])
p.add_argument("--num-samples", type=int, default=None)
p.add_argument("--output-dir", default=RESULT_DIR)
return p.parse_args()
def main():
args = parse_args()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
device = f"cuda:{local_rank}"
rank = get_rank()
print_rank0(f"Running DFlash-b16 baseline on {world_size} GPU(s)")
# ── Load models ──
print_rank0(f"Loading target: {args.base_model}")
target_model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True,
)
target_model.eval()
print_rank0(f"Loading DFlash-b16 draft: {args.draft_model}")
draft_model = AutoModel.from_pretrained(
args.draft_model,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(device)
draft_model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
stop_token_ids = [tokenizer.eos_token_id]
print_rank0(f"DFlash-b16: block_size={draft_model.block_size}, "
f"target_layer_ids={draft_model.target_layer_ids}, "
f"num_layers={len(draft_model.layers)}")
# ── Run benchmarks ──
results = {"model": "Qwen3-8B-DFlash-b16", "type": "baseline",
"block_size": draft_model.block_size}
for bench_name in args.benchmarks:
print_rank0(f"\n{'='*60}")
print_rank0(f"Benchmark: {bench_name} ({world_size} GPUs)")
print_rank0(f"{'='*60}")
all_prompts = load_prompts(bench_name, args.num_samples)
my_prompts = split_list(all_prompts, rank, world_size)
print_rank0(f"Total {len(all_prompts)} prompts, ~{len(my_prompts)} per GPU")
local_accept_lengths = []
local_tokens = 0
t0 = time.time()
iterator = tqdm(my_prompts, desc=f"[GPU{rank}] {bench_name}", unit="sample",
disable=(rank != 0))
for prompt in iterator:
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
output_ids, accept_lens = spec_generate_b16(
draft_model=draft_model,
target_model=target_model,
input_ids=input_ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
stop_token_ids=stop_token_ids,
)
local_accept_lengths.extend(accept_lens)
num_gen = output_ids.shape[1] - input_ids.shape[1]
local_tokens += num_gen
if rank == 0 and len(local_accept_lengths) > 0:
avg = sum(local_accept_lengths) / len(local_accept_lengths)
iterator.set_postfix(accept_len=f"{avg:.2f}", tokens=local_tokens, gen=num_gen)
elapsed = time.time() - t0
# ── Gather ──
if world_size > 1:
local_sum = torch.tensor(sum(local_accept_lengths), dtype=torch.float64, device=device)
local_count = torch.tensor(len(local_accept_lengths), dtype=torch.long, device=device)
local_tok = torch.tensor(local_tokens, dtype=torch.long, device=device)
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(local_count, op=dist.ReduceOp.SUM)
dist.all_reduce(local_tok, op=dist.ReduceOp.SUM)
total_accept_sum = local_sum.item()
total_count = local_count.item()
total_tokens = local_tok.item()
else:
total_accept_sum = sum(local_accept_lengths)
total_count = len(local_accept_lengths)
total_tokens = local_tokens
avg_accept_length = total_accept_sum / max(total_count, 1)
throughput = total_tokens / elapsed if elapsed > 0 else 0
print_rank0(f"\n{bench_name} Results:")
print_rank0(f" Avg Accept Length: {avg_accept_length:.3f}")
print_rank0(f" Total tokens: {total_tokens}")
print_rank0(f" Latency: {elapsed:.1f}s")
print_rank0(f" Throughput: {throughput:.1f} tok/s (aggregate {world_size} GPUs)")
print_rank0(f" Num verify rounds: {total_count}")
print_rank0(f" Num samples: {len(all_prompts)}")
results[bench_name] = {
"avg_accept_length": avg_accept_length,
"total_tokens": total_tokens,
"latency": elapsed,
"throughput": throughput,
"num_samples": len(all_prompts),
"num_verify_rounds": total_count,
"num_gpus": world_size,
}
# ── Save ──
if is_main():
os.makedirs(args.output_dir, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
result_file = os.path.join(
args.output_dir,
f"dflash_b16_baseline_offline_{timestamp}.json",
)
with open(result_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {result_file}")
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
main()