diff --git a/SpecForge-ext/benchmarks/README.md b/SpecForge-ext/benchmarks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..678ce7a257b8bddb2a44373a4eba1ff77595813a --- /dev/null +++ b/SpecForge-ext/benchmarks/README.md @@ -0,0 +1,67 @@ +# Benchmarking for Speculative Decoding + +## Overview + +We provided a unified script to test the performance of the Speculative Decoding with EAGLE3 algorithm on multiple datasets. You can follow the steps below to run the benchmarks. + +## Run Benchmarks + +### Launch SGLang and Benchmarker Concurrently + +`bench_eagle3.py` can help you launch a SGLang server process and a Benchmarking process concurrently. In this way, you don't have to launch the SGLang server manually, this script will manually handle the SGLang launch under different speculative decoding configurations. Some important arguments are: +- `--model-path`: the path to the target model. +- `--speculative-draft-model-path`: the path to the draft model. +- `--port`: the port to launch the SGLang server. +- `--trust-remote-code`: trust the remote code. +- `--mem-fraction-static`: the memory fraction for the static memory. +- `--tp-size`: the tensor parallelism size. +- `--attention-backend`: the attention backend. +- `--config-list`: the list of speculative decoding configuration to test, the format is `,,,`. +- `--benchmark-list`: the list of benchmarks to test, the format is `::`. + +```shell +python3 bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --port 30000 \ + --trust-remote-code \ + --mem-fraction-static 0.8 \ + --tp-size 1 \ + --attention-backend fa3 \ + --config-list 1,0,0,0 1,3,1,4 \ + --benchmark-list mtbench gsm8k:5 ceval:5:accountant \ + --dtype bfloat16 +``` + +### Launch Benchmarker Independently + +If you want to launch the SGLang server independently, you can use the following command. + +```shell +# you can launch a server +python3 -m sglang.launch_server \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --mem-fraction-static 0.75 \ + --cuda-graph-max-bs 1 \ + --tp 1 \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 30000 \ + --dtype bfloat16 +``` + +Then we can start benchmarking. Note that you should use the same host and port as the one used in the SGLang server. Note that `--skip-launch-server` is required to skip the launch of the SGLang server. + +```bash +python bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --port 30000 \ + --config-list 1,3,1,4 \ + --benchmark-list mtbench:5 ceval:5:accountant gsm8k:5 humaneval:5 math500:5 mtbench:5 aime:1 \ + --skip-launch-server +``` diff --git a/SpecForge-ext/benchmarks/__init__.py b/SpecForge-ext/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfec7eeb81271e0d07feafe767881dba5dcf4acd --- /dev/null +++ b/SpecForge-ext/benchmarks/__init__.py @@ -0,0 +1,3 @@ +""" +Benchmark scripts for speculative decoding evaluation. +""" diff --git a/SpecForge-ext/benchmarks/bench_eagle3.py b/SpecForge-ext/benchmarks/bench_eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..988e108f5e1f8ce82e9ccbeaf1b77a5d741fa816 --- /dev/null +++ b/SpecForge-ext/benchmarks/bench_eagle3.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +""" +Usage: + +# if you want to run benchmarks directly +# mtbench:20 means only run 20 samples in the dataset +python bench_eagle3.py \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --port 30000 \ + --config-list 1,0,0,0 1,3,1,4 \ + --benchmark-list mtbench:20 \ + --dtype bfloat16 + + +or if you want run sglang alone. + +# launch sglang +python3 -m sglang.launch_server \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --mem-fraction-static 0.75 \ + --cuda-graph-max-bs 1 \ + --tp 1 \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 30000 \ + --dtype bfloat16 + +# then run benchmarks +python bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --port 30000 \ + --config-list 1,0,0,0 \ + --benchmark-list mtbench:80 \ + --dtype bfloat16 \ + --skip-launch-server +""" +import argparse +import json +import os +import time +from dataclasses import asdict +from typing import List + +import requests +from benchmarker import BENCHMARKS +from sglang.srt.server_args import ServerArgs +from sglang.test.test_utils import kill_process_tree, popen_launch_server +from sglang.utils import wait_for_server + + +def parse_args(): + parser = argparse.ArgumentParser() + sglang_group = parser.add_argument_group("sglang") + ServerArgs.add_cli_args(sglang_group) + + # make the follow args a group + benchmark_group = parser.add_argument_group("benchmark") + benchmark_group.add_argument( + "--skip-launch-server", action="store_true", default=False + ) + benchmark_group.add_argument("--timeout-for-server-launch", type=int, default=600) + benchmark_group.add_argument("--num-prompts", type=int, default=80) + benchmark_group.add_argument("--output-dir", type=str, default="./results") + benchmark_group.add_argument( + "--config-list", type=str, nargs="+", default=["1,0,0,0", "1,3,1,4"] + ) + benchmark_group.add_argument( + "--name", + type=str, + default=None, + help="name of this benchmark run, if provided, will be added to the output file name", + ) + benchmark_group.add_argument( + "--benchmark-list", + type=str, + nargs="+", + default=[ + "mtbench:80", + "gsm8k:200", + "humaneval:200", + "math500:200", + "ceval:200", + ], + help=f"The list of benchmarks to run. The format is ::,. We support the following benchmarks: {', '.join(BENCHMARKS.benchmarks.keys())}", + ) + benchmark_group.add_argument( + "--enable-multi-turn-conversation", + action="store_true", + default=False, + ) + return parser.parse_args() + + +def launch_sglang_server( + server_args: ServerArgs, + base_url: str, + batch_size: int, + steps: int, + topk: int, + num_draft_tokens: int, + timeout: int, +): + """ + This function launches the SGLang server with the given server arguments. + """ + sglang_args: List[str] = [] + if steps > 0: + sglang_args.extend( + [ + "--speculative-algorithm", + "EAGLE3", + "--speculative-num-steps", + str(steps), + "--speculative-eagle-topk", + str(topk), + "--speculative-num-draft-tokens", + str(num_draft_tokens), + "--speculative-draft-model-path", + server_args.speculative_draft_model_path, + ] + ) + + sglang_args.extend( + [ + "--cuda-graph-max-bs", + str(batch_size), + "--mem-fraction-static", + str(server_args.mem_fraction_static), + "--tp-size", + str(server_args.tp_size), + "--max-running-requests", + str(batch_size), + ] + ) + + if server_args.trust_remote_code: + sglang_args.extend(["--trust-remote-code"]) + + if server_args.disable_radix_cache: + sglang_args.extend(["--disable-radix-cache"]) + + if server_args.ep_size: + sglang_args.extend(["--ep-size", str(server_args.ep_size)]) + + if server_args.attention_backend: + sglang_args.extend(["--attention-backend", server_args.attention_backend]) + + if server_args.quantization: + sglang_args.extend(["--quantization", server_args.quantization]) + + if server_args.dtype: + sglang_args.extend(["--dtype", server_args.dtype]) + + process = popen_launch_server( + server_args.model_path, + base_url, + timeout=timeout, + other_args=sglang_args, + env={ + "SGLANG_RECORD_STEP_TIME": "1", + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1", + **os.environ, + }, + ) + return process + + +def send_flush_cache_request(base_url: str): + requests.post(base_url + "/flush_cache") + + +def main(): + args = parse_args() + server_args: ServerArgs = ServerArgs.from_cli_args(args) + configs = [tuple(map(int, config.split(","))) for config in args.config_list] + + # split the arg into list of (bench_name, num_prompts) + benchmark_list = [] + for item in args.benchmark_list: + splits = item.split(":") + if len(splits) == 1: + bench_name = splits[0] + num_prompts = None + subset = None + elif len(splits) == 2: + bench_name, num_prompts = splits + subset = None + elif len(splits) == 3: + bench_name, num_prompts, subset = splits + subset = subset.split(",") + else: + raise ValueError(f"Invalid benchmark list format: {item}") + benchmark_list.append((bench_name, num_prompts, subset)) + assert len(benchmark_list) != 0, "the number of benchmark list is 0" + + base_url = f"http://localhost:{args.port}" + + results = {} + results["model"] = server_args.speculative_draft_model_path + + def run_benchmarks(batch_size: int, steps: int, topk: int, num_draft_tokens: int): + for benchmark_name, num_prompts, subset in benchmark_list: + print( + f"Running benchmark {benchmark_name} with {num_prompts} prompts, batch size {batch_size}, steps {steps}, topk {topk}, num_draft_tokens {num_draft_tokens}, subset {subset}" + ) + benchmarkder_cls = BENCHMARKS.get(benchmark_name) + num_prompts = int(num_prompts) if num_prompts is not None else None + if subset is None: + benchmarker = benchmarkder_cls(num_samples=num_prompts) + else: + benchmarker = benchmarkder_cls(num_samples=num_prompts, subset=subset) + metrics_list = benchmarker.run( + host=args.host, port=args.port, batch_size=batch_size + ) + send_flush_cache_request(base_url) + if benchmark_name not in results: + results[benchmark_name] = [] + results[benchmark_name].append( + dict( + batch_size=batch_size, + steps=steps, + topk=topk, + num_draft_tokens=num_draft_tokens, + metrics=[asdict(metric) for metric in metrics_list], + num_samples=num_prompts, + ) + ) + + if args.skip_launch_server: + batch_size = configs[0][0] if len(configs) > 0 else 8 + run_benchmarks(batch_size, None, None, None) + else: + # we itearate over each config from args + for batch_size, steps, topk, num_draft_tokens in configs: + process = launch_sglang_server( + server_args, + base_url, + batch_size, + steps, + topk, + num_draft_tokens, + args.timeout_for_server_launch, + ) + wait_for_server(base_url) + run_benchmarks(batch_size, steps, topk, num_draft_tokens) + kill_process_tree(process.pid) + process.wait() + + os.makedirs(args.output_dir, exist_ok=True) + timestamp = time.strftime("%Y%m%d_%H%M%S") + result_file = os.path.join( + args.output_dir, + f"{args.name + '_' if args.name else ''}results_{timestamp}.jsonl", + ) + with open(result_file, "w") as f: + json.dump(results, f, indent=4) + print(f"Results saved to {result_file}") + + +if __name__ == "__main__": + main() diff --git a/SpecForge-ext/benchmarks/benchmarker/__init__.py b/SpecForge-ext/benchmarks/benchmarker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e37fb99c2caa75be8ab58dc51f393f9a748c2b7 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/__init__.py @@ -0,0 +1,29 @@ +from .aime import AIMEBenchmarker +from .ceval import CEvalBenchmarker +from .financeqa import FinanceQABenchmarker +from .gpqa import GPQABenchmarker +from .gsm8k import GSM8KBenchmarker +from .humaneval import HumanEvalBenchmarker +from .livecodebench import LCBBenchmarker +from .math500 import Math500Benchmarker +from .mmlu import MMLUBenchmarker +from .mmstar import MMStarBenchmarker +from .mtbench import MTBenchBenchmarker +from .registry import BENCHMARKS +from .simpleqa import SimpleQABenchmarker + +__all__ = [ + "BENCHMARKS", + "AIMEBenchmarker", + "CEvalBenchmarker", + "GSM8KBenchmarker", + "HumanEvalBenchmarker", + "Math500Benchmarker", + "MTBenchBenchmarker", + "MMStarBenchmarker", + "GPQABenchmarker", + "FinanceQABenchmarker", + "MMLUBenchmarker", + "LCBBenchmarker", + "SimpleQABenchmarker", +] diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de7dade82abb7f1d9cfcf39e3574ee3f6d791c9b Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54726f9f9a03ca4654cd8f5a00db780906867433 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-312.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68e3e34578078bd90cedae80d79e5f43e0c509b6 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-312.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..842cab50e3e231d047fd1d25065797f883631634 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ae25449dbab86276b82aa3f3e26f577056dc71a Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-312.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..675c10a9cfe3c44b63f9cc388f7164226901c40c Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-312.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1b712747c91d8601186115c16f86435b1fdf62e Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b66ee0c37dc0aa8429afc963948decbe7e7dfd4 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-312.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac412c8ce303f09971083923a1891b359aee72ff Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-312.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3d90334141894bf535a097768506570fff983b3 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f52ed3a10bbd10586a4af234c04039442022b59 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42ee0110411dcf57daab9faefa07f6e6614bcd27 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f81109f287bf2916635065ef65ada71f6ee60541 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49032422160b6947193140c31ad4a73a72e60287 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..073e41ca7bc5fbcfc7a00a65e7eeb06fd134766f Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50b2f9c20e191c4c6ae9f7515e8ec0fd1f76aa80 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3f4859e065cd56f17f549f2c46010b48ac4dd89 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2909d2ea030fc7b42ed4201f6179be1b0c675356 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9311dedbedfd922545f96844684f46b955904180 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2390d4ebb1dee85aca13ba3c73ee558dbbfa7322 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0baba898f7a8d9b4f121f137f2f40dd4c70aa7c Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01195ef048e0ab96b0ed8fe92fcbf544dbfbffe3 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..574ebda8c12f2d93d1d26368ad6eaf32d08f5b94 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06ff62ca76cf378abc3a5f72c1edb4321e2c666f Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f97f903d4a848d26c287012def10915795f24ce Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4762b87c10f9b1b869fcc48c4b07a8cccddbcb4c Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..827c77551ef80485be4e5d31ea76a2968eea2dba Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6194deaa7df34e9bcd10e2430f38eb5d380d1fea Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74f6d41d76925a7e8b3c7328967f2658864dfbf0 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eee9679164ef34a1e5a94c282df5b191c535037 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beac186cdb97bf776412e69d6a0fe4c4839bea27 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c4d643f3aea3911a157545b19876a2ff1d43f13 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-310.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32a4c020775c95bb8ae4852574bf5b3b65ced13f Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-310.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef5281987b4997aabc5fdd4205e042fe5f117a51 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-311.pyc differ diff --git a/SpecForge-ext/benchmarks/benchmarker/aime.py b/SpecForge-ext/benchmarks/benchmarker/aime.py new file mode 100644 index 0000000000000000000000000000000000000000..fba473c2c6d2f27397f67b424423271288cd6ae7 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/aime.py @@ -0,0 +1,133 @@ +""" +AIME benchmark +""" + +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_aime_answer(output: str) -> Optional[str]: + """Extract final answer from AIME problem solution. + + AIME answers are typically integers between 0 and 999, and are usually + in \boxed{} format. + """ + # Try to find answer in \boxed{} format + boxed_pattern = r"\\boxed\{([^}]+)\}" + match = re.search(boxed_pattern, output) + if match: + answer = match.group(1).strip() + # Extract number from the boxed content + numbers = re.findall(r"\d+", answer) + if numbers: + return numbers[-1] # Take the last number (usually the final answer) + return answer + + # Try to find answer in \boxed format (without braces) + boxed_pattern2 = r"\\boxed\s+(\d+)" + match = re.search(boxed_pattern2, output) + if match: + return match.group(1).strip() + + # Look for patterns like "The answer is 42" or "Answer: 123" + answer_patterns = [ + r"(?:answer|Answer|ANSWER)[\s:]+(\d+)", + r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)", + r"(?:is|equals?|=\s*)(\d+)\s*$", + ] + for pattern in answer_patterns: + matches = re.findall(pattern, output, re.IGNORECASE) + if matches: + return matches[-1].strip() + + # Fallback: extract the last integer in the text + numbers = re.findall(r"\b(\d+)\b", output) + if numbers: + # Filter to reasonable AIME answer range (0-999) + valid_numbers = [n for n in numbers if 0 <= int(n) <= 999] + if valid_numbers: + return valid_numbers[-1] + + return None + + +@BENCHMARKS.register("aime") +class AIMEBenchmarker(Benchmarker): + """AIME benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]: + """Load and preprocess AIME dataset.""" + dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"] + questions = [] + labels = [] + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + questions.append({"question": q["Problem"]}) + # Extract answer from Answer field + answer = None + if "Answer" in q: + answer = str(q["Answer"]).strip() + elif "answer" in q: + answer = str(q["answer"]).strip() + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract answer from model output.""" + return extract_aime_answer(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for AIME by comparing numeric answers.""" + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + for pred, label in zip(predictions, labels): + if label is not None: + valid_count += 1 + if pred is not None: + # Normalize answers for comparison + pred_normalized = str(pred).strip() + label_normalized = str(label).strip() + # Try exact match first + if pred_normalized == label_normalized: + correct += 1 + else: + # Try numeric comparison + try: + pred_num = int(pred_normalized) + label_num = int(label_normalized) + if pred_num == label_num: + correct += 1 + except ValueError: + pass + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for AIME with reasoning prompt.""" + return create_simple_sgl_function( + function_name="reasoning_gen", + answer_key="answer", + user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.", + ) + + def get_max_new_tokens(self) -> int: + """AIME problems require more tokens.""" + return 32768 diff --git a/SpecForge-ext/benchmarks/benchmarker/base.py b/SpecForge-ext/benchmarks/benchmarker/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f8da625319cc854688521b8d9bf1a4b98ac5006b --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/base.py @@ -0,0 +1,218 @@ +""" +Base class for benchmark implementations. +""" + +import time +from abc import ABC, abstractmethod +from argparse import Namespace +from typing import Any, Callable, Dict, List, Optional, Tuple + +from sglang import set_default_backend +from sglang.test.test_utils import select_sglang_backend + +from .utils import compute_metrics + + +class Benchmarker(ABC): + """ + Base class for benchmark implementations. + + Subclasses should implement: + - load_data(): Load and preprocess dataset + - create_sgl_function(): Create the SGL function for inference + + Optional overrides: + - extract_answer(): Extract answer from model output (if needed) + - compute_accuracy(): Compute accuracy metric (if applicable) + - get_answer_keys(): Get list of answer keys for multi-turn conversations + + Args: + num_samples: The number of samples to run the benchmark on. If not provided, all questions will be used. + subset: The subset of the dataset to run the benchmark on. If not provided, all subsets will be used. + """ + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + self.num_samples = num_samples + self.subset = subset + + @abstractmethod + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Any]]: + """ + Load and preprocess the dataset. + + Returns: + Tuple of (questions, labels) where: + - questions: List of question dicts for SGL function + - labels: List of ground truth labels (can be None if not applicable) + """ + raise NotImplementedError + + @abstractmethod + def create_sgl_function(self) -> Callable: + """ + Create the SGL function for inference. + + Returns: + SGL function decorated with @sgl.function + """ + raise NotImplementedError + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[Any]: + """ + Extract answer from model output. + + Args: + output: Raw model output string + label: Optional ground truth label for reference + + Returns: + Extracted answer, or None if extraction fails + """ + return output + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """ + Compute accuracy metric. + + Args: + predictions: List of predicted answers + labels: List of ground truth labels + + Returns: + Accuracy score (0-1), or None if not applicable + """ + return None + + def get_answer_keys(self) -> Optional[List[str]]: + """ + Get list of answer keys for multi-turn conversations. + + Returns: + List of answer keys (e.g., ["answer_1", "answer_2"]), or None for single-turn + """ + return None + + def get_max_new_tokens(self) -> int: + """ + Get maximum number of new tokens to generate. + + Returns: + Maximum tokens (default: 2048) + """ + return 2048 + + def run( + self, + host: str, + port: int, + batch_size: int, + max_new_tokens: int = None, + num_runs: int = 1, + ): + """ + Run the benchmark evaluation. + + This method handles the common workflow: + 1. Initialize backend + 2. Load data + 3. Create SGL function + 4. Run inference loops + 5. Compute metrics + 6. Print results + + Args: + host (str): The host of the SGLang server + port (int): The port of the SGLang server + batch_size (int): The number of prompts to process in parallel + num_samples (int): The number of samples to run the benchmark on. If not provided, all samples will be used. + max_new_tokens (int): Maximum number of new tokens to generate, default is 2048 + num_runs (int): The number of times to run this benchmark, default is 1. You can set it to a larger number if you want to get more stable results. + """ + if not host.startswith(("http://", "https://")): + host = f"http://{host}" + # Initialize backend + sglang_args = Namespace(host=host, port=port, backend="srt-no-parallel") + set_default_backend(select_sglang_backend(sglang_args)) + + # Load data + questions, labels = self.load_data() + if len(questions) == 0: + print("No valid questions found. Please check the dataset format.") + return + + # Create SGL function + sgl_function = self.create_sgl_function() + + # Run evaluation loops + metrics_list = [] + answer_keys = self.get_answer_keys() + max_new_tokens = max_new_tokens or self.get_max_new_tokens() + + for _ in range(num_runs): + tic = time.perf_counter() + states = sgl_function.run_batch( + questions, + temperature=0, + max_new_tokens=max_new_tokens, + num_threads=batch_size, + progress_bar=True, + ) + latency = time.perf_counter() - tic + + # Extract predictions + predictions = [] + primary_answer_key = answer_keys[0] if answer_keys else "answer" + for i in range(len(states)): + # Access answer from state object (states[i] supports dict-like access) + output = states[i][primary_answer_key] + if isinstance(output, str): + extracted = self.extract_answer( + output, + (labels[i] if labels and i < len(labels) else None), + ) + else: + extracted = output + predictions.append(extracted) + + # Compute accuracy if applicable + accuracy = None + # Check if we have a labels list (even if all labels are None) + has_labels_list = labels and len(labels) > 0 + + if has_labels_list: + # Always call compute_accuracy if we have a labels list + # This allows it to return None, which will be displayed in print_results + accuracy = self.compute_accuracy(predictions, labels) + if accuracy is not None: + valid_count = sum(1 for p in predictions if p is not None) + if valid_count < len(predictions): + print( + f"Warning: {len(predictions) - valid_count} predictions could not be extracted." + ) + + # Compute performance metrics + metrics = compute_metrics( + states, + latency, + answer_key=primary_answer_key, + additional_answer_keys=( + answer_keys[1:] if answer_keys and len(answer_keys) > 1 else None + ), + ) + # Always set accuracy if we have a labels list (even if compute_accuracy returns None) + # This allows print_results to show None when compute_accuracy returns None + if has_labels_list: + metrics.accuracy = ( + accuracy # Can be None if compute_accuracy returns None + ) + if accuracy is not None: + metrics.num_valid_predictions = sum( + 1 for p in predictions if p is not None + ) + + metrics_list.append(metrics) + return metrics_list diff --git a/SpecForge-ext/benchmarks/benchmarker/ceval.py b/SpecForge-ext/benchmarks/benchmarker/ceval.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b77ccbdb0deb5ce4d2c4522a157836cf0e6efb --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/ceval.py @@ -0,0 +1,267 @@ +""" +C-Eval benchmark evaluation script. +""" + +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import concatenate_datasets, load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_answer(answer_str: str) -> str: + """Extract the answer choice (A, B, C, D) from the model output.""" + # Try to find the answer in various formats + answer_str = answer_str.strip().upper() + + # Direct match for single letter + match = re.search(r"\b([ABCD])\b", answer_str) + if match: + return match.group(1) + + # Try to find answer in parentheses or brackets + for pattern in [ + r"\(([ABCD])\)", + r"\[([ABCD])\]", + r"答案[::]\s*([ABCD])", + r"Answer[::]\s*([ABCD])", + ]: + match = re.search(pattern, answer_str, re.IGNORECASE) + if match: + return match.group(1).upper() + + # Try to find the first occurrence of A, B, C, or D + match = re.search(r"([ABCD])", answer_str) + if match: + return match.group(1) + + return None + + +def format_question(question: str, options: List[str]) -> str: + """Format the question with options.""" + prompt = question + "\n\n选项:\n" + for i, option in enumerate(options): + prompt += f"{chr(65 + i)}. {option}\n" + prompt += "\n请从A、B、C、D中选择一个答案。" + return prompt + + +@BENCHMARKS.register("ceval") +class CEvalBenchmarker(Benchmarker): + """C-Eval benchmark implementation.""" + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + if subset is None: + subset = "all" + super().__init__(num_samples, subset) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[str]]: + """Load and preprocess C-Eval dataset.""" + all_configs = [ + "accountant", + "advanced_mathematics", + "art_studies", + "basic_medicine", + "business_administration", + "chinese_language_and_literature", + "civil_servant", + "clinical_medicine", + "college_chemistry", + "college_economics", + "college_physics", + "college_programming", + "computer_architecture", + "computer_network", + "discrete_mathematics", + "education_science", + "electrical_engineer", + "environmental_impact_assessment_engineer", + "fire_engineer", + "high_school_biology", + "high_school_chemistry", + "high_school_chinese", + "high_school_geography", + "high_school_history", + "high_school_mathematics", + "high_school_physics", + "high_school_politics", + "ideological_and_moral_cultivation", + "law", + "legal_professional", + "logic", + "mao_zedong_thought", + "marxism", + "metrology_engineer", + "middle_school_biology", + "middle_school_chemistry", + "middle_school_geography", + "middle_school_history", + "middle_school_mathematics", + "middle_school_physics", + "middle_school_politics", + "modern_chinese_history", + "operating_system", + "physician", + "plant_protection", + "probability_and_statistics", + "professional_tour_guide", + "sports_science", + "tax_accountant", + "teacher_qualification", + "urban_and_rural_planner", + "veterinary_medicine", + ] + + # Select configs to load + if self.subset == "all": + configs_to_load = all_configs + else: + for subset in self.subset: + assert ( + subset in all_configs + ), f"Subset {subset} not found in C-Eval dataset" + configs_to_load = self.subset + + # Load datasets + try: + datasets = [] + for config in configs_to_load: + try: + ds = load_dataset("ceval/ceval-exam", name=config, split="test") + datasets.append(ds) + print(f"Loaded config '{config}' with {len(ds)} samples") + except Exception as e: + print(f"Warning: Failed to load config '{config}': {e}") + if len(datasets) == 0: + raise ValueError("No configs could be loaded") + dataset = concatenate_datasets(datasets) + print( + f"Successfully loaded C-Eval dataset with all configs (total: {len(dataset)} samples)" + ) + except Exception as e: + print(e) + print(f"Failed to load C-Eval dataset from 'ceval/ceval-exam': {e}") + print("Please ensure the dataset is available or install it manually.") + print("You can try: pip install datasets") + print("Or download from: https://huggingface.co/datasets/ceval/ceval-exam") + return [], [] + + # Process questions + questions = [] + labels = [] + for idx, item in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + # Handle different dataset formats + question_text = None + if "question" in item: + question_text = item["question"] + elif "inputs" in item: + question_text = item["inputs"] + elif "problem" in item: + question_text = item["problem"] + elif "content" in item: + question_text = item["content"] + + if not question_text: + continue + + # Get options - C-Eval typically has options as a list or dict + options = None + if "options" in item: + options = item["options"] + if isinstance(options, dict): + # Convert dict to list in order A, B, C, D + options = [ + options.get("A", ""), + options.get("B", ""), + options.get("C", ""), + options.get("D", ""), + ] + elif isinstance(options, list): + # Ensure we have 4 options + while len(options) < 4: + options.append("") + elif "choices" in item: + options = item["choices"] + if isinstance(options, dict): + options = [ + options.get("A", ""), + options.get("B", ""), + options.get("C", ""), + options.get("D", ""), + ] + else: + # Try to construct options from A, B, C, D fields + options = [ + item.get("A", item.get("option_A", "")), + item.get("B", item.get("option_B", "")), + item.get("C", item.get("option_C", "")), + item.get("D", item.get("option_D", "")), + ] + + # Filter out empty options + if options: + options = [str(opt).strip() for opt in options if opt] + if len(options) < 2: # Need at least 2 options + continue + else: + continue + + # Get answer + answer = None + if "answer" in item: + answer = str(item["answer"]).upper().strip() + elif "target" in item: + answer = str(item["target"]).upper().strip() + elif "label" in item: + answer = str(item["label"]).upper().strip() + elif "correct" in item: + answer = str(item["correct"]).upper().strip() + + # Validate answer + if answer and answer in ["A", "B", "C", "D"]: + # Format question + formatted_question = format_question(question_text, options) + questions.append({"question": formatted_question}) + labels.append(answer) + + if len(questions) == 0: + print("No valid questions found. Please check the dataset format.") + print( + "Sample item keys:", + list(dataset[0].keys()) if len(dataset) > 0 else "No items", + ) + return [], [] + + return questions, labels + + def create_sgl_function(self): + """Create SGL function for C-Eval.""" + return create_simple_sgl_function( + function_name="get_ceval_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) + + def extract_answer(self, output: str, label: Any = None) -> str: + """Extract answer choice from model output.""" + return extract_answer(output) + + def compute_accuracy(self, predictions: List[str], labels: List[str]) -> float: + """Compute accuracy metric.""" + correct = 0 + valid_count = 0 + for i in range(len(predictions)): + if predictions[i] is not None: # Only count valid predictions + valid_count += 1 + if predictions[i] == labels[i]: + correct += 1 + return correct / valid_count if valid_count > 0 else 0.0 diff --git a/SpecForge-ext/benchmarks/benchmarker/financeqa.py b/SpecForge-ext/benchmarks/benchmarker/financeqa.py new file mode 100644 index 0000000000000000000000000000000000000000..9323b63423ba288edc79d2ecfb6a33d0a926af7c --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/financeqa.py @@ -0,0 +1,59 @@ +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + +QUESTION_PROMPT = """ +Given the following context: + +{context} + +Can you answer the following question? + +{question} +""".strip() + + +def generate_question(row: Dict[str, Any]) -> str: + if row["context"] is None: + return row["question"].strip() + else: + question = QUESTION_PROMPT.format( + context=row["context"].strip(), + question=row["question"].strip(), + ) + return question + + +@BENCHMARKS.register("financeqa") +class FinanceQABenchmarker(Benchmarker): + """FinanceQA benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + """Load and preprocess FinanceQA dataset.""" + # Read data + ds = load_dataset("AfterQuery/FinanceQA")["test"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(None) + return questions, labels + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_financeqa_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/SpecForge-ext/benchmarks/benchmarker/gpqa.py b/SpecForge-ext/benchmarks/benchmarker/gpqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e2add8fa835a076e51be350c9d95295e0f20bb31 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/gpqa.py @@ -0,0 +1,85 @@ +import random +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + +GPQA_QUERY_TEMPLATE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + + +def generate_question(row: Dict[str, Any]) -> str: + gold_index = random.randint(0, 3) + choices = [ + row["Incorrect Answer 1"], + row["Incorrect Answer 2"], + row["Incorrect Answer 3"], + ] + choices.insert(gold_index, row["Correct Answer"]) + + question = GPQA_QUERY_TEMPLATE.format( + Question=row["Question"].strip(), + A=choices[0].strip(), + B=choices[1].strip(), + C=choices[2].strip(), + D=choices[3].strip(), + ) + + # 0 means A, 1 means B, 2 means C, 3 means D + answer = ["A", "B", "C", "D"][gold_index] + return question, answer + + +@BENCHMARKS.register("gpqa") +class GPQABenchmarker(Benchmarker): + """GPQA benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + """Load and preprocess GPQA dataset.""" + # Read data + ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text, answer = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]: + if "Answer: " not in output: + return None + return output.split("Answer: ")[1].strip() + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + if not labels or len(labels) == 0: + return None + correct = sum(1 for pred, label in zip(predictions, labels) if pred == label) + return correct / len(labels) if len(labels) > 0 else 0.0 + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_gpqa_mcq_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/SpecForge-ext/benchmarks/benchmarker/gsm8k.py b/SpecForge-ext/benchmarks/benchmarker/gsm8k.py new file mode 100644 index 0000000000000000000000000000000000000000..e06e9d5b843a5d61aefc15e450da3fb6a2fe5424 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/gsm8k.py @@ -0,0 +1,108 @@ +""" +GSM8K benchmark evaluation script. +""" + +import ast +import os +import re +from typing import Any, Dict, List, Optional, Tuple + +from sglang.utils import download_and_cache_file, read_jsonl + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_few_shot_sgl_function + +INVALID = -9999999 + + +def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str: + """Format a single example.""" + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines: List[Dict], k: int) -> str: + """Get few-shot examples as a string.""" + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str: str) -> int: + """Extract numeric answer from model output.""" + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +@BENCHMARKS.register("gsm8k") +class GSM8KBenchmarker(Benchmarker): + """GSM8K benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + """Load and preprocess GSM8K dataset.""" + # 优先从本地数据目录读取 + local_path = "/workspace/hanrui/datasets/gsm8k/test.jsonl" + + if os.path.exists(local_path): + print(f"Loading GSM8K data from local: {local_path}") + lines = list(read_jsonl(local_path)) + else: + # 如果本地不存在,从网络下载 + print(f"Local data not found, downloading from GitHub...") + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) + + # Construct prompts + few_shot_examples = get_few_shot_examples(lines, 5) + + questions = [] + labels = [] + for i in range((len(lines))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = get_one_example(lines, i, False) + questions.append({"question": question_text}) + labels.append(get_answer_value(lines[i]["answer"])) + + # Store few_shot_examples for use in create_sgl_function + self.few_shot_examples = few_shot_examples + + assert all(l != INVALID for l in labels), "Some labels are invalid" + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]: + """Extract numeric answer from model output.""" + return get_answer_value(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for GSM8K by comparing numeric answers.""" + if not labels or len(labels) == 0: + return None + correct = sum(1 for pred, label in zip(predictions, labels) if pred == label) + return correct / len(labels) if len(labels) > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for GSM8K with few-shot examples.""" + return create_few_shot_sgl_function( + few_shot_examples=self.few_shot_examples, + function_name="few_shot_gsm8k", + answer_key="answer", + stop=["Question", "Assistant:", "<|separator|>"], + ) diff --git a/SpecForge-ext/benchmarks/benchmarker/humaneval.py b/SpecForge-ext/benchmarks/benchmarker/humaneval.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2b59779ad59c0e5f9590e9abf720f1459ee739 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/humaneval.py @@ -0,0 +1,201 @@ +""" +HumanEval benchmark evaluation script. +""" + +import json +import os +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_code_from_output(output: str) -> Optional[str]: + """Extract Python code from model output. + + Tries to extract code blocks or function definitions. + """ + # Try to find code in markdown code blocks + code_block_pattern = r"```(?:python)?\n(.*?)```" + match = re.search(code_block_pattern, output, re.DOTALL) + if match: + return match.group(1).strip() + + # Try to find function definition (common in HumanEval) + # Look for "def " followed by code until the next def or end of string + def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)" + match = re.search(def_pattern, output, re.DOTALL) + if match: + return match.group(1).strip() + + # Fallback: return the output as-is (might already be code) + return output.strip() if output.strip() else None + + +def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool: + """Check if generated code passes the test cases. + + This is a simplified version. For full evaluation, use the official + HumanEval evaluation framework. + + HumanEval test code typically contains assertions that will raise + AssertionError if the code doesn't pass. If execution completes without + exceptions, the tests pass. + """ + try: + # Create a safe execution environment + namespace = {} + # Execute the code (function definition) + exec(code, namespace) + # Execute the test code (which contains assertions) + # If no exception is raised, the tests pass + exec(test_code, namespace) + return True + except AssertionError: + # Assertion failed - test didn't pass + return False + except Exception: + # Any other exception (syntax error, runtime error, etc.) means test failed + return False + + +@BENCHMARKS.register("humaneval") +class HumanEvalBenchmarker(Benchmarker): + """HumanEval benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + """Initialize benchmark and store test cases.""" + super().__init__(num_samples, None) + self.test_cases = [] + self.entry_points = [] + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]: + """Load and preprocess HumanEval dataset.""" + # 优先从本地数据目录读取 + local_path = "/workspace/hanrui/datasets/humaneval/test.jsonl" + + if os.path.exists(local_path): + print(f"Loading HumanEval data from local: {local_path}") + with open(local_path, 'r') as f: + dataset = [json.loads(line) for line in f] + else: + # 如果本地不存在,从 HuggingFace 下载 + print(f"Local data not found, downloading from HuggingFace...") + dataset = load_dataset("openai/openai_humaneval")["test"] + + questions = [] + labels = [] + self.test_cases = [] + self.entry_points = [] + + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + questions.append({"question": q["prompt"]}) + + # Store test case and entry point for evaluation + test_code = q.get("test", "") + entry_point = q.get("entry_point", "") + self.test_cases.append(test_code) + self.entry_points.append(entry_point) + + # Store canonical solution as reference (optional, for comparison) + canonical_solution = q.get("canonical_solution", "") + labels.append( + { + "test": test_code, + "entry_point": entry_point, + "canonical_solution": canonical_solution, + } + ) + + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract code from model output.""" + return extract_code_from_output(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for HumanEval by checking if code passes tests. + + Note: This is a simplified evaluation. For official pass@k metrics, + use the HumanEval evaluation framework. + """ + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + + for i, (pred, label) in enumerate(zip(predictions, labels)): + if label is not None and isinstance(label, dict): + valid_count += 1 + if pred is not None: + try: + # Get the prompt (function signature and docstring) + prompt = self.questions[i]["question"] + entry_point = label.get("entry_point", "") + + # The prompt contains the function signature (e.g., "def function_name(...):") + # The generated code might be: + # 1. Just the function body (what we want) - need to combine with prompt + # 2. The complete function including signature - use as-is + # 3. Code in markdown blocks - already extracted by extract_code_from_output + + pred_str = str(pred).strip() + + # Check if pred already contains a complete function definition + # (starts with "def " and contains the entry_point function name) + if pred_str.startswith("def ") and entry_point: + # Check if this is the same function (by name) + func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str) + if ( + func_name_match + and func_name_match.group(1) == entry_point + ): + # Generated code includes complete function, use it as-is + full_code = pred_str + else: + # Different function or no match, combine with prompt + full_code = prompt + "\n" + pred_str + elif pred_str.startswith("def "): + # Has function definition but we can't verify entry_point, use as-is + full_code = pred_str + else: + # Generated code is just the body, combine with prompt + full_code = prompt + "\n" + pred_str + + # Check if code passes tests + test_code = label.get("test", "") + + if test_code and check_code_passes_tests( + full_code, test_code, entry_point + ): + correct += 1 + except Exception as e: + # If evaluation fails, consider it incorrect + # Uncomment for debugging: print(f"Error evaluating code {i}: {e}") + pass + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for HumanEval.""" + return create_simple_sgl_function( + function_name="get_humaneval_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) + + def get_max_new_tokens(self) -> int: + """HumanEval code generation requires more tokens.""" + return 1024 diff --git a/SpecForge-ext/benchmarks/benchmarker/livecodebench.py b/SpecForge-ext/benchmarks/benchmarker/livecodebench.py new file mode 100644 index 0000000000000000000000000000000000000000..490ba2b20349ecd68a3edc468d38ef377c6e8d05 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/livecodebench.py @@ -0,0 +1,46 @@ +""" +GSM8K benchmark evaluation script. +""" + +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def generate_question(row: Dict[str, Any]) -> str: + question = row["question_content"].strip() + return question + + +@BENCHMARKS.register("livecodebench") +class LCBBenchmarker(Benchmarker): + """LiveCodeBench benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + # Read data + ds = load_dataset("livecodebench/code_generation")["test"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(None) + return questions, labels + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_livecodebench_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/SpecForge-ext/benchmarks/benchmarker/math500.py b/SpecForge-ext/benchmarks/benchmarker/math500.py new file mode 100644 index 0000000000000000000000000000000000000000..64ca48eb386aa6f388ef997c34de496dad4db1b7 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/math500.py @@ -0,0 +1,122 @@ +""" +MATH-500 benchmark evaluation script. +""" + +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_math_answer(output: str) -> Optional[str]: + """Extract final answer from math problem solution. + + Tries to extract answer from \boxed{} format first, then looks for + the last number in the output. + """ + # Try to find answer in \boxed{} format + boxed_pattern = r"\\boxed\{([^}]+)\}" + match = re.search(boxed_pattern, output) + if match: + return match.group(1).strip() + + # Try to find answer in \boxed format (without braces) + boxed_pattern2 = r"\\boxed\s+([^\s]+)" + match = re.search(boxed_pattern2, output) + if match: + return match.group(1).strip() + + # Try to find the last number (could be integer or decimal) + # Look for patterns like "The answer is 42" or "Answer: 3.14" + answer_patterns = [ + r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)", + r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$", + ] + for pattern in answer_patterns: + matches = re.findall(pattern, output, re.IGNORECASE) + if matches: + return matches[-1].strip() + + # Fallback: extract the last number in the text + numbers = re.findall(r"[-+]?\d*\.?\d+", output) + if numbers: + return numbers[-1] + + return None + + +@BENCHMARKS.register("math500") +class Math500Benchmarker(Benchmarker): + """MATH-500 benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]: + """Load and preprocess MATH-500 dataset.""" + dataset = load_dataset("HuggingFaceH4/MATH-500")["test"] + questions = [] + labels = [] + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + questions.append({"question": q["problem"]}) + # Extract answer from solution or answer field + answer = None + if "answer" in q: + answer = str(q["answer"]).strip() + elif "solution" in q: + # Try to extract from solution + answer = extract_math_answer(q["solution"]) + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract answer from model output.""" + return extract_math_answer(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for MATH-500 by comparing answers.""" + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + for pred, label in zip(predictions, labels): + if label is not None: + valid_count += 1 + if pred is not None: + # Normalize answers for comparison (remove whitespace, handle different formats) + pred_normalized = str(pred).strip().lower() + label_normalized = str(label).strip().lower() + # Try exact match first + if pred_normalized == label_normalized: + correct += 1 + else: + # Try numeric comparison if both are numbers + try: + pred_num = float(pred_normalized) + label_num = float(label_normalized) + if abs(pred_num - label_num) < 1e-6: + correct += 1 + except ValueError: + pass + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for MATH-500.""" + return create_simple_sgl_function( + function_name="get_math500_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/SpecForge-ext/benchmarks/benchmarker/mmlu.py b/SpecForge-ext/benchmarks/benchmarker/mmlu.py new file mode 100644 index 0000000000000000000000000000000000000000..407339a82e2f1d86d8829a33ededb2201f3b2ee2 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/mmlu.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + +GPQA_QUERY_TEMPLATE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + + +def generate_question(row: Dict[str, Any]) -> str: + choices = row["choices"] + question = GPQA_QUERY_TEMPLATE.format( + Question=row["question"].strip(), + A=choices[0].strip(), + B=choices[1].strip(), + C=choices[2].strip(), + D=choices[3].strip(), + ) + + # 0 means A, 1 means B, 2 means C, 3 means D + answer = ["A", "B", "C", "D"][row["answer"]] + print(answer) + return question, answer + + +@BENCHMARKS.register("mmlu") +class MMLUBenchmarker(Benchmarker): + """MMLU benchmark implementation.""" + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + if subset is None: + subset = ["all"] + super().__init__(num_samples, subset) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + # Read data + questions = [] + labels = [] + + for subset in self.subset: + ds = load_dataset("cais/mmlu", subset)["test"] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text, answer = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]: + if "Answer: " not in output: + return None + return output.split("Answer: ")[1].strip() + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + if not labels or len(labels) == 0: + return None + correct = sum(1 for pred, label in zip(predictions, labels) if pred == label) + return correct / len(labels) if len(labels) > 0 else 0.0 + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_mmlu_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/SpecForge-ext/benchmarks/benchmarker/mmstar.py b/SpecForge-ext/benchmarks/benchmarker/mmstar.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab1c44a28023a6bf18277edcacbe96794fa2c6a --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/mmstar.py @@ -0,0 +1,185 @@ +""" +MMStar benchmark evaluation script. +""" + +import os +import re +import shutil +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_image_sgl_function + + +def extract_mmstar_answer( + output: str, options: Optional[List[str]] = None +) -> Optional[str]: + """Extract answer from MMStar model output. + + MMStar questions typically have multiple choice options (A, B, C, D, etc.) + """ + output_upper = output.strip().upper() + + # Try to find answer choice (A, B, C, D, etc.) + # Direct match for single letter + match = re.search(r"\b([A-Z])\b", output_upper) + if match: + letter = match.group(1) + if options and len(options) > 0: + # Validate that the letter is within valid range + max_option = chr(64 + len(options)) # 'A' + (len-1) + if "A" <= letter <= max_option: + return letter + else: + # Assume A-D are valid + if "A" <= letter <= "D": + return letter + + # Try to find answer in parentheses or brackets + for pattern in [ + r"\(([A-Z])\)", + r"\[([A-Z])\]", + r"答案[::]\s*([A-Z])", + r"Answer[::]\s*([A-Z])", + r"选择[::]\s*([A-Z])", + ]: + match = re.search(pattern, output_upper) + if match: + letter = match.group(1) + if options and len(options) > 0: + max_option = chr(64 + len(options)) + if "A" <= letter <= max_option: + return letter + elif "A" <= letter <= "D": + return letter + + return None + + +@BENCHMARKS.register("mmstar") +class MMStarBenchmarker(Benchmarker): + """MMStar benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + """Initialize benchmark and set up cache directory.""" + self.cache_dir = None + self.options_list = [] # Store options for each question + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]: + """Load and preprocess MMStar dataset.""" + self.cache_dir = os.path.join(".cache", "mmstar_specforge") + image_dir = os.path.join(self.cache_dir, "images") + os.makedirs(self.cache_dir, exist_ok=True) + os.makedirs(image_dir, exist_ok=True) + print(f"Created temporary image directory: {self.cache_dir}") + + dataset = load_dataset("Lin-Chen/MMStar")["val"] + questions = [] + labels = [] + self.options_list = [] + + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + image = q["image"] + image_path = os.path.join(self.cache_dir, q["meta_info"]["image_path"]) + image.convert("RGB").save(image_path, "JPEG") + + # Extract question and options + question_full = q["question"] + if "Options:" in question_full: + question_text, options_text = question_full.split("Options:", 1) + question_text = question_text.strip() + # Parse options (typically A. option1 B. option2 etc.) + options = [] + for line in options_text.strip().split("\n"): + line = line.strip() + if line and re.match(r"^[A-Z]\.", line): + option_text = re.sub(r"^[A-Z]\.\s*", "", line).strip() + options.append(option_text) + self.options_list.append(options) + else: + question_text = question_full.strip() + self.options_list.append([]) + + item = { + "image_path": image_path, + "question": question_text, + } + questions.append(item) + + # Extract ground truth answer + answer = None + if "answer" in q: + answer = str(q["answer"]).strip().upper() + elif "correct_answer" in q: + answer = str(q["correct_answer"]).strip().upper() + elif "ground_truth" in q: + answer = str(q["ground_truth"]).strip().upper() + + # Validate answer is a valid option letter + if answer and len(answer) == 1 and "A" <= answer <= "Z": + if self.options_list[-1]: + max_option = chr(64 + len(self.options_list[-1])) + if answer <= max_option: + labels.append(answer) + else: + labels.append(None) + else: + labels.append(answer) + else: + labels.append(None) + + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract answer from model output.""" + # Use the options for the current question if available + # Note: We can't easily get the question index here, so we'll use a simpler approach + return extract_mmstar_answer(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for MMStar by comparing answer choices.""" + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + for pred, label in zip(predictions, labels): + if label is not None: + valid_count += 1 + if pred is not None: + # Normalize to uppercase for comparison + pred_normalized = str(pred).strip().upper() + label_normalized = str(label).strip().upper() + if pred_normalized == label_normalized: + correct += 1 + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for MMStar (image-based Q&A).""" + return create_image_sgl_function( + function_name="get_mmstar_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) + + def run(self, *args, **kwargs): + """Run benchmark and clean up cache directory.""" + try: + return super().run(*args, **kwargs) + finally: + # Clean up cache directory + if self.cache_dir and os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + print(f"Deleted temporary directory: {self.cache_dir}") diff --git a/SpecForge-ext/benchmarks/benchmarker/mtbench.py b/SpecForge-ext/benchmarks/benchmarker/mtbench.py new file mode 100644 index 0000000000000000000000000000000000000000..5624e434ab8ce0a356b6b5dfac1da7f39a4ac82a --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/mtbench.py @@ -0,0 +1,70 @@ +""" +MT-Bench benchmark evaluation script. +Adapted from https://github.com/chromecast56/sglang/blob/6f145d2eadb93a116134f703358ce76f15381045/benchmark/mtbench/bench_sglang.py +""" + +import os +from typing import Any, Dict, List, Optional, Tuple + +from sglang.utils import download_and_cache_file, read_jsonl + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_multi_turn_sgl_function + +SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." + + +@BENCHMARKS.register("mtbench") +class MTBenchBenchmarker(Benchmarker): + """MT-Bench benchmark implementation.""" + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + # support categorical data for mtbench + if subset is None: + subset = ["all"] + super().__init__(num_samples, subset) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[None]]: + """Load and preprocess MT-Bench dataset.""" + # 优先从本地数据目录读取 + local_path = "/workspace/hanrui/datasets/mtbench/question.jsonl" + + if os.path.exists(local_path): + print(f"Loading MT-Bench data from local: {local_path}") + questions_data = list(read_jsonl(local_path)) + else: + # 如果本地不存在,从网络下载 + print(f"Local data not found, downloading from GitHub...") + url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" + download_and_cache_file(url, filename="mtbench.jsonl") + questions_data = list(read_jsonl("mtbench.jsonl")) + + questions_data = questions_data + + questions = [ + {"question_1": q["turns"][0], "question_2": q["turns"][1]} + for q in questions_data + ] + # MT-Bench doesn't have labels for accuracy computation + labels = [None] * len(questions) + + if self.num_samples is not None: + questions = questions[: self.num_samples] + labels = labels[: self.num_samples] + return questions, labels + + def create_sgl_function(self): + """Create SGL function for MT-Bench (2-turn conversation).""" + return create_multi_turn_sgl_function( + function_name="answer_mt_bench", + system_prompt=SYSTEM_PROMPT, + num_turns=2, + max_tokens=self.get_max_new_tokens(), + ) + + def get_answer_keys(self) -> List[str]: + """Return answer keys for multi-turn conversation.""" + return ["answer_1", "answer_2"] diff --git a/SpecForge-ext/benchmarks/benchmarker/registry.py b/SpecForge-ext/benchmarks/benchmarker/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4f474fcd15bd9a891b8f8977465aaa233c9fd1 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/registry.py @@ -0,0 +1,31 @@ +class BenchmarkRegistry: + + def __init__(self): + self.benchmarks = {} + + def register(self, name: str): + """ + Usage: + ```python + BENCHMARKS = BenchmarkRegistry() + + BENCHMARKS.register("aime") + class AIMEBenchmarker(Benchmarker): + ... + ``` + """ + + def wrapper(cls): + self.benchmarks[name] = cls + return cls + + return wrapper + + def get(self, name: str) -> type: + """ + Get the benchmark class by name. + """ + return self.benchmarks[name] + + +BENCHMARKS = BenchmarkRegistry() diff --git a/SpecForge-ext/benchmarks/benchmarker/simpleqa.py b/SpecForge-ext/benchmarks/benchmarker/simpleqa.py new file mode 100644 index 0000000000000000000000000000000000000000..5facab00d719d6d235a8cb50d161679ebe28f6a0 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/simpleqa.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def generate_question(row: Dict[str, Any]) -> str: + question = row["problem"].strip() + return question + + +@BENCHMARKS.register("simpleqa") +class SimpleQABenchmarker(Benchmarker): + """SimpleQA benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + # Read data + ds = load_dataset("basicv8vc/SimpleQA")["test"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(None) + return questions, labels + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_simpleqa_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/SpecForge-ext/benchmarks/benchmarker/utils.py b/SpecForge-ext/benchmarks/benchmarker/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a6dabfb9a4ef1789b7a89a5d7131755a6e6fa8 --- /dev/null +++ b/SpecForge-ext/benchmarks/benchmarker/utils.py @@ -0,0 +1,273 @@ +""" +Utility functions for benchmark scripts. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import sglang as sgl + + +@dataclass +class BenchmarkMetrics: + """Container for benchmark performance metrics.""" + + latency: float + output_throughput: float + accept_length: float + accuracy: Optional[float] = None + num_questions: int = 0 + num_valid_predictions: int = 0 + categorical_performance: Optional[Dict[str, "BenchmarkMetrics"]] = None + + +def compute_metrics( + states: List[Any], + latency: float, + answer_key: str = "answer", + additional_answer_keys: Optional[List[str]] = None, +) -> BenchmarkMetrics: + """ + Compute performance metrics from SGLang states. + + Args: + states: List of SGLang state objects from run_batch + latency: Total latency in seconds + answer_key: Primary key for answer in state meta info + additional_answer_keys: Additional keys to include in token count (e.g., ["answer_1", "answer_2"]) + + Returns: + BenchmarkMetrics object with computed metrics + """ + # Compute output tokens + num_output_tokens = 0 + if additional_answer_keys: + for key in [answer_key] + additional_answer_keys: + num_output_tokens += sum( + s.get_meta_info(key)["completion_tokens"] for s in states + ) + else: + num_output_tokens = sum( + s.get_meta_info(answer_key)["completion_tokens"] for s in states + ) + + output_throughput = num_output_tokens / latency if latency > 0 else 0.0 + + # Compute accept length (speculative decoding metric) + has_verify = "spec_verify_ct" in states[0].get_meta_info(answer_key) + if has_verify: + num_verify_tokens = 0 + if additional_answer_keys: + for key in [answer_key] + additional_answer_keys: + num_verify_tokens += sum( + s.get_meta_info(key).get("spec_verify_ct", 0) for s in states + ) + else: + num_verify_tokens = sum( + s.get_meta_info(answer_key).get("spec_verify_ct", 0) for s in states + ) + + if num_verify_tokens == 0: + accept_length = 1.0 + else: + accept_length = num_output_tokens / num_verify_tokens + else: + accept_length = 1.0 + + return BenchmarkMetrics( + latency=latency, + output_throughput=output_throughput, + accept_length=accept_length, + num_questions=len(states), + ) + + +def print_results( + metrics_list: List[BenchmarkMetrics], + benchmark_name: str, + show_accuracy: bool = False, +): + """ + Print benchmark results in a formatted way. + + Args: + metrics_list: List of BenchmarkMetrics from multiple runs + benchmark_name: Name of the benchmark + show_accuracy: Whether to show accuracy metrics + """ + avg_latency = np.mean([m.latency for m in metrics_list]) + avg_throughput = np.mean([m.output_throughput for m in metrics_list]) + avg_accept_length = np.mean([m.accept_length for m in metrics_list]) + + print(f"\n{'='*50}") + print(f"{benchmark_name} Evaluation Results") + print(f"{'='*50}") + print(f"Number of questions: {metrics_list[0].num_questions}") + if show_accuracy: + if metrics_list[0].accuracy is not None: + avg_accuracy = np.mean( + [m.accuracy for m in metrics_list if m.accuracy is not None] + ) + print(f"Average Accuracy: {avg_accuracy:.4f} ({avg_accuracy*100:.2f}%)") + else: + print(f"Average Accuracy: None") + print(f"Average Latency: {avg_latency:.3f} s") + print(f"Average Output throughput: {avg_throughput:.3f} token/s") + print(f"Average Accept length: {avg_accept_length:.3f}") + print(f"{'='*50}\n") + + +def create_simple_sgl_function( + function_name: str = "get_answer", + answer_key: str = "answer", + system_prompt: Optional[str] = None, + max_tokens: int = 2048, + stop: Optional[List[str]] = None, + user_prefix: Optional[str] = None, +) -> Callable: + """ + Create a simple SGL function for single-turn Q&A. + + Args: + function_name: Name of the function + answer_key: Key for storing the answer + system_prompt: Optional system prompt + max_tokens: Maximum tokens to generate + stop: Optional stop sequences + user_prefix: Optional suffix to append to user message (appended after question) + + Returns: + SGL function decorated with @sgl.function + """ + + @sgl.function + def sgl_func(s, question): + if system_prompt: + s += sgl.system(system_prompt) + user_content = question + if user_prefix: + user_content = question + user_prefix + s += sgl.user(user_content) + gen_kwargs = {"max_tokens": max_tokens} + if stop: + gen_kwargs["stop"] = stop + s += sgl.assistant(sgl.gen(answer_key, **gen_kwargs)) + + sgl_func.__name__ = function_name + return sgl_func + + +def create_few_shot_sgl_function( + few_shot_examples: str, + function_name: str = "few_shot_answer", + answer_key: str = "answer", + max_tokens: int = 512, + stop: Optional[List[str]] = None, +) -> Callable: + """ + Create an SGL function for few-shot learning. + + Args: + few_shot_examples: String containing few-shot examples + function_name: Name of the function + answer_key: Key for storing the answer + max_tokens: Maximum tokens to generate + stop: Optional stop sequences + + Returns: + SGL function decorated with @sgl.function + """ + + @sgl.function + def sgl_func(s, question): + s += few_shot_examples + question + gen_kwargs = {"max_tokens": max_tokens} + if stop: + gen_kwargs["stop"] = stop + s += sgl.gen(answer_key, **gen_kwargs) + + sgl_func.__name__ = function_name + return sgl_func + + +def create_multi_turn_sgl_function( + function_name: str = "multi_turn_answer", + system_prompt: Optional[str] = None, + num_turns: int = 2, + max_tokens: int = 2048, +) -> Callable: + """ + Create an SGL function for multi-turn conversations (e.g., MT-Bench with 2 turns). + + Args: + function_name: Name of the function + system_prompt: Optional system prompt + num_turns: Number of conversation turns (default: 2) + max_tokens: Maximum tokens to generate per turn + + Returns: + SGL function decorated with @sgl.function + """ + if num_turns == 2: + # Most common case: 2-turn conversation + @sgl.function + def sgl_func(s, question_1, question_2): + if system_prompt: + s += sgl.system(system_prompt) + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=max_tokens)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=max_tokens)) + + else: + # Generic case: create function with dynamic number of turns + # Note: This requires the caller to pass arguments as a dict + @sgl.function + def sgl_func(s, **kwargs): + if system_prompt: + s += sgl.system(system_prompt) + for i in range(num_turns): + question_key = f"question_{i+1}" + answer_key = f"answer_{i+1}" + if question_key in kwargs: + s += sgl.user(kwargs[question_key]) + s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens)) + + sgl_func.__name__ = function_name + return sgl_func + + +def create_image_sgl_function( + function_name: str = "get_image_answer", + answer_key: str = "answer", + max_tokens: int = 2048, +) -> Callable: + """ + Create an SGL function for image-based Q&A. + + Args: + function_name: Name of the function + answer_key: Key for storing the answer + max_tokens: Maximum tokens to generate + + Returns: + SGL function decorated with @sgl.function + """ + + @sgl.function + def sgl_func(s, image_path, question, **kwargs): + """ + The body of the SGL function: constructs a multimodal conversation flow. + + - First, it inputs an image + text question as 'user'. + - Then, it generates an answer as 'assistant', binding the response to the specified `answer_key`. + + Note: sgl.image() automatically encodes the image into a format supported by the model for multimodal input. + """ + # User input: Image + Text question + s += sgl.user(sgl.image(image_path) + question) + s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens)) + + sgl_func.__name__ = function_name + return sgl_func diff --git a/SpecForge-ext/configs/qwen2-5-vl-7b-eagle3.json b/SpecForge-ext/configs/qwen2-5-vl-7b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..672193e3b1284badcb747356f1cbfcd402e19ccf --- /dev/null +++ b/SpecForge-ext/configs/qwen2-5-vl-7b-eagle3.json @@ -0,0 +1,40 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 8192, + "max_window_layers": 28, + "model_type": "llama", + "target_model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16, + 24, + 24 + ] + }, + "rope_theta": 1000000, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 + } diff --git a/SpecForge-ext/examples/README.md b/SpecForge-ext/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ff5a6f3a8a5aae6c9ff7645afc266dc6cd7363bc --- /dev/null +++ b/SpecForge-ext/examples/README.md @@ -0,0 +1,9 @@ +# Run SpecForge Examples + +This folder contains the examples of running SpecForge on different models. The scripts can be invoked by the following command: + +```bash +bash examples/.sh [NUM_GPUS] [TP_SIZE] +``` + +We use the ShareGPT dataset for all the examples for now, but you can replace it with more robust datasets such as perfectblend, magpie-qwen2.5-pro-1m-v0.1, etc. diff --git a/SpecForge-ext/examples/run_deepseek_v2_lite_eagle3_online.sh b/SpecForge-ext/examples/run_deepseek_v2_lite_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..283c62ce80743d52bc29aaddc3b0b7a9829890c7 --- /dev/null +++ b/SpecForge-ext/examples/run_deepseek_v2_lite_eagle3_online.sh @@ -0,0 +1,25 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v2-lite +NUM_GPUS=${1:-8} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path deepseek-ai/DeepSeek-V2-Lite \ + --draft-model-config $ROOT_DIR/configs/deepseek-v2-lite-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template deepseek \ + --target-model-backend hf \ + --cache-dir $ROOT_DIR/cache diff --git a/SpecForge-ext/examples/run_deepseek_v3_671b_eagle3_offline.sh b/SpecForge-ext/examples/run_deepseek_v3_671b_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..4bede1dd50be44503365fb03fce8624e1bed2d4e --- /dev/null +++ b/SpecForge-ext/examples/run_deepseek_v3_671b_eagle3_offline.sh @@ -0,0 +1,43 @@ + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v3 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# generate hidden states +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + scripts/prepare_hidden_states.py \ + --target-model-path deepseek-ai/DeepSeek-V3 \ + --enable-aux-hidden-states \ + --data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --output-path $ROOT_DIR/cache/hidden_states/perfect-blend-deepseek-v3 \ + --chat-template deepseek-v3 \ + --max-length 2048 \ + --tp-size 8 \ + --batch-size 4 \ + --sglang-mem-fraction-static 0.75 + +# train eagle3 offline +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path deepseek-ai/DeepSeek-V3 \ + --draft-model-config $ROOT_DIR/configs/deepseek-v3-671b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/perfect-blend-deepseek-v3 \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/deepseek-v3-671B-eagle3-perfect-blend-offline \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template deepseek-v3 \ + --cache-dir $ROOT_DIR/cache diff --git a/SpecForge-ext/examples/run_gemma3_1b_eagle3_online.sh b/SpecForge-ext/examples/run_gemma3_1b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1365069594baae0f7b8acbc45640c5ea39e0731 --- /dev/null +++ b/SpecForge-ext/examples/run_gemma3_1b_eagle3_online.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma3-1b +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-3-1b-it \ + --draft-model-config $ROOT_DIR/configs/gemma3-1b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --output-dir $ROOT_DIR/outputs/gemma3-1b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template gemma \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 10 diff --git a/SpecForge-ext/examples/run_gpt_oss_120b_eagle3_online.sh b/SpecForge-ext/examples/run_gpt_oss_120b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..eea5afbd8a6512945d35f4b005c338fc5f1671a8 --- /dev/null +++ b/SpecForge-ext/examples/run_gpt_oss_120b_eagle3_online.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for GPT-OSS-120B +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path openai/gpt-oss-120b \ + --draft-model-config $ROOT_DIR/configs/gpt-oss-20B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend-gptoss-20B.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/gpt-oss-20b-eagle3 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template gpt-oss \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 diff --git a/SpecForge-ext/examples/run_gpt_oss_20b_eagle3_online.sh b/SpecForge-ext/examples/run_gpt_oss_20b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..55baeac1c49576e25162b6ff78558a7df8c4ee2d --- /dev/null +++ b/SpecForge-ext/examples/run_gpt_oss_20b_eagle3_online.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for GPT-OSS-20B +NUM_GPUS=${1:-8} +TP_SIZE=${2:-2} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path openai/gpt-oss-20b \ + --draft-model-config $ROOT_DIR/configs/gpt-oss-20B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend-gptoss-20B.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/perfect-blend-gptoss-20b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template gpt-oss \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 diff --git a/SpecForge-ext/examples/run_ling_flash_2.0_eagle3_offline.sh b/SpecForge-ext/examples/run_ling_flash_2.0_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7f2925b4bd3381d0d9984a59db7e3b0f3699faa --- /dev/null +++ b/SpecForge-ext/examples/run_ling_flash_2.0_eagle3_offline.sh @@ -0,0 +1,45 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for ling-flash-2.0 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# generate hidden states +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + scripts/prepare_hidden_states.py \ + --target-model-path inclusionAI/Ling-flash-2.0 \ + --enable-aux-hidden-states \ + --data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --output-path $ROOT_DIR/cache/hidden_states/perfect-blend-ling-flash-2.0 \ + --chat-template ling-flash-2.0 \ + --max-length 2048 \ + --tp-size $TP_SIZE \ + --batch-size 4 \ + --sglang-mem-fraction-static 0.75 \ + --trust-remote-code + +# train eagle3 offline +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path inclusionAI/Ling-flash-2.0 \ + --draft-model-config $ROOT_DIR/configs/ling-flash-2.0-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/perfect-blend-ling-flash-2.0 \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/ling-flash-2.0-eagle3-perfect-blend-offline \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template ling-flash-2.0 \ + --embedding-key 'model.word_embeddings.weight' \ + --cache-dir $ROOT_DIR/cache \ + --trust-remote-code diff --git a/SpecForge-ext/examples/run_ling_flash_2.0_eagle3_online.sh b/SpecForge-ext/examples/run_ling_flash_2.0_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..8f9d1cc3d87905107b0708b99ec8a32a832831a2 --- /dev/null +++ b/SpecForge-ext/examples/run_ling_flash_2.0_eagle3_online.sh @@ -0,0 +1,30 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for ling-flash-2.0 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# train eagle3 online +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path inclusionAI/Ling-flash-2.0 \ + --draft-model-config $ROOT_DIR/configs/ling-flash-2.0-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/ling-flash-2.0-eagle3-perfect-blend-online \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template ling-flash-2.0 \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 \ + --sglang-mem-fraction-static 0.75 \ + --embedding-key 'model.word_embeddings.weight' \ + --trust-remote-code diff --git a/SpecForge-ext/examples/run_llama3.1_8b_eagle3_offline.sh b/SpecForge-ext/examples/run_llama3.1_8b_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..dffcbef845b727e6be2eeb2f24d63de2cc8b693f --- /dev/null +++ b/SpecForge-ext/examples/run_llama3.1_8b_eagle3_offline.sh @@ -0,0 +1,39 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# generate hidden states +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --output-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --chat-template llama3 \ + --max-length 4096 \ + --tp-size $TP_SIZE \ + --batch-size 32 + +# train eagle3 offline +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt-offline \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache diff --git a/SpecForge-ext/examples/run_llama3.1_8b_eagle3_online.sh b/SpecForge-ext/examples/run_llama3.1_8b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..d47c1797fa14aaaba21784f108ea2c270163b805 --- /dev/null +++ b/SpecForge-ext/examples/run_llama3.1_8b_eagle3_online.sh @@ -0,0 +1,29 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +# train eagle3 for llama3.1-8b +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend sglang \ + --log-interval 10 \ + --sglang-mem-fraction-static 0.25 diff --git a/SpecForge-ext/examples/run_llama3.3_70b_eagle3_online.sh b/SpecForge-ext/examples/run_llama3.3_70b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..0ea80413df676a319d8a1a38eaeb036d21d3321d --- /dev/null +++ b/SpecForge-ext/examples/run_llama3.3_70b_eagle3_online.sh @@ -0,0 +1,25 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for llama3.1-8b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.3-70B-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama3-70B-ealge3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama3.3-70b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache \ + --target-model-backend sglang diff --git a/SpecForge-ext/examples/run_llama4_scout_eagle3_online.sh b/SpecForge-ext/examples/run_llama4_scout_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..73ed03a617297b64569cafecfd5bce9a2cf8f940 --- /dev/null +++ b/SpecForge-ext/examples/run_llama4_scout_eagle3_online.sh @@ -0,0 +1,25 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for llama3.1-8b +NUM_GPUS=${1:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-4-Scout-17B-16E-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama4-scout-17B-16E-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama4-scout-17B-16E-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama4 \ + --cache-dir $ROOT_DIR/cache \ + --tp-size 8 \ + --embedding-key language_model.model.embed_tokens.weight \ diff --git a/SpecForge-ext/examples/run_longcat_flash_dflash_online.sh b/SpecForge-ext/examples/run_longcat_flash_dflash_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..2c721354580927dce7e37425b0399569cbb962c5 --- /dev/null +++ b/SpecForge-ext/examples/run_longcat_flash_dflash_online.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=${SPECFORGE_DATA_NUM_PROC:-64} + +NUM_GPUS=${1:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} +WANDB_MODE=offline +SGL_JIT_DEEPGEMM_PRECOMPILE=false +SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dflash.py \ + --target-model-path meituan-longcat/LongCat-Flash-Chat-FP8 \ + --target-model-backend sglang \ + --tp-size $NUM_GPUS \ + --sglang-attention-backend flashinfer \ + --sglang-mem-fraction-static 0.75 \ + --sglang-ep-size $NUM_GPUS \ + --draft-config-path $ROOT_DIR/configs/longcat-flash-dflash.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/longcat-flash-dflash-sharegpt \ + --num-epochs 20 \ + --batch-size 2 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template longcat \ + --log-interval 50 \ + --save-interval 1000 \ + --report-to wandb \ + --wandb-project specforge-longcat-flash-dflash \ + --wandb-name longcat-flash-dflash-sharegpt \ + --mask-token-id 2 diff --git a/SpecForge-ext/examples/run_longcat_flash_eagle3_online.sh b/SpecForge-ext/examples/run_longcat_flash_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..f89cb502009610f0f19db6fef25c268ff1c8f641 --- /dev/null +++ b/SpecForge-ext/examples/run_longcat_flash_eagle3_online.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meituan-longcat/LongCat-Flash-Chat-FP8 \ + --draft-model-config $ROOT_DIR/configs/longcat-flash-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/longcat-flash-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template longcat \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend sglang \ + --log-interval 10 \ + --sglang-mem-fraction-static 0.75 \ + --sglang-attention-backend flashinfer \ + --sglang-ep-size $NUM_GPUS diff --git a/SpecForge-ext/examples/run_phi4_eagle3_online.sh b/SpecForge-ext/examples/run_phi4_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..f306d22e71941b295bab03695a7f3c3187fc54d5 --- /dev/null +++ b/SpecForge-ext/examples/run_phi4_eagle3_online.sh @@ -0,0 +1,27 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path microsoft/phi-4 \ + --draft-model-config $ROOT_DIR/configs/phi4-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/phi4-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template phi4 \ + --cache-dir $ROOT_DIR/cache \ + --target-model-backend sglang \ + --embedding-key model.embed_tokens.weight diff --git a/SpecForge-ext/examples/run_qwen2.5_32b_vl_eagle3_online.sh b/SpecForge-ext/examples/run_qwen2.5_32b_vl_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..a7c86b0e502e19a1e39f42860d4804c768e84642 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen2.5_32b_vl_eagle3_online.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# support tp1 train eagle3 for qwen2.5-vl-7b-instruct +NUM_GPUS=${1:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen2.5-VL-32B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen2.5-vl-32b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/allava4v_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen2.5-vl-32b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --dist-timeout 360 \ + --chat-template qwen2-vl \ + --target-model-backend sglang \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 4 \ + --sglang-mem-fraction-static 0.5 \ + --is-vlm \ + --min-pixels 200704 \ + --max-pixels 1003520 diff --git a/SpecForge-ext/examples/run_qwen2.5_7b_vl_eagle3_online.sh b/SpecForge-ext/examples/run_qwen2.5_7b_vl_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..e94e6e39882484f0e6590e72686d594ba4bf1ff0 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen2.5_7b_vl_eagle3_online.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# support tp1 train eagle3 for qwen2.5-vl-7b-instruct +NUM_GPUS=${1:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen2.5-VL-7B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen2-5-vl-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/allava4v_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/Qwen2.5-VL-7B-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 8192 \ + --dist-timeout 360 \ + --chat-template qwen2-vl \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 1 \ + --is-vlm \ + --min-pixels 50176 \ + --max-pixels 802816 diff --git a/SpecForge-ext/examples/run_qwen3_235b_a22b_eagle3.sh b/SpecForge-ext/examples/run_qwen3_235b_a22b_eagle3.sh new file mode 100644 index 0000000000000000000000000000000000000000..c96b42cb6267bce71fe520669544d9a59eb1cffc --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_235b_a22b_eagle3.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp4/tp8 train eagle3 for Qwen3-30B-A3B +NUM_GPUS=8 +TP_SIZE=4 +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path /workdir/huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-FP8/\ + --draft-model-config $ROOT_DIR/configs/qwen3-next-80b-a3b-eagle3.json \ + --train-data-path /workdir/data_qwen80b/qwen3_80b_perfectblend_train_regen.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir /workdir/qwen3-80b-regen-blend \ + --num-epochs 2 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir /workdir/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/SpecForge-ext/examples/run_qwen3_8b_dflash_online.sh b/SpecForge-ext/examples/run_qwen3_8b_dflash_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..44125233a731ec500446411331c94852d7e4ba01 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_8b_dflash_online.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=32 +NUM_GPUS=${1:-1} + +ATTENTION_BACKEND=${2:-flex_attention} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dflash.py \ + --target-model-path Qwen/Qwen3-8B \ + --draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-dflash-sharegpt \ + --num-epochs 20 \ + --batch-size 4 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template qwen \ + --attention-backend $ATTENTION_BACKEND \ + --log-interval 50 \ + --save-interval 1000 \ + diff --git a/SpecForge-ext/examples/run_qwen3_8b_eagle3_online.sh b/SpecForge-ext/examples/run_qwen3_8b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..4aa79654701fb6c05f6430140657ba9e550e4f13 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_8b_eagle3_online.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp8 train eagle3 for Qwen3-4B/8B/32B up to tp_size = 8 +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-8B \ + --draft-model-config $ROOT_DIR/configs/qwen3-8b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/SpecForge-ext/examples/run_qwen3_8b_qwen3eagle_5layer.sh b/SpecForge-ext/examples/run_qwen3_8b_qwen3eagle_5layer.sh new file mode 100644 index 0000000000000000000000000000000000000000..c514da97d961420f8834a7278068b5d421716fb7 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_8b_qwen3eagle_5layer.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Train Qwen3-8B with 5-layer Qwen3-style EAGLE3 draft model +# Uses Qwen3's native architecture (with QK-Norm) instead of Llama's +# +# Key differences from Llama EAGLE: +# - QK-Norm: RMSNorm applied to Q and K projections before RoPE +# - Configurable attention bias (False for Qwen3) +# +# Qwen3-8B dimensions: +# - hidden_size: 4096 +# - intermediate_size: 12288 +# - num_attention_heads: 32 +# - num_key_value_heads: 8 (GQA) +# - head_dim: 128 + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +NUM_GPUS=${1:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path /workspace/Qwen3-8B \ + --draft-model-config $ROOT_DIR/configs/qwen3-8b-qwen3eagle-5layer.json \ + --train-data-path /workspace/hanrui/qwen3-8b_dflash_regen/sharegpt_train_regenerated.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-qwen3eagle-5layer \ + --num-epochs 10 \ + --batch-size 8 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --ttt-length 7 \ + --log-interval 100 \ + --save-interval 5000 \ + --target-model-backend sglang \ + --report-to none + diff --git a/SpecForge-ext/examples/run_qwen3_coder_30b_a3b_eagle3_online.sh b/SpecForge-ext/examples/run_qwen3_coder_30b_a3b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..b88d5fcdca1cf34ae6f0b050c0a5af390cf50c05 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_coder_30b_a3b_eagle3_online.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# Train EAGLE3 draft model for Qwen3-Coder-30B-A3B-Instruct +# Uses the regenerated OPC dataset and TP=4 on GPUs 4,5,6,7 + +# GPU Configuration - Use the later 4 GPUs (4,5,6,7) +export CUDA_VISIBLE_DEVICES=4,5,6,7 +NUM_GPUS=4 +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-Coder-30B-A3B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen3-coder-30B-A3B-instruct-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/opc_regenerated.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-coder-30b-a3b-instruct-eagle3-opc-regen \ + --num-epochs 2 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 4 \ + --dist-timeout 60 \ + --log-interval 50 \ + --save-interval 5000 \ + --eval-interval 5000 \ + --report-to wandb \ + --wandb-project specforge-qwen3-coder \ + --wandb-name qwen3-coder-30b-eagle3-tp4-opc-regen diff --git a/SpecForge-ext/examples/run_qwen3_coder_eagle3_offline.sh b/SpecForge-ext/examples/run_qwen3_coder_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7d0f272bfcd23a1073ee6ca012222b7a8a0df82 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_coder_eagle3_offline.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for qwen3-coder +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-Coder-480B-A35B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen3-coder-480B-A35B-instruct-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/opc.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/Qwen3-Coder-480B-A35B-Instruct \ + --num-epochs 10 \ + --draft-micro-batch-size 1 \ + --draft-global-batch-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template qwen \ + --target-model-backend sglang diff --git a/SpecForge-ext/examples/run_qwen3_coder_eagle3_online.sh b/SpecForge-ext/examples/run_qwen3_coder_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..77f7803301b9678538dcbcde664ab8c74b5451a4 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_coder_eagle3_online.sh @@ -0,0 +1,33 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for qwen3-coder +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8 \ + --draft-model-config $ROOT_DIR/configs/qwen3-coder-480B-A35B-instruct-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/opc_regenerated.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/Qwen3-Coder-480B-A35B-Instruct-FP8 \ + --tp-size $TP_SIZE \ + --sglang-ep-size 2 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-5 \ + --ttt-length 13 \ + --sglang-mem-fraction-static 0.6 \ + --max-length 2048 \ + --chat-template qwen \ + --target-model-backend sglang \ + --save-interval 20000 \ + --eval-interval 20000 \ + --report-to wandb \ + --wandb-project specforge-qwen3-480-coder-fp8 \ + --wandb-name qwen3-coder-480b-a35b-eagle3-tp8-ep2-opc-regen diff --git a/SpecForge-ext/examples/run_qwen3_next_80b_eagle3_online.sh b/SpecForge-ext/examples/run_qwen3_next_80b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..14838913f92ca64ccc7fc49f70389d834bf815d7 --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_next_80b_eagle3_online.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +NUM_GPUS=${1:-8} +TP_SIZE=4 +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path $ROOT_DIR//Qwen/Qwen3-Next-80B-A3B-Instruct-FP8/\ + --draft-model-config $ROOT_DIR/configs/qwen3-next-80b-a3b-eagle3.json \ + --train-data-path $ROOT_DIR/data_qwen80b/qwen3_80b_perfectblend_train_regen.jsonl \ + --output-dir $ROOT_DIR/qwen3-80b-regen-blend \ + --num-epochs 2 \ + --batch-size 2 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --sglang-mem-fraction-static 0.5 \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --target-model-backend sglang diff --git a/SpecForge-ext/results/baseline_humaneval_results_20260213_111307.jsonl b/SpecForge-ext/results/baseline_humaneval_results_20260213_111307.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..c7ca3460c5f6c646e1a3eae193607ed4406b3fea --- /dev/null +++ b/SpecForge-ext/results/baseline_humaneval_results_20260213_111307.jsonl @@ -0,0 +1,23 @@ +{ + "model": null, + "humaneval": [ + { + "batch_size": 1, + "steps": null, + "topk": null, + "num_draft_tokens": null, + "metrics": [ + { + "latency": 685.5040215048939, + "output_throughput": 244.09922444022513, + "accept_length": 2.3569738287742625, + "accuracy": 0.0, + "num_questions": 164, + "num_valid_predictions": 164, + "categorical_performance": null + } + ], + "num_samples": 164 + } + ] +} \ No newline at end of file diff --git a/SpecForge-ext/scripts/prepare_data.py b/SpecForge-ext/scripts/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..dffba5eadd0759295c7cbeecf8e073ef0b8e856c --- /dev/null +++ b/SpecForge-ext/scripts/prepare_data.py @@ -0,0 +1,397 @@ +import argparse +import json +import os +import subprocess +from pathlib import Path +from typing import Dict, Tuple + +from tqdm import tqdm + +from datasets import concatenate_datasets, config, load_dataset + +""" +This script will convert the ultrachat/sharegpt dataset to the following schema in jsonl format: +{ + "id": str, + "conversations": [ + { + "role": str, + "content": str + } + ], +} +""" + +ROLE_MAPPING = { + "human": "user", + "gpt": "assistant", + "chatgpt": "assistant", + "bing": "assistant", + "bard": "assistant", +} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=str, + choices=[ + "ultrachat", + "sharegpt", + "eaglechat", + "perfectblend", + "perfectblend-llama3.1-8b-instruct", + "perfectblend-llama3.3-70b-instruct", + "perfectblend-llama4-scout-instruct", + "perfectblend-llama4-maverick-instruct", + "magpie-qwen2.5-pro-1m-v0.1", + "sharegpt4v", + "allava4v", + "opc", + ], + help="The demo dataset to quickly run the training for speculative decoding", + ) + parser.add_argument( + "--output-path", + type=str, + default=None, + help="The path to save the processed dataset, if not specified, the dataset will be saved in the cache/dataset/dataset_name directory of the root path", + ) + parser.add_argument( + "--data-path", + type=str, + default=None, + help="The path to the custom dataset, if not specified, the default dataset will be loaded", + ) + parser.add_argument( + "--sample-size", + type=int, + default=None, + help="The number of samples to process from the dataset, if not specified, all samples will be processed", + ) + parser.add_argument( + "--split-eval", + action="store_true", + help="Whether to split the dataset into train and eval sets, default is False", + ) + parser.add_argument( + "--opc-subset", + type=str, + default="largescale_diverse_instruct", + choices=[ + "largescale_diverse_instruct", + "filtered_infinity_instruct", + "realuser_instruct", + "all", + ], + help="The subset of OpenCoder opc-sft-stage1 dataset to use, or 'all' to use all subsets (default: largescale_diverse_instruct)", + ) + return parser.parse_args() + + +def get_cache_dir(dataset_name): + cache_dir = None + if dataset_name == "sharegpt4v": + raise ValueError("Downloading 'sharegpt4v' is not supported.") + elif dataset_name == "allava4v": + cache_dir = os.path.join( + config.HF_DATASETS_CACHE, "FreedomIntelligence", "ALLaVA" + ) + else: + raise ValueError( + f"Dataset '{dataset_name}' is not a supported VLM dataset for download." + ) + return cache_dir + + +def download_vlm_dataset(dataset_name: str) -> None: + """Download VLM's dataset such as sharegpt4v and allava4v""" + if dataset_name == "sharegpt4v": + raise Exception("Don't Support Download sharegpt4v.") + elif dataset_name == "allava4v": + cache_dir = get_cache_dir(dataset_name) + os.makedirs(cache_dir, exist_ok=True) + script_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "datasets", + "download_laion.sh", + ) + os.chmod(script_path, 0o755) + if not os.path.exists( + os.path.join(cache_dir, "allava_laion", "image_chunks", "images_0.zip") + ): + result = subprocess.run( + ["bash", script_path], + cwd=cache_dir, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"Download image dataset failed: {result.stderr}") + print("##### allava4v dataset Download Complete #####") + else: + print("##### allava4v dataset has existed.") + else: + raise Exception(f"Don't support {dataset_name}") + + +def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the ultrachat dataset. + + The function expects a row with the following schema: + "messages": [ + { + "role": "user" | "assistant", + "content": str + } + ] + """ + conversations = row["messages"] + formatted_conversations = [] + for message in conversations: + role = message["role"] + content = message["content"] + assert role in ["user", "assistant"] + formatted_conversations.append({"role": role, "content": content}) + row = {"id": row["prompt_id"], "conversations": formatted_conversations} + return row, 0 + + +def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """ + sharegpt dataset schema: + { + "conversations": [ + { + "from": , + "value": , + }, + ... + ] + } + """ + conversations = row["conversations"] + formatted_conversations = [] + skipped_count = 0 + for message in conversations: + if message["from"] not in ROLE_MAPPING: + skipped_count += 1 + continue + new_role = ROLE_MAPPING[message["from"]] + content = message["value"] + formatted_conversations.append({"role": new_role, "content": content}) + + row = {"id": row["id"], "conversations": formatted_conversations} + return row, skipped_count + + +def process_sharegpt4v_row(row, dataset_name: str = None) -> Dict: + """ + sharegpt4v dataset schema: + { + "id": str, + "image": str, # path to the image + "conversations": [ + { + "from": , + "value": , + }, + ... + ] + } + """ + cache_dir = get_cache_dir(dataset_name) + conversations = row["conversations"] + image = os.path.join(cache_dir, row["image"]) + if not os.path.exists(image): + print(f"Image path {image} does not exist, skipping this sample.") + return None, None + formatted_conversations = [] + skipped_count = 0 + for message in conversations: + if message["from"] not in ROLE_MAPPING: + skipped_count += 1 + continue + new_role = ROLE_MAPPING[message["from"]] + if new_role == "user": + text_content = message["value"].replace("\n", "") + content = text_content + else: + content = message["value"] + formatted_conversations.append({"role": new_role, "content": content}) + + row = {"id": row["id"], "image": image, "conversations": formatted_conversations} + return row, skipped_count + + +def load_dataset_from_path(data_path: Path): + suffix = data_path.suffix.split(".")[1] + ds = load_dataset(suffix, data_files=str(data_path), split="train") + return ds + + +def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name): + train_output_jsonl_path = output_path.joinpath(f"{dataset_name}_train.jsonl") + if train_output_jsonl_path.exists(): + print( + f"The dataset {dataset_name} has already been processed and saved in {train_output_jsonl_path}, skipping..." + ) + return + + total_skipped_count = 0 + with open(train_output_jsonl_path, "w") as f: + for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"): + if proc_fn is not None: + row, skipped_count = proc_fn(item, dataset_name) + if row is None: + continue + total_skipped_count += skipped_count + else: + row = item + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + if test_ds is not None: + test_output_jsonl_path = output_path.joinpath(f"{dataset_name}_test.jsonl") + with open(test_output_jsonl_path, "w") as f: + for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"): + if proc_fn is not None: + row, skipped_count = proc_fn(item, dataset_name) + if row is None: + continue + total_skipped_count += skipped_count + else: + row = item + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + if total_skipped_count > 0: + total_messages = len(train_ds) + (len(test_ds) if test_ds is not None else 0) + print( + f"Skipped {total_skipped_count}/{total_messages} messages for {dataset_name}" + ) + + +import hashlib + + +def process_opc_sft_stage1(row: Dict) -> Tuple[Dict, int]: + row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["instruction"]}, + {"role": "assistant", "content": row["output"]}, + ], + } + return processed_row, 0 + + +def add_index(row, idx) -> Dict: + row["id"] = idx + return row + + +def main(): + args = parse_args() + # load dataset + if args.dataset == "ultrachat": + ds = load_dataset("HuggingFaceH4/ultrachat_200k")["train_sft"] + proc_fn = process_ultrachat_row + elif args.dataset == "sharegpt": + if args.data_path is None: + ds = load_dataset("Aeala/ShareGPT_Vicuna_unfiltered")["train"] + else: + print("Loading dataset from custom data path: ", args.data_path) + ds = load_dataset_from_path(Path(args.data_path)) + proc_fn = process_sharegpt_row + elif args.dataset == "eaglechat": + ds = load_dataset("zhaode/EagleChat")["train"] + proc_fn = lambda row: (row, 0) + elif args.dataset == "perfectblend": + ds = load_dataset("mlabonne/open-perfectblend")["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = process_sharegpt_row + elif args.dataset == "perfectblend-llama3.1-8b-instruct": + ds = load_dataset("frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct")[ + "train" + ] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "perfectblend-llama3.3-70b-instruct": + ds = load_dataset( + "frankleeeee/PerfectBlend-Regenerated-Llama-3.3-70B-Instruct" + )["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "perfectblend-llama4-scout-instruct": + ds = load_dataset( + "frankleeeee/PerfectBlend-Regenerated-Llama-4-Scout-17B-16E-Instruct" + )["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "perfectblend-llama4-maverick-instruct": + ds = load_dataset( + "frankleeeee/PerfectBlend-Regenerated-Llama-4-Maverick-17B-128E-Instruct" + )["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "magpie-qwen2.5-pro-1m-v0.1": + ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1")["train"] + ds = ds.rename_column("uuid", "id") + proc_fn = process_sharegpt_row + elif args.dataset == "sharegpt4v": + ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"] + raise Exception("Not supported sharegpt4v now") + download_vlm_dataset(args.dataset) + proc_fn = process_sharegpt4v_row + elif args.dataset == "allava4v": + ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[ + "instruct" + ] + download_vlm_dataset(args.dataset) + proc_fn = process_sharegpt4v_row + elif args.dataset == "opc": + if args.opc_subset == "all": + # Load all subsets and concatenate them + subsets = [ + "largescale_diverse_instruct", + "filtered_infinity_instruct", + "realuser_instruct", + ] + datasets_list = [ + load_dataset("OpenCoder-LLM/opc-sft-stage1", subset)["train"] + for subset in subsets + ] + ds = concatenate_datasets(datasets_list) + else: + ds = load_dataset("OpenCoder-LLM/opc-sft-stage1", args.opc_subset)["train"] + proc_fn = process_opc_sft_stage1 + else: + raise ValueError( + f"This script only supports ultrachat, sharegpt, sharegpt4v, allava4v, opc, and perfect-blend-gptoss-20B datasets for demo purpose, if you wish to use other datasets, please modify this script." + ) + # filter and split dataset + if args.sample_size is not None and args.sample_size < len(ds): + ds = ds.select(range(args.sample_size)) + print(f"Processing {args.sample_size} samples from the dataset {args.dataset}") + if args.split_eval: + ds = ds.train_test_split(test_size=0.05) + train_ds = ds["train"] + test_ds = ds["test"] + else: + train_ds = ds + test_ds = None + + if args.output_path is None: + root_path = Path(__file__).parent.parent + output_path = root_path.joinpath("cache", "dataset") + output_path.mkdir(parents=True, exist_ok=True) + else: + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + process_and_save_ds(train_ds, test_ds, output_path, proc_fn, args.dataset) + + +if __name__ == "__main__": + main() diff --git a/SpecForge-ext/scripts/prepare_hidden_states.py b/SpecForge-ext/scripts/prepare_hidden_states.py new file mode 100644 index 0000000000000000000000000000000000000000..d201ca47972fdae08be0a98069fde450522d2442 --- /dev/null +++ b/SpecForge-ext/scripts/prepare_hidden_states.py @@ -0,0 +1,678 @@ +""" +This script will generate the hidden states for the dataset use transformer as the target model backend. +By generating hidden states in advance, we can avoid: +- the memory overhead of loading target model +- the latency overhead of generating hidden states for each request. + +Optimized for lower memory usage and higher efficiency. + +Usage: +torchrun --nproc_per_node=8 \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path ./cache/dataset/sharegpt_train.jsonl \ + --output-path ./cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --chat-template llama3 \ + --max-length 2048 \ + --tp-size 1 \ + --batch-size 32 \ + --num-samples 1000 \ + --output-path ./cache/hidden_states + +For pre-formatted data (with chat template already applied), add --is-preformatted: +torchrun --nproc_per_node=8 \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path ./cache/dataset/preformatted_data.jsonl \ + --output-path ./cache/hidden_states \ + --chat-template llama3 \ + --is-preformatted \ + --max-length 2048 +""" + +import argparse +import gc +import hashlib +import os +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +from tqdm import tqdm +from transformers import AutoConfig, AutoProcessor, AutoTokenizer + +from datasets import Dataset +from specforge.args import SGLangBackendArgs +from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders +from specforge.distributed import ( + destroy_distributed, + get_dp_group, + get_tp_group, + init_distributed, + is_tp_rank_0, +) +from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model +from specforge.utils import ( + print_with_rank, + rank_0_priority, + safe_conversations_generator, +) + + +@dataclass +class DataPoint: + input_ids: torch.Tensor + loss_mask: torch.Tensor + hidden_state: torch.Tensor + aux_hidden_state: Optional[torch.Tensor] = None + + +def parse_args(): + parser = argparse.ArgumentParser() + + # model-related arguments + model_group = parser.add_argument_group("model") + model_group.add_argument("--target-model-path", type=str, required=True) + model_group.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code when loading models", + ) + model_group.add_argument( + "--is-vlm", action="store_true", help="Whether the target model is a VLM" + ) + model_group.add_argument("--enable-aux-hidden-states", action="store_true") + model_group.add_argument("--aux-hidden-states-layers", type=str, default=None) + + data_group = parser.add_argument_group("data") + data_group.add_argument("--data-path", type=str, required=True) + data_group.add_argument("--max-length", type=int, default=2048) + data_group.add_argument("--chat-template", type=str, default="llama3") + data_group.add_argument( + "--is-preformatted", + action="store_true", + help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.", + ) + data_group.add_argument("--num-samples", type=int, default=None) + data_group.add_argument("--build-dataset-num-proc", type=int, default=8) + + inference_group = parser.add_argument_group("inference") + inference_group.add_argument("--tp-size", type=int, default=1) + inference_group.add_argument("--batch-size", type=int, default=32) + + others_group = parser.add_argument_group("others") + others_group.add_argument("--cache-dir", type=str, default="./cache") + others_group.add_argument("--output-path", type=str, default=None) + others_group.add_argument( + "--model-download-dir", + type=str, + default=None, + help="The directory to download the target model to", + ) + others_group.add_argument( + "--dist-timeout", + type=int, + default=2000, + help="Timeout for collective communication in minutes, default to 2000 so that it does not go timeout", + ) + others_group.add_argument( + "--num-io-threads", + type=int, + default=4, + help="Number of threads for async I/O operations", + ) + others_group.add_argument( + "--num-workers", type=int, default=4, help="Number of workers for DataLoader" + ) + others_group.add_argument( + "--io-queue-size", + type=int, + default=50, + help="Max number of pending I/O futures.", + ) + others_group.add_argument( + "--file-group-size", + type=int, + default=2000, + help="Number of files per subdirectory.", + ) + + sglang_group = parser.add_argument_group("sglang") + SGLangBackendArgs.add_args(sglang_group) + return parser.parse_args() + + +def build_target_model( + args: argparse.Namespace, model_config: AutoConfig +) -> Tuple[Eagle3TargetModel, Optional[AutoProcessor]]: + """ + Build the target model according to the arguments. + + For VLM models (Qwen2.5-VL) without TP, load directly from transformers. + Otherwise, use the Eagle3 target model wrapper. + """ + if args.is_vlm and model_config.model_type == "qwen2_5_vl" and args.tp_size == 1: + # TODO: replace with sglang + from transformers import Qwen2_5_VLForConditionalGeneration + + target_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=( + model_config.dtype + if hasattr(model_config, "dtype") + else model_config.torch_dtype + ), + ) + .eval() + .cuda() + ) + else: + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + target_model = get_eagle3_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend="sglang", # we set this as the default backend to minimize precision mismatch in training and serving + torch_dtype=( + model_config.dtype + if hasattr(model_config, "dtype") + else model_config.torch_dtype + ), + device="cuda", + cache_dir=args.model_download_dir, + trust_remote_code=args.trust_remote_code, + **target_model_kwargs, + ) + # Set auxiliary hidden states layers if specified + target_model.set_aux_hidden_states_layers(args.aux_hidden_states_layers) + + if args.is_vlm: + processor = AutoProcessor.from_pretrained(args.target_model_path) + else: + processor = None + + return target_model, processor + + +class HiddenStatesGenerator: + """ + This is a generator for creating and saving the hidden states based on the target model. + It includes the following features: + 1. Fixes a potential deadlock in TP > 1 scenarios when a batch is skipped. + 2. Implements a context manager (`with` statement) for robust resource handling. + 3. Makes internal settings (like queue sizes, group sizes) configurable. + 4. Centralizes resource cleanup logic. + """ + + def __init__( + self, + target_model, + enable_aux_hidden_states: bool = True, + num_io_threads: int = 4, + io_queue_size: int = 50, + file_group_size: int = 2000, + ): + """ + Args: + target_model: The model for inference. + enable_aux_hidden_states: Whether to save auxiliary hidden states. + num_io_threads: Number of threads for async I/O. + io_queue_size: Max number of pending I/O futures before cleanup. + file_group_size: Number of files per subdirectory. + """ + self.model = target_model + self.enable_aux_hidden_states = enable_aux_hidden_states + + # --- Configurable parameters --- + self.num_io_threads = num_io_threads + self.io_queue_size = io_queue_size + self.file_group_size = file_group_size + + # progress bar should only shown on TP rank = 0 + self.show_progress = dist.get_rank(get_tp_group()) == 0 + + # --- REFACTOR: Thread pool is now managed by __enter__ and __exit__ --- + self.io_executor = None + self.pending_futures = [] + + def __enter__(self): + """Initializes resources when entering a 'with' block.""" + if is_tp_rank_0(): + self.io_executor = ThreadPoolExecutor(max_workers=self.num_io_threads) + self.pending_futures = [] + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Cleans up resources when exiting a 'with' block.""" + if is_tp_rank_0() and self.io_executor is not None: + if self.show_progress: + print("\nWaiting for all async I/O operations to complete...") + self._wait_all_saves() + self.io_executor.shutdown(wait=True) + self.io_executor = None # Reset for safety + + # Final barrier to ensure all processes exit generate() cleanly + dist.barrier() + + def _save_tensor_sync(self, data_point: DataPoint, output_file: str) -> None: + """ + Save a data point to a file synchronously. If there is any NaN value in the data, this datapoint will be skipped. + + Args: + data_point (DataPoint): The data point to save. + output_file (str): The path to the output file. + """ + if data_point.hidden_state is not None and torch.any( + torch.isnan(data_point.hidden_state) + ): + print( + f"Warning: NaN found in hidden_state for {output_file}. Skipping save." + ) + return + + if data_point.aux_hidden_state is not None and torch.any( + torch.isnan(data_point.aux_hidden_state) + ): + print( + f"Warning: NaN found in aux_hidden_state for {output_file}. Skipping save." + ) + return + + torch.save(asdict(data_point), output_file) + + def _save_tensor_async(self, data_point: DataPoint, output_file: str) -> None: + """ + Submit a job to the io_executor to save the data point asynchronously. + + Args: + data_point (DataPoint): The data point to save. + output_file (str): The path to the output file. + """ + assert is_tp_rank_0(), "Only tp_rank=0 should call _save_tensor_async" + # If the queue of pending save operations is full, we must wait. + if len(self.pending_futures) >= self.io_queue_size: + # First, try to clear any futures that have already finished without waiting. + self.pending_futures = [f for f in self.pending_futures if not f.done()] + # If the queue is *still* full, it means all I/O threads are busy and we have + # a backlog. We must now block the main generation loop and wait for the + # oldest I/O operation to complete before proceeding. + if len(self.pending_futures) >= self.io_queue_size: + self.pending_futures.pop(0).result() + + future = self.io_executor.submit( + self._save_tensor_sync, data_point, output_file + ) + self.pending_futures.append(future) + + def _wait_all_saves(self): + """ + This method is to ensure that all submitted jobs are completed. + """ + if is_tp_rank_0() and self.pending_futures: + for future in tqdm( + self.pending_futures, + desc="Finalizing Writes", + disable=not self.show_progress, + ): + future.result() # Wait and raise exception if any + self.pending_futures.clear() + + def _prepare_output_dirs( + self, output_path: str, start_idx: int, total_samples: int + ) -> None: + """ + The dataset is organized into groups of files, each group has a folder which contains the files for this group. For example, if the + file_group_size is 2000, the 0-1999 samples will be saved in the folder "rows_0-2000", the 2000-3999 samples will be saved in the folder "rows_2000-4000", etc. + + Args: + output_path (str): The path to the output directory. + start_idx (int): The starting index of the samples to save. + total_samples (int): The total number of samples to save. + + Returns: + None + """ + if not is_tp_rank_0() or total_samples == 0: + return + start_group = (start_idx // self.file_group_size) * self.file_group_size + end_sample_idx = start_idx + total_samples - 1 + end_group = (end_sample_idx // self.file_group_size) * self.file_group_size + for group_start_idx in range(start_group, end_group + 1, self.file_group_size): + grouped_subdir = ( + f"rows_{group_start_idx}-{group_start_idx + self.file_group_size}" + ) + output_dir = os.path.join(output_path, grouped_subdir) + os.makedirs(output_dir, exist_ok=True) + + def _check_existing_files_batch( + self, output_path: str, global_indices: List[int] + ) -> List[bool]: + """ + A helper function to check if the files for the given global indices exist. + + Args: + output_path (str): The path to the output directory. + global_indices (List[int]): The global indices of the samples to check. + + Returns: + List[bool]: A list of booleans indicating if the files for the given global indices exist. + """ + if not is_tp_rank_0(): + return [False] * len(global_indices) + + def check_single_file(idx): + return os.path.exists(self._get_file_path(output_path, idx)) + + # Parallel file existence check + with ThreadPoolExecutor(max_workers=self.num_io_threads) as executor: + exists = list(executor.map(check_single_file, global_indices)) + return exists + + def _get_file_path(self, output_path: str, idx: int) -> str: + """ + A helper function to get the standard file path for the data point with the given index. + + Args: + output_path (str): The path to the output directory. + idx (int): The global index of the data point. + + Returns: + str: The file path for the data point. + """ + group_idx = (idx // self.file_group_size) * self.file_group_size + grouped_subdir = f"rows_{group_idx}-{group_idx + self.file_group_size}" + return os.path.join(output_path, grouped_subdir, f"data_{idx}.ckpt") + + @torch.no_grad() + def generate( + self, + data_loader: torch.utils.data.DataLoader, + output_path: str, + start_idx: int = 0, + samples_per_dp: int = 0, + ): + """ + This version prioritizes minimal CPU RAM usage above all else, even at the cost of performance. + - It processes samples one-by-one within the tp_rank_0 process. + - It avoids batching GPU-to-CPU transfers. + - It ensures only one sample's data is in RAM for I/O at any given time. + """ + self._prepare_output_dirs(output_path, start_idx, samples_per_dp) + + tp_group = get_tp_group() + tp_group_ranks = dist.get_process_group_ranks(tp_group) + tp_rank_0_global = tp_group_ranks[0] + global_idx = start_idx + + progress_bar = tqdm( + data_loader, + disable=(not self.show_progress), + desc="Generating Hidden States", + position=dist.get_rank(get_dp_group()), + leave=True, + ) + + total_skipped, total_processed = 0, 0 + + for batch_idx, batch in enumerate(progress_bar): + batch_size = batch["input_ids"].size(0) + current_batch_indices = list(range(global_idx, global_idx + batch_size)) + + # # Step 1: Synchronize valid indices across TP group + # we check which files already exist and sync this info across TP ranks + # if exists, we will skip these samples + if is_tp_rank_0(): + exists_list = self._check_existing_files_batch( + output_path, current_batch_indices + ) + exists_tensor = torch.tensor( + exists_list, dtype=torch.bool, device="cuda" + ) + else: + exists_tensor = torch.tensor( + [False] * batch_size, dtype=torch.bool, device="cuda" + ) + dist.broadcast(exists_tensor, src=tp_rank_0_global, group=tp_group) + + # Step 1: TP rank 0 checks which samples need processing + valid_indices_in_batch = [ + i for i, exists in enumerate(exists_tensor) if not exists + ] + sample_global_indices = [ + current_batch_indices[i] for i in valid_indices_in_batch + ] + num_valid = len(valid_indices_in_batch) + total_skipped += batch_size - num_valid + + # Step 2: Filter batch before moving to GPU to save memory + global_idx += batch_size + filtered_batch = { + "input_ids": batch["input_ids"][valid_indices_in_batch], + "attention_mask": batch["attention_mask"][valid_indices_in_batch], + "loss_mask": batch["loss_mask"][valid_indices_in_batch], + } + del batch + if num_valid == 0: + # Data has already been generated, no sample processing, update progress bar. + if self.show_progress: + progress_bar.set_postfix( + { + "processed": total_processed, + "skipped": total_skipped, + "pending_io": ( + len(self.pending_futures) if is_tp_rank_0() else 0 + ), + } + ) + continue + + filtered_batch_gpu = { + k: v.cuda(non_blocking=True) for k, v in filtered_batch.items() + } + _, _, aux_hidden_states_list, last_hidden_states_list = self.model.extend( + **filtered_batch_gpu, + return_last_hidden_states=True, + return_logits=False, + ) + + del filtered_batch_gpu + + if is_tp_rank_0(): + for i, ( + current_global_idx, + aux_hidden_states, + last_hidden_states, + ) in enumerate( + zip( + sample_global_indices, + aux_hidden_states_list, + last_hidden_states_list, + ) + ): + + # Process ONE sample at a time to minimize CPU RAM footprint + # 1. Transfer only the required slice for one sample to CPU + aux_hidden_states = ( + aux_hidden_states.cpu().clone().unsqueeze(0) + if aux_hidden_states is not None + else None + ) + last_hidden_states = ( + last_hidden_states.cpu().clone().unsqueeze(0) + if last_hidden_states is not None + else None + ) + data_point = DataPoint( + input_ids=filtered_batch["input_ids"][i].clone(), + loss_mask=filtered_batch["loss_mask"][i].clone(), + hidden_state=last_hidden_states, + aux_hidden_state=aux_hidden_states, + ) + + # 3. Save asynchronously (the backpressure logic is still crucial) + output_file = self._get_file_path(output_path, current_global_idx) + self._save_tensor_async(data_point, output_file) + + # 4. Immediately clean up the single-sample CPU tensors + del last_hidden_states, aux_hidden_states + + total_processed += len(sample_global_indices) + + # Clean up the large GPU and CPU batch data + del aux_hidden_states_list, last_hidden_states_list, filtered_batch + + if batch_idx % 5 == 0: # Make GC and cache clearing more frequent + torch.cuda.empty_cache() + gc.collect() + + if self.show_progress: + progress_bar.set_postfix( + { + "processed": total_processed, + "skipped": total_skipped, + "pending_io": ( + len(self.pending_futures) if is_tp_rank_0() else 0 + ), + } + ) + + if self.show_progress: + print( + f"\nGeneration loop finished. Processed: {total_processed}, Skipped: {total_skipped}" + ) + dist.barrier() + + +def main(): + args = parse_args() + if args.aux_hidden_states_layers is not None: + args.aux_hidden_states_layers = [ + int(x) for x in args.aux_hidden_states_layers.split(",") + ] + + # Initialize distributed environment (TP + DP) + init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + + # Build target model (with TP) + target_model_config = AutoConfig.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + target_model, processor = build_target_model(args, target_model_config) + + print_with_rank( + f"DP Rank {dist.get_rank(get_dp_group())}, TP Rank {dist.get_rank(get_tp_group())}, " + f"DP Size {dist.get_world_size(get_dp_group())}, TP Size {dist.get_world_size(get_tp_group())}" + ) + + if args.output_path is None: + args.output_path = os.path.join( + Path(__file__).parent.parent, "cache", "hidden_states" + ) + + # Load complete dataset + assert os.path.exists( + args.data_path + ), f"Dataset path {args.data_path} does not exist" + dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.data_path}, + cache_dir=os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "cache", + "hf_dataset", + ), + ) + if args.num_samples is not None: + dataset = dataset.select(range(args.num_samples)) + # Tokenizer and cache key + tokenizer = AutoTokenizer.from_pretrained( + args.target_model_path, trust_remote_code=True + ) + cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}-{args.is_preformatted}" + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + # Preprocess on complete, un-sharded dataset + with rank_0_priority(): + print_with_rank("Main process is building the dataset cache...") + eagle3_dataset = build_eagle3_dataset( + dataset=dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + is_vlm=args.is_vlm, + is_preformatted=args.is_preformatted, + processor=processor, + num_proc=args.build_dataset_num_proc, + ) + print_with_rank(f"Dataset prepared with {len(eagle3_dataset)} samples.") + + # Create DP-sharded dataloader + data_loader = prepare_dp_dataloaders( + dataset=eagle3_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + shuffle=False, + process_group=get_dp_group(), + is_vlm=args.is_vlm, + ) + + print_with_rank( + f"DataLoader created for DP Rank {dist.get_rank(get_dp_group())}. " + f"Number of batches: {len(data_loader)}" + ) + + # Calculate starting index and sample count for current DP rank + total = len(eagle3_dataset) + dp_rank = dist.get_rank(get_dp_group()) + dp_size = dist.get_world_size(get_dp_group()) + + # Calculate samples per DP rank (handle non-divisible case) + samples_per_dp = total // dp_size + remainder = total % dp_size + + # Earlier ranks handle one extra sample if there's a remainder + if dp_rank < remainder: + samples_per_dp += 1 + start_idx = dp_rank * samples_per_dp + else: + start_idx = dp_rank * samples_per_dp + remainder + + print_with_rank( + f"DP Rank {dp_rank} will process {samples_per_dp} samples, " + f"starting from index {start_idx}" + ) + + # Generate hidden states + try: + # Pass configurable arguments from args if needed + with HiddenStatesGenerator( + target_model, + enable_aux_hidden_states=args.enable_aux_hidden_states, + num_io_threads=args.num_io_threads, + io_queue_size=args.io_queue_size, + file_group_size=args.file_group_size, + # Other params like io_queue_size can also be added to argparse + ) as hidden_states_generator: + + # Generate hidden states + hidden_states_generator.generate( + data_loader, + output_path=args.output_path, + start_idx=start_idx, + samples_per_dp=samples_per_dp, + ) + + finally: + # The finally block ensures destroy_distributed is always called + print_with_rank("All hidden states generated or job finished.") + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/SpecForge-ext/scripts/regenerate_train_data.py b/SpecForge-ext/scripts/regenerate_train_data.py new file mode 100644 index 0000000000000000000000000000000000000000..1781a58d858d80c10f63c5dd390a838c8a820507 --- /dev/null +++ b/SpecForge-ext/scripts/regenerate_train_data.py @@ -0,0 +1,407 @@ +""" +This script will re-generate the dataset from target model, +which better aligns the draft model with the target model’s output distribution. + +Usage: +1. Set up one or more SGLang servers for the target model. + +python3 -m sglang.launch_server \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --mem-fraction-static 0.75 \ + --cuda-graph-max-bs 128 \ + --tp 1 \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 30000 \ + --dtype bfloat16 + + +2. Regenerate the dataset using the `regenerate_train_data.py` script. +python scripts/regenerate_train_data.py \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --concurrency 128 \ + --max-tokens 4096 \ + --server-address localhost:30000 \ + --temperature 0.8 \ + --input-file-path ./cache/dataset/sharegpt_train.jsonl \ + --output-file-path ./cache/dataset/sharegpt_train_regen.jsonl +""" + +import argparse +import json +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List + +from openai import OpenAI +from tqdm import tqdm + + +def parse_arguments(): + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="Re-generate training data using sglang model server" + ) + + # model related arguments + model_group = parser.add_argument_group("model") + model_group.add_argument("--model", type=str, required=True) + model_group.add_argument( + "--is-reasoning-model", + action="store_true", + help="Whether the model is a reasoning model", + ) + model_group.add_argument( + "--is-gpt-oss", + action="store_true", + help="Whether the model is a GPT-OSS model", + ) + + # sampling params + sampling_params_group = parser.add_argument_group("sampling parameters") + sampling_params_group.add_argument( + "--temperature", + type=float, + default=0.7, + help="Temperature for sglang model server", + ) + sampling_params_group.add_argument( + "--top-p", + type=float, + default=None, + help="Nucleus sampling top_p", + ) + sampling_params_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling value sent via extra_body", + ) + sampling_params_group.add_argument( + "--repetition-penalty", + type=float, + default=None, + help="Mapped to presence_penalty in the OpenAI API", + ) + sampling_params_group.add_argument( + "--max-tokens", + type=int, + default=4096, + help="Maximum number of tokens (default: 4096)", + ) + + # optimization + optimization_group = parser.add_argument_group("optimization") + optimization_group.add_argument( + "--concurrency", + type=int, + default=64, + help="The number of requests to send to a single server concurrently, the total number of concurrent requests is concurrency * number of server addresses", + ) + + # data related arguments + data_group = parser.add_argument_group("data") + data_group.add_argument( + "--input-file-path", type=str, required=True, help="Path to the input file" + ) + data_group.add_argument( + "--output-file-path", type=str, required=True, help="Path to the output file" + ) + data_group.add_argument( + "--num-samples", + type=int, + default=None, + help="The number of samples to regenerate, if not provided, all samples will be regenerated", + ) + + # sglang server + server_group = parser.add_argument_group("sglang server") + server_group.add_argument( + "--server-address", + type=str, + nargs="+", + help="Server address and port for sglang model server", + ) + return parser.parse_args() + + +def get_random_reasoning_effort() -> str: + """Get a random reasoning effort level for the model with weighted probabilities.""" + # usage example: https://huggingface.co/openai/gpt-oss-20b/discussions/28 + # Reasoning effort levels with weights: LOW(4), MEDIUM(4), HIGH(2) + reasoning_efforts = [ + "low", + "medium", + "high", + ] + weights = [4, 4, 2] + return random.choices(reasoning_efforts, weights=weights, k=1)[0] + + +def compute_context_length(conversations: List[Dict[str, Any]]) -> int: + """ + This is a rough estimate of the context length measured in untokenized + tokens. + """ + length = 0 + for message in conversations: + content = message.get("content") + if isinstance(content, str): + # {"role": "assistant", "content": "Hi, how can I help?"} + length += len(content.split()) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict): + text = part.get("text") + if isinstance(text, str): + length += len(text.split()) + return length + + +def build_query_kwargs(args, messages, max_tokens=None): + effective_max_tokens = max_tokens if max_tokens is not None else args.max_tokens + + query_kwargs = dict( + model=args.model, + messages=messages, + max_tokens=effective_max_tokens, + temperature=args.temperature, + stream=False, + ) + if args.top_p is not None: + query_kwargs["top_p"] = args.top_p + if args.repetition_penalty is not None: + query_kwargs["presence_penalty"] = args.repetition_penalty + extra_body = {} + if args.top_k is not None: + extra_body["top_k"] = args.top_k + if extra_body: + query_kwargs["extra_body"] = extra_body + if args.is_gpt_oss: + query_kwargs["reasoning_effort"] = get_random_reasoning_effort() + return query_kwargs + + +def call_sglang( + args, + server_address: str, + data: List[Dict[str, Any]], + max_tokens=None, +) -> str: + """Send a batch of prompts to sglang /v1/completions.""" + client = OpenAI(base_url=f"http://{server_address}/v1", api_key="None") + + messages = data["conversations"] + regenerated_messages = [] + + # ignore data which starts with an assistant message + if messages[0]["role"] == "assistant": + data["status"] = "error" + data["error"] = "Data starts with an assistant message" + return data + + for message in messages: + if message["role"] == "system": + regenerated_messages.append(message) + elif message["role"] == "assistant": + continue + elif message["role"] == "user": + regenerated_messages.append(message) + + query_kwargs = build_query_kwargs(args, regenerated_messages, max_tokens) + + try: + resp = client.chat.completions.create(**query_kwargs) + except Exception as e: + data["status"] = "error" + data["error"] = str(e) + return data + response_text = resp.choices[0].message.content + resp_msg = { + "role": "assistant", + "content": response_text, + } + if args.is_reasoning_model: + resp_msg["thinking"] = resp.choices[0].message.reasoning_content + regenerated_messages.append(resp_msg) + else: + data["status"] = "error" + data["error"] = f"Invalid message role: {message['role']}" + return data + data["conversations"] = regenerated_messages + data["status"] = "success" + return data + + +def main(): + # Parse command line arguments + args = parse_arguments() + + # Validate parameters + if not (0.0 <= args.temperature <= 1.0): + raise ValueError("Temperature must be between 0.0 and 1.0") + + if args.max_tokens <= 0: + raise ValueError("Max tokens must be greater than 0") + + print(f"Configuration:") + print(f" Model path: {args.model}") + print(f" Max tokens: {args.max_tokens}") + print(f" Concurrency: {args.concurrency}") + print(f" Temperature: {args.temperature}") + print(f" API URL: {args.server_address}") + print(f" Input file: {args.input_file_path}") + print(f" Output file: {args.output_file_path}") + print("-" * 50) + total_lines = sum(1 for _ in open(args.input_file_path)) + + # test all server addresses + valid_server_addresses = [] + for server_address in args.server_address: + dummy_data = dict( + conversations=[{"role": "user", "content": "Hello, how are you?"}] + ) + result = call_sglang( + args, + server_address, + dummy_data, + max_tokens=1, + ) + if result is not None: + valid_server_addresses.append(server_address) + else: + print(f"Server {server_address} is not available") + + if len(valid_server_addresses) == 0: + raise ValueError("No server address is available") + print( + f"Using {len(valid_server_addresses)} server addresses: {valid_server_addresses}" + ) + print("-" * 50) + + # create error file path if not exists + error_file_path = args.output_file_path.replace(".jsonl", "_error.jsonl") + print( + f"Regenerating dataset and saving the output to {args.output_file_path} and error log to {error_file_path}" + ) + print("-" * 50) + context_token_sum = 0 + context_token_min = None + context_token_max = 0 + success_samples = 0 + error_samples = 0 + + # Create progress bar + with ( + open(args.input_file_path, "r") as input_file, + open(args.output_file_path, "w") as output_file_handle, + open(error_file_path, "w") as error_file_handle, + ): + executor = ThreadPoolExecutor( + max_workers=args.concurrency * len(valid_server_addresses) + ) + waiting_queue = { + server_address: [] for server_address in valid_server_addresses + } + pbar = tqdm(total=total_lines, desc="Processing") + start_server_index = 0 + + for line in input_file: + if ( + args.num_samples is not None + and success_samples + error_samples >= args.num_samples + ): + break + + data = json.loads(line.strip()) + + # find server address with the least waiting requests + server_address = valid_server_addresses[start_server_index] + start_server_index = (start_server_index + 1) % len(valid_server_addresses) + + # submit prompt to sglang + while len(waiting_queue[server_address]) >= args.concurrency: + finished_on_request = False + # check if any future is done, if so, write the result to the output file + for req_future in waiting_queue[server_address]: + if req_future.done(): + regen_data = req_future.result() + + if regen_data["status"] == "error": + error_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + error_samples += 1 + else: + ctx_len = compute_context_length( + regen_data.get("conversations", []) + ) + context_token_sum += ctx_len + if context_token_min is None: + context_token_min = ctx_len + else: + context_token_min = min(context_token_min, ctx_len) + context_token_max = max(context_token_max, ctx_len) + + output_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + success_samples += 1 + waiting_queue[server_address].remove(req_future) + finished_on_request = True + + if finished_on_request: + break + + req_future = executor.submit( + call_sglang, + args, + server_address, + data, + ) + waiting_queue[server_address].append(req_future) + pbar.update(1) + + # deal with all the remaining requests + for server_address, waiting_queue_items in waiting_queue.items(): + for req_future in waiting_queue_items: + regen_data = req_future.result() + if regen_data["status"] == "error": + error_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + error_samples += 1 + else: + ctx_len = compute_context_length( + regen_data.get("conversations", []) + ) + context_token_sum += ctx_len + if context_token_min is None: + context_token_min = ctx_len + else: + context_token_min = min(context_token_min, ctx_len) + context_token_max = max(context_token_max, ctx_len) + + output_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + success_samples += 1 + + print(f"\nProcessing completed!") + if success_samples > 0: + avg_len = context_token_sum / success_samples + print("Context length statistics (token count over conversations):") + print(f"Number of successful examples: {success_samples}") + print(f"Shortest context length: {context_token_min}") + print(f"Longest context length: {context_token_max}") + print(f"Average context length: {avg_len:.2f}") + else: + print("No successful examples to compute context length statistics.") + + print( + f"\nProcessing completed! {success_samples} samples regenerated, {error_samples} samples failed." + ) + + +if __name__ == "__main__": + main() diff --git a/SpecForge-ext/scripts/train_dflash.py b/SpecForge-ext/scripts/train_dflash.py new file mode 100644 index 0000000000000000000000000000000000000000..7baf5ee94632f6deaac1d20ce0c54a693bce729f --- /dev/null +++ b/SpecForge-ext/scripts/train_dflash.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +# coding=utf-8 +"""DFlash Training Script.""" + +import argparse +import logging +import math +import os +import shutil +import time +import warnings +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoConfig, AutoTokenizer + +from datasets import load_dataset +from specforge.args import SGLangBackendArgs, TrackerArgs +from specforge.core.dflash import OnlineDFlashModel +from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders +from specforge.distributed import destroy_distributed, get_dp_group, init_distributed +from specforge.modeling.draft.dflash import DFlashDraftModel +from specforge.modeling.target.dflash_target_model import ( + DFlashTargetModel, + get_dflash_target_model, +) +from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead +from specforge.optimizer import BF16Optimizer +from specforge.tracker import create_tracker +from specforge.utils import print_on_rank0, print_with_rank + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train DFlash Draft Model") + + model_group = parser.add_argument_group("model") + model_group.add_argument("--target-model-path", type=str, required=True) + model_group.add_argument( + "--target-model-backend", + type=str, + default="hf", + choices=["sglang", "hf"], + help="Backend for target model: 'sglang' (service) or 'hf' (local)", + ) + model_group.add_argument("--draft-config-path", type=str, default=None) + model_group.add_argument("--block-size", type=int, default=16) + model_group.add_argument("--num-draft-layers", type=int, default=1) + model_group.add_argument( + "--mask-token-id", + type=int, + default=None, + help="MASK token ID. If not provided, auto-detect from tokenizer.", + ) + model_group.add_argument( + "--attention-backend", + type=str, + default="flex_attention", + choices=["eager", "sdpa", "flex_attention"], + help="Attention backend for draft model.", + ) + model_group.add_argument( + "--trust-remote-code", action="store_true", help="Trust remote code" + ) + + dataset_group = parser.add_argument_group("dataset") + dataset_group.add_argument("--train-data-path", type=str, required=True) + dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument("--chat-template", type=str, default="qwen") + dataset_group.add_argument("--is-preformatted", action="store_true") + dataset_group.add_argument("--dataloader-num-workers", type=int, default=8) + dataset_group.add_argument( + "--build-dataset-num-proc", + type=int, + default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8)), + ) + + training_group = parser.add_argument_group("training") + training_group.add_argument("--num-epochs", type=int, default=3) + training_group.add_argument("--batch-size", type=int, default=1) + training_group.add_argument("--learning-rate", type=float, default=1e-4) + training_group.add_argument("--max-length", type=int, default=2048) + training_group.add_argument("--warmup-ratio", type=float, default=0.01) + training_group.add_argument("--max-grad-norm", type=float, default=1.0) + training_group.add_argument("--accumulation-steps", type=int, default=1) + training_group.add_argument("--seed", type=int, default=42) + training_group.add_argument("--resume", action="store_true") + + output_group = parser.add_argument_group("output") + output_group.add_argument("--output-dir", type=str, required=True) + output_group.add_argument("--cache-dir", type=str, default="./cache") + output_group.add_argument("--log-interval", type=int, default=50) + output_group.add_argument("--eval-interval", type=int, default=1000) + output_group.add_argument("--save-interval", type=int, default=1000) + + optimization_group = parser.add_argument_group("optimization") + optimization_group.add_argument( + "--tp-size", + type=int, + default=1, + help="The size of the tensor parallel for the target model", + ) + + tracker_group = parser.add_argument_group("tracker") + TrackerArgs.add_args(tracker_group) + + dist_group = parser.add_argument_group("distributed") + dist_group.add_argument("--dist-timeout", type=int, default=30) + + # SGLang specific args + sglang_group = parser.add_argument_group("sglang backend") + SGLangBackendArgs.add_args(sglang_group) + + return parser.parse_args() + + +def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: + """Build target model (backend wrapper) and draft model.""" + print_on_rank0( + f"Loading target model from {args.target_model_path} using {args.target_model_backend} backend" + ) + + # 1. Build Target Model Wrapper + target_model_kwargs = {} + if args.target_model_backend == "sglang": + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + + target_model = get_dflash_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend=args.target_model_backend, + torch_dtype=torch.bfloat16, + device="cuda" if args.target_model_backend == "hf" else None, + trust_remote_code=args.trust_remote_code, + **target_model_kwargs, + ) + + # 2. Build Draft Model + if args.draft_config_path: + draft_config = AutoConfig.from_pretrained(args.draft_config_path) + print_on_rank0(f"Loaded draft config from {args.draft_config_path}") + else: + # Load config from HF (needed for structure info even if backend is sglang) + target_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config.num_hidden_layers = args.num_draft_layers + draft_config.block_size = args.block_size + draft_config.num_target_layers = target_config.num_hidden_layers + print_on_rank0("Auto-generated draft config from target model") + + # Set attention implementation based on backend + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + + draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) + + # Set capture layers for target model based on draft model config + target_model.set_capture_layers(draft_model.target_layer_ids) + + print_on_rank0( + f"Draft config: block_size={draft_config.block_size}, " + f"num_hidden_layers={draft_config.num_hidden_layers}, " + f"num_target_layers={draft_config.num_target_layers}" + ) + print_on_rank0( + f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}" + ) + + return target_model, draft_model + + +def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: + """Build train and eval dataloaders.""" + import hashlib + + # convert to dataloader + cache_params_string = ( + f"{args.train_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}" + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + + # Filter out samples with too few loss tokens (DFlash requires >= 2 * block_size) + min_loss_tokens = 2 * args.block_size + original_size = len(train_eagle3_dataset) + train_eagle3_dataset = train_eagle3_dataset.filter( + lambda x: x["loss_mask"].sum() >= min_loss_tokens + ) + print_on_rank0( + f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples" + ) + + train_dataloader = prepare_dp_dataloaders( + train_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=True, + process_group=get_dp_group(), + ) + + eval_dataloader = None + if args.eval_data_path: + eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_eagle3_dataset = build_eagle3_dataset( + dataset=eval_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + ) + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=False, + process_group=get_dp_group(), + ) + + return train_dataloader, eval_dataloader + + +def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer): + """Save checkpoint.""" + save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + if dist.get_rank() == 0: + os.makedirs(save_dir, exist_ok=True) + dist.barrier() + + with FSDP.state_dict_type(dflash_model, StateDictType.FULL_STATE_DICT): + state_dict = dflash_model.state_dict() + draft_state_dict = { + k.replace("draft_model.", ""): v + for k, v in state_dict.items() + if "draft_model." in k + } + + if dist.get_rank() == 0: + torch.save( + { + "epoch": epoch, + "global_step": step, + "args": args, + **optimizer.state_dict(), + }, + os.path.join(save_dir, "training_state.pt"), + ) + + draft_model.save_pretrained(save_dir, state_dict=draft_state_dict) + + # Copy modeling_dflash.py for inference compatibility + modeling_src = os.path.join( + os.path.dirname(__file__), + "..", + "specforge", + "modeling", + "draft", + "dflash.py", + ) + modeling_dst = os.path.join(save_dir, "modeling_dflash.py") + if os.path.exists(modeling_src): + shutil.copy(modeling_src, modeling_dst) + + print_on_rank0(f"Saved checkpoint to {save_dir}") + + dist.barrier() + + +def record_metrics( + args, + loss: float, + accuracy: float, + global_step: int, + tracker, + optimizer, + train_dataloader=None, + mode: str = "train", +) -> None: + logdict = {} + + if mode == "train" and optimizer is not None: + logdict["train/lr"] = optimizer.get_learning_rate() + + logdict[f"{mode}/loss"] = loss + logdict[f"{mode}/accuracy"] = accuracy + + print_on_rank0( + f"{mode.capitalize()} - Step {global_step} [{global_step}/{args.num_epochs * len(train_dataloader) // args.accumulation_steps}?], Loss: {loss:.4f}, Acc: {accuracy:.4f}" + ) + + tracker.log(logdict, step=global_step) + + +def main(): + # Configure logging to ensure we see INFO logs + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Force the root logger to INFO as well, just in case + logging.getLogger().setLevel(logging.INFO) + + # Filter annoying FSDP warnings + warnings.filterwarnings( + "ignore", + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", + ) + + args = parse_args() + set_seed(args.seed) + + init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_with_rank("Initialized distributed") + + target_model, draft_model = build_models(args) + + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + + # Get mask_token_id + if args.mask_token_id is not None: + mask_token_id = args.mask_token_id + elif tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + else: + tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + mask_token_id = tokenizer.mask_token_id + print_on_rank0(f"Using mask_token_id: {mask_token_id}") + + train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) + + steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) + total_steps = args.num_epochs * steps_per_epoch + print_on_rank0(f"Total training steps: {total_steps}") + + # Note: We need embedding layer for DFlash wrapper. + # For SGLang backend, we can't easily get the embedding layer object. + # We use TargetEmbeddingsAndHead to efficiently load only needed weights. + print_on_rank0("Loading target embeddings and head efficiently...") + target_components = TargetEmbeddingsAndHead.from_pretrained( + args.target_model_path, + embed_key="model.embed_tokens.weight", # Adjust if Qwen/Llama differs + lm_head_key="lm_head.weight", + device="cuda", + trust_remote_code=args.trust_remote_code, + ) + + dflash_model = OnlineDFlashModel( + draft_model=draft_model, + target_lm_head=target_components.lm_head, + target_embed_tokens=target_components.embed_tokens, + block_size=draft_model.block_size, + mask_token_id=mask_token_id, + attention_backend=args.attention_backend, + ) + + dflash_model = FSDP( + dflash_model, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, + ) + print_with_rank("Initialized FSDP") + + optimizer = BF16Optimizer( + draft_model, + lr=args.learning_rate, + max_grad_norm=args.max_grad_norm, + warmup_ratio=args.warmup_ratio, + total_steps=total_steps, + ) + + print_on_rank0(f"Initializing tracker (report_to={args.report_to})...") + tracker = create_tracker(args, args.output_dir) + print_on_rank0("Tracker initialized successfully.") + + global_step = 0 + last_time = time.time() + + for epoch in range(args.num_epochs): + train_dataloader.sampler.set_epoch(epoch) + draft_model.train() + + if dist.get_rank() == 0: + progress_bar = tqdm( + train_dataloader, desc=f"Training Epoch {epoch}", leave=True + ) + else: + progress_bar = train_dataloader + + for data in progress_bar: + global_step += 1 + + input_ids = data["input_ids"].cuda() + attention_mask = data["attention_mask"].cuda() + loss_mask = data["loss_mask"].cuda() + + # Generate context from Target Model (SGLang or HF) + # This calls the backend to get hidden states + target_output = target_model.generate_dflash_data( + input_ids, attention_mask, loss_mask + ) + hidden_states = target_output.hidden_states.cuda() # Ensure on GPU + + # Forward pass (Parallel Training) + loss, accuracy = dflash_model( + input_ids=input_ids, + attention_mask=attention_mask, + hidden_states=hidden_states, + loss_mask=loss_mask, + ) + + (loss / args.accumulation_steps).backward() + + if global_step % args.accumulation_steps == 0: + optimizer.step() + + if global_step % args.log_interval == 0: + loss_log = loss.clone() + acc_log = accuracy.clone() + dist.all_reduce(loss_log) + dist.all_reduce(acc_log) + loss_log = loss_log / dist.get_world_size() + acc_log = acc_log / dist.get_world_size() + + record_metrics( + args, + loss_log.item(), + acc_log.item(), + global_step, + tracker, + optimizer, + train_dataloader, + mode="train", + ) + + if dist.get_rank() == 0: + elapsed = time.time() - last_time + last_time = time.time() + progress_bar.set_postfix( + { + "loss": f"{loss.item():.4f}", + "acc": f"{accuracy.item():.4f}", + "iter_time": f"{elapsed:.2f}s", + } + ) + + if global_step % args.save_interval == 0: + save_checkpoint( + args, epoch, global_step, dflash_model, draft_model, optimizer + ) + + save_checkpoint( + args, args.num_epochs, global_step, dflash_model, draft_model, optimizer + ) + + tracker.close() + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/SpecForge-ext/scripts/train_eagle3.py b/SpecForge-ext/scripts/train_eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..4237dc94a43f38b2f39ddf08f4e333e64188bc4e --- /dev/null +++ b/SpecForge-ext/scripts/train_eagle3.py @@ -0,0 +1,988 @@ +import argparse +import hashlib +import math +import os +import time +from argparse import ArgumentParser, Namespace +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from accelerate.utils import set_seed +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoProcessor, AutoTokenizer + +from datasets import Dataset +from specforge import ( + AutoDraftModelConfig, + AutoEagle3DraftModel, + OnlineEagle3Model, + QwenVLOnlineEagle3Model, +) +from specforge.args import SGLangBackendArgs, TrackerArgs +from specforge.data import ( + build_eagle3_dataset, + build_offline_eagle3_dataset, + generate_vocab_mapping_file, + prepare_dp_dataloaders, +) +from specforge.distributed import ( + destroy_distributed, + get_dp_group, + get_draft_dp_group, + get_draft_sp_group, + get_tp_group, + init_distributed, +) +from specforge.modeling.target import ( + Eagle3TargetModel, + TargetHead, + get_eagle3_target_model, +) +from specforge.optimizer import BF16Optimizer +from specforge.tracker import Tracker, create_tracker, get_tracker_class +from specforge.utils import ( + create_draft_config_from_target, + get_last_checkpoint, + print_args_with_dots, + print_on_rank0, + print_with_rank, + rank_0_priority, + safe_conversations_generator, +) + + +def parse_args() -> Tuple[ArgumentParser, Namespace]: + """ + This function is used to parse the arguments for the training script. + """ + parser = argparse.ArgumentParser(description="Train Eagle3 with online data") + + # add model-related arguments + model_group = parser.add_argument_group("model") + model_group.add_argument("--target-model-path", type=str, required=True) + model_group.add_argument( + "--trust-remote-code", action="store_true", help="Trust remote code" + ) + model_group.add_argument( + "--draft-model-config", + type=str, + required=False, + help="Draft model config path. If not provided, will auto-generate from target model.", + ) + model_group.add_argument( + "--embedding-key", + type=str, + default="model.embed_tokens.weight", + help="The key of the embedding weight to load from the target model", + ) + model_group.add_argument( + "--lm-head-key", + type=str, + default="lm_head.weight", + help="The key of the lm head weight to load from the target model, this is only required for offline training", + ) + model_group.add_argument( + "--is-vlm", action="store_true", help="Whether the target model is a VLM" + ) + model_group.add_argument( + "--target-model-backend", + type=str, + default="sglang", + choices=["sglang", "hf", "custom"], + help="The backend of the target model", + ) + + # dataset arguments + dataset_group = parser.add_argument_group("dataset") + dataset_group.add_argument("--train-data-path", type=str, required=True) + dataset_group.add_argument("--train-hidden-states-path", type=str, default=None) + dataset_group.add_argument("--eval-hidden-states-path", type=str, default=None) + dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument("--chat-template", type=str, default="llama3") + dataset_group.add_argument( + "--is-preformatted", + action="store_true", + help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.", + ) + dataset_group.add_argument( + "--train-only-last-turn", + action="store_true", + help="If set, only the last assistant turn in each conversation contributes to the loss. " + "Useful for thinking models where conversation history may lack thought processes.", + ) + dataset_group.add_argument("--build-dataset-num-proc", type=int, default=8) + dataset_group.add_argument( + "--dataloader-num-workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + # training hyper params + training_group = parser.add_argument_group("training") + training_group.add_argument("--num-epochs", type=int, default=10) + training_group.add_argument( + "--max-num-steps", + type=int, + default=None, + help="The maximum number of steps to train. If not provided, will be calculated as num_epochs * steps_per_epoch", + ) + training_group.add_argument("--batch-size", type=int, default=1) + training_group.add_argument("--learning-rate", type=float, default=1e-4) + training_group.add_argument("--max-length", type=int, default=2048) + training_group.add_argument("--warmup-ratio", type=float, default=0.015) + training_group.add_argument( + "--total-steps", + type=int, + default=None, + help="Total training steps. If not provided, will be calculated as num_epochs * steps_per_epoch", + ) + training_group.add_argument("--max-grad-norm", type=float, default=0.5) + training_group.add_argument( + "--ttt-length", + type=int, + default=7, + help="The length for Test-Time Training (TTT).", + ) + training_group.add_argument("--resume", action="store_true") + training_group.add_argument( + "--ckpt-dir", + type=str, + default=None, + help="directory includes the checkpoint to start training with", + ) + training_group.add_argument("--eval-interval", type=int, default=5000) + training_group.add_argument("--save-interval", type=int, default=5000) + training_group.add_argument( + "--log-interval", + type=int, + default=50, + help="Log training metrics every N steps", + ) + training_group.add_argument("--seed", type=int, default=0) + training_group.add_argument("--draft-accumulation-steps", type=int, default=1) + + # data processing type + optimization_group = parser.add_argument_group("optimization") + optimization_group.add_argument( + "--tp-size", + type=int, + default=1, + help="The size of the tensor parallel for the target model", + ) + # distributed training + optimization_group.add_argument("--sp-ulysses-size", type=int, default=1) + optimization_group.add_argument("--sp-ring-size", type=int, default=1) + optimization_group.add_argument( + "--attention-backend", + type=str, + default="flex_attention", + help="The attention backend for the draft model", + ) + + # other args + other_group = parser.add_argument_group("others") + other_group.add_argument("--cache-key", type=str, default=None) + other_group.add_argument("--cache-dir", type=str, default="./cache") + other_group.add_argument("--output-dir", type=str, required=True) + other_group.add_argument("--verbose", action="store_true") + other_group.add_argument( + "--dist-timeout", + type=int, + default=20, + help="Timeout for collective communication in minutes", + ) + other_group.add_argument( + "--model-download-dir", + type=str, + default=None, + help="The directory to download the target model to", + ) + + # vlm related args + vlm_group = parser.add_argument_group("vlm") + vlm_group.add_argument( + "--min-pixels", type=int, default=50176 + ) # 64*28*28 for qwen2.5-vl + vlm_group.add_argument( + "--max-pixels", type=int, default=802816 + ) # 1024*28*28 for qwen2.5-vl + + # profiling related args + profiling_group = parser.add_argument_group("profiling") + profiling_group.add_argument("--profile", action="store_true") + profiling_group.add_argument("--profile-start-step", type=int, default=30) + profiling_group.add_argument("--profile-num-steps", type=int, default=4) + profiling_group.add_argument("--profile-record-shapes", action="store_true") + + # sglang target model backend related args + sglang_group = parser.add_argument_group("sglang target model backend") + SGLangBackendArgs.add_args(sglang_group) + + # tracker related args + tracker_group = parser.add_argument_group("tracker") + TrackerArgs.add_args(tracker_group) + + args = parser.parse_args() + return parser, args + + +def build_tracker(args: Namespace, parser: ArgumentParser) -> Tracker: + """ + Build the experiment tracker according to the report_to argument. + + Args: + args: The arguments for the training script. + parser: The parser for the training script. + + Returns: + The experiment tracker. + """ + tracker_class = get_tracker_class(args.report_to) + if tracker_class: + tracker_class.validate_args(parser, args) + else: + parser.error(f"Unknown tracker: {args.report_to}") + tracker = create_tracker(args, args.output_dir) + return tracker + + +def build_target_model( + args: Namespace, draft_model_config: AutoDraftModelConfig, is_online: bool = True +) -> Tuple[Union[Eagle3TargetModel, TargetHead], Optional[AutoProcessor]]: + """ + Build the target model according to the arguments. + + Args: + args: The arguments for the training script. + draft_model_config: The draft model config. + + Returns: + The target model. + """ + if is_online: + if ( + args.is_vlm + and draft_model_config.target_model_type == "qwen2_5_vl" + and args.target_model_backend == "custom" + ): + from transformers import Qwen2_5_VLForConditionalGeneration + + target_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + ) + .eval() + .cuda() + ) + else: + if args.target_model_backend == "sglang": + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + else: + target_model_kwargs = {} + target_model = get_eagle3_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend=args.target_model_backend, + torch_dtype=torch.bfloat16, + device="cuda", + cache_dir=args.model_download_dir, + **target_model_kwargs, + trust_remote_code=args.trust_remote_code, + ) + + # set the aux hidden states layers + if ( + hasattr(draft_model_config, "eagle_config") + and draft_model_config.eagle_config is not None + and "eagle_aux_hidden_state_layer_ids" in draft_model_config.eagle_config + ): + target_model.set_aux_hidden_states_layers( + draft_model_config.eagle_config["eagle_aux_hidden_state_layer_ids"] + ) + else: + target_model.set_aux_hidden_states_layers() + + if args.is_vlm: + processor = AutoProcessor.from_pretrained( + args.target_model_path, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + ) + else: + processor = None + + return target_model, processor + else: + target_head = TargetHead.from_pretrained( + model_path=args.target_model_path, + lm_head_key=args.lm_head_key, + cache_dir=args.model_download_dir, + trust_remote_code=args.trust_remote_code, + ) + return target_head, None + + +def sanity_check(args: Namespace) -> None: + """ + Perform sanity checks on the arguments. + + Args: + args: The arguments for the training script. + + Returns: + None + """ + args.dp_size = dist.get_world_size() // args.tp_size + args.target_batch_size = args.tp_size * args.batch_size + args.draft_accumulation_steps = ( + args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size + ) + + if args.eval_data_path is not None and args.eval_hidden_states_path is not None: + raise ValueError( + "Cannot set both eval_data_path and eval_hidden_states_path. " + "For online mode, set only eval_data_path. " + "For offline mode, set only eval_hidden_states_path." + ) + + +def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]: + # Handle draft model config + if args.draft_model_config is None: + # Auto-generate and save config file + auto_config_path = create_draft_config_from_target( + target_model_path=args.target_model_path, cache_dir=args.model_download_dir + ) + draft_model_config = AutoDraftModelConfig.from_file(auto_config_path) + else: + # Use provided config file + draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) + + # Handle base ckpt, config file + draft_model_last_checkpoint = None + if args.ckpt_dir is not None: + if os.path.isdir(args.ckpt_dir): + draft_model_config = AutoDraftModelConfig.from_file( + os.path.join(args.ckpt_dir, "config.json") + ) + draft_model_last_checkpoint = args.ckpt_dir + print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}") + else: + raise ValueError( + f"Provided base model dir {args.ckpt_dir} is not a valid directory." + ) + + # detecting last ckpt for draft model + if args.resume and os.path.isdir(args.output_dir): + print_on_rank0(args.output_dir) + draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) + print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") + + if draft_model_last_checkpoint: + draft_model = AutoEagle3DraftModel.from_pretrained( + draft_model_last_checkpoint, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() + else: + draft_model = AutoEagle3DraftModel.from_config( + draft_model_config, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() + + draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) + draft_model.freeze_embedding() + return draft_model_config, draft_model + + +def build_dataloaders( + args: Namespace, + draft_model_config: AutoDraftModelConfig, + processor: Optional[AutoProcessor] = None, +) -> Tuple[DataLoader, str, Optional[DataLoader]]: + # build dataloaders + tokenizer = AutoTokenizer.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + + # convert to dataloader + cache_params_string = ( + f"{args.train_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}" # Tokenizer may also different + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + train_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.train_data_path}, + ) + is_online = ( + args.train_data_path is not None and args.train_hidden_states_path is None + ) + with rank_0_priority(): + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + is_vlm=args.is_vlm, + is_preformatted=args.is_preformatted, + processor=processor, + num_proc=args.build_dataset_num_proc, + train_only_last_turn=args.train_only_last_turn, + ) + vocab_mapping_path = generate_vocab_mapping_file( + dataset=train_eagle3_dataset, + target_vocab_size=draft_model_config.vocab_size, + draft_vocab_size=draft_model_config.draft_vocab_size, + cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), + cache_key=cache_key, + ) + + if not is_online: + train_eagle3_dataset = build_offline_eagle3_dataset( + args.train_hidden_states_path, + args.max_length, + ) + + train_dataloader = prepare_dp_dataloaders( + train_eagle3_dataset, + args.target_batch_size, + num_workers=args.dataloader_num_workers, + shuffle=True, + process_group=( + get_draft_dp_group() + if args.attention_backend == "usp" and not is_online + else get_dp_group() + ), + is_vlm=args.is_vlm, + ) + if args.eval_data_path is not None or args.eval_hidden_states_path is not None: + if args.eval_data_path is not None: + eval_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.eval_data_path}, + ) + eval_eagle3_dataset = build_eagle3_dataset( + eval_dataset, + tokenizer, + args.chat_template, + args.max_length, + is_vlm=args.is_vlm, + processor=processor, + num_proc=args.build_dataset_num_proc, + is_preformatted=args.is_preformatted, + train_only_last_turn=args.train_only_last_turn, + ) + elif args.eval_hidden_states_path is not None: + eval_eagle3_dataset = build_offline_eagle3_dataset( + args.eval_hidden_states_path, + args.max_length, + ) + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.target_batch_size, + num_workers=args.dataloader_num_workers, + shuffle=False, + process_group=( + get_draft_dp_group() + if args.attention_backend == "usp" and not is_online + else get_dp_group() + ), + is_vlm=args.is_vlm, + ) + print_with_rank("Initialized eval dataloader") + else: + eval_dataloader = None + return ( + train_dataloader, + vocab_mapping_path, + eval_dataloader, + ) + + +def save_checkpoints( + args: Namespace, + epoch: int, + step: int, + eagle3_model: nn.Module, + optimizer: Optimizer, +): + epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + if dist.get_rank() == 0: + os.makedirs(epoch_output_dir, exist_ok=True) + dist.barrier() + + with FSDP.state_dict_type(eagle3_model, StateDictType.FULL_STATE_DICT): + model_state_dict = eagle3_model.state_dict() + state_to_save = { + "epoch": epoch, + "global_step": step, + "args": args, + } + state_to_save.update(optimizer.state_dict()) + draft_model_state_dict = { + k.replace("draft_model.", ""): v + for k, v in model_state_dict.items() + if "draft_model." in k and "embed" not in k.lower() + } + + if dist.get_rank() == 0: + torch.save( + state_to_save, + os.path.join(epoch_output_dir, "training_state.pt"), + ) + print_on_rank0( + f"Saved full training state to {epoch_output_dir}/training_state.pt" + ) + eagle3_model.draft_model.save_pretrained( + epoch_output_dir, + state_dict=draft_model_state_dict, + ) + print_on_rank0(f"Saved model configuration to {epoch_output_dir}") + dist.barrier() + + +def run_forward( + args: Namespace, + eagle3_model: nn.Module, + data: dict, + target_model: Optional[Eagle3TargetModel] = None, + is_online: bool = True, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + if args.is_vlm and args.target_model_backend == "custom": + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + pixel_values=data["pixel_values"].cuda(), + image_grid_thw=data["image_grid_thw"].cuda(), + ) + else: + image_grid_thw = None + if is_online: + # we generate the eagle3 using the target model in an online fashion + # Handle VLM data: pixel_values and image_grid_thw are lists + # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None + if args.is_vlm: + image_grid_thw = ( + [thw.cuda().squeeze() for thw in data["image_grid_thw"]] + if args.is_vlm + else None + ) + pixel_values = data["pixel_values"].cuda() + eagle3_data = target_model.generate_eagle3_data( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + is_vlm=args.is_vlm, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + else: + eagle3_data = target_model.generate_eagle3_data( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + ) + + input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids) + attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask) + loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask) + target = get_dp_data_shard_from_tp(eagle3_data.target) + hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states) + else: + # we generate the logits using the hidden states loaded from disk + attention_mask = data["attention_mask"].cuda() + hidden_states = data["hidden_state"].cuda() + input_ids, target, loss_mask = target_model.preprocess( + data["input_ids"], data["target"], data["loss_mask"] + ) + input_ids = input_ids.cuda() + target = target_model( + target.cuda() + ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. + loss_mask = loss_mask.cuda() + plosses, _, acces = eagle3_model( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + target=target, + hidden_states=hidden_states, + image_grid_thw=image_grid_thw, + is_vlm=args.is_vlm, + ) + return plosses, acces + + +def run_backward_and_update( + args: Namespace, plosses: List[torch.Tensor], optimizer: Optimizer, global_step: int +) -> None: + ploss_weight = [0.8**i for i in range(len(plosses))] + ploss = ( + sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) + / args.draft_accumulation_steps + ) + ploss.backward() + + if global_step % args.draft_accumulation_steps == 0: + optimizer.step() + + +def record_metrcs( + args: Namespace, + accuracies: List[torch.Tensor], + plosses: List[torch.Tensor], + global_step: int, + tracker: Tracker, + optimizer: Optional[Optimizer] = None, + mode: str = "train", +) -> None: + logdict = {} + + if mode == "train" and optimizer is not None: + logdict["train/lr"] = optimizer.get_learning_rate() + + accuracies = torch.stack(accuracies) + plosses = torch.stack(plosses) + + assert accuracies.shape[0] == args.ttt_length + dist.all_reduce(accuracies, op=dist.ReduceOp.AVG) + accuracies = accuracies.cpu().tolist() + for i in range(len(accuracies)): + logdict[f"{mode}/acc_{i}"] = accuracies[i] + print_on_rank0( + f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, Acc: {accuracies[i]:.2f}" + ) + + dist.all_reduce(plosses, op=dist.ReduceOp.AVG) + plosses = plosses.cpu().tolist() + for i in range(len(plosses)): + logdict[f"{mode}/ploss_{i}"] = plosses[i] + print_on_rank0( + f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, pLoss: {plosses[i]}" + ) + tracker.log(logdict, step=global_step) + + +def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor: + """ + Process: TP split -> Pad to Max Len -> SP gather. + """ + # 1. TP: Slice the tensor along the batch dimension + tp_group = get_tp_group() + tp_size = dist.get_world_size(tp_group) + tp_rank = dist.get_rank(tp_group) + + local_tp_shard = tensor.chunk(tp_size, dim=0)[tp_rank] + + # 2. SP: Handle dynamic sequence lengths and Gather + sp_group = get_draft_sp_group() + + if sp_group is not None and dist.get_world_size(sp_group) > 1: + sp_world_size = dist.get_world_size(sp_group) + local_seq_len = local_tp_shard.size(sp_dim) + + # Find global max sequence length in SP group + len_tensor = torch.tensor( + [local_seq_len], device=local_tp_shard.device, dtype=torch.long + ) + dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group) + max_seq_len = len_tensor.item() + + # Pad local tensor if necessary + # Shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1 + if local_seq_len < max_seq_len: + pad_size = max_seq_len - local_seq_len + + pad_config = [0] * (local_tp_shard.ndim * 2) + + pad_idx = (local_tp_shard.ndim - 1 - sp_dim) * 2 + 1 + pad_config[pad_idx] = pad_size + + # Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed + local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0) + else: + local_tp_shard_padded = local_tp_shard + + gathered_shards = [ + torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size) + ] + dist.all_gather( + gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group + ) + + return torch.cat(gathered_shards, dim=sp_dim) + + return local_tp_shard + + +def main(): + # ================================================ + # 1. Initialize + # ================================================ + parser, args = parse_args() + set_seed(args.seed) + init_distributed( + timeout=args.dist_timeout, + tp_size=args.tp_size, + sp_ring_size=args.sp_ring_size, + sp_ulysses_size=args.sp_ulysses_size, + ) + is_online = ( + args.train_data_path is not None and args.train_hidden_states_path is None + ) + + sanity_check(args) + print_args_with_dots(args) + print_with_rank("Initialized distributed environment") + + # ================================================ + # 2. Build models + # ================================================ + draft_model_config, draft_model = build_draft_model(args) + target_model, processor = build_target_model(args, draft_model_config, is_online) + + # ================================================ + # 3. Build dataloader + # ================================================ + train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders( + args, draft_model_config, processor + ) + + # we load the vocab mapping then + draft_model.load_vocab_mapping(vocab_mapping_path) + print_with_rank("Loaded vocab mapping") + + # Calculate total steps if not provided + if args.total_steps is None: + steps_per_epoch = math.ceil( + len(train_dataloader) / args.draft_accumulation_steps + ) + args.total_steps = args.num_epochs * steps_per_epoch + print_with_rank( + f"Auto-calculated total_steps: {args.total_steps} (num_epochs={args.num_epochs} * steps_per_epoch={steps_per_epoch})" + ) + else: + print_with_rank(f"Using provided total_steps: {args.total_steps}") + + # ================================================ + # 4. Build Eagle3 model + # ================================================ + if ( + args.is_vlm + and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl" + and args.tp_size == 1 + and args.target_model_backend != "sglang" + ): + eagle3_model = QwenVLOnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + processor=processor, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + else: + if is_online: + eagle3_model = OnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + else: + # offline: the target_model is TargetHead not a model + eagle3_model = OnlineEagle3Model( + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + eagle3_model = FSDP( + eagle3_model, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, + process_group=dist.group.WORLD, # the draft model should run dp for all processes + ) + print_with_rank("Initialized Eagle3 FSDP model") + + # ================================================ + # 5. Build optimizer and scheduler + # ================================================ + optimizer = BF16Optimizer( + draft_model, + lr=args.learning_rate, + max_grad_norm=args.max_grad_norm, + warmup_ratio=args.warmup_ratio, + total_steps=args.total_steps, + ) + print_with_rank("Initialized optimizer and scheduler") + + # ================================================ + # 6. Build tracker + # ================================================ + tracker = build_tracker(args, parser) + global_step = 0 + start_epoch = 0 + dist.barrier() + + last_time = time.time() + + # ================================================ + # 7. Start training + # ================================================ + print_on_rank0(f"Starting training from epoch {start_epoch}") + + for epoch in range(start_epoch, args.num_epochs): + # Run training + train_dataloader.sampler.set_epoch(epoch + 1) + draft_model.train() + + if dist.get_rank() == 0: + progress_bar = tqdm( + train_dataloader, desc=f"Training Epoch {epoch}", leave=True + ) + else: + progress_bar = train_dataloader + + for data in progress_bar: + global_step += 1 + + # ================================================ + # 7.0 Profiling + # ================================================ + if args.profile: + # we add the step by 1 to align with global step + if global_step == args.profile_start_step + 1: + print("Start profile") + torch_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=args.profile_record_shapes, + ) + torch_profiler.start() + if global_step == args.profile_start_step + args.profile_num_steps + 1: + output_path = os.path.join( + args.output_dir, + f"profile_rank{torch.distributed.get_rank()}_{time.time()}.trace.json.gz", + ) + print(f"End profile {output_path=}") + torch_profiler.stop() + torch_profiler.export_chrome_trace(output_path) + + # ================================================ + # 7.1 Training Step + # ================================================ + plosses, acces = run_forward( + args, eagle3_model, data, target_model, is_online + ) + run_backward_and_update(args, plosses, optimizer, global_step) + + # log training metrics + if global_step % (args.log_interval * args.draft_accumulation_steps) == 0: + record_metrcs( + args, + acces, + plosses, + global_step // args.draft_accumulation_steps, + tracker, + optimizer, + mode="train", + ) + + if dist.get_rank() == 0: + time_per_step = time.time() - last_time + last_time = time.time() + avg_loss = sum(pl for pl in plosses) / len(plosses) + avg_acc = sum(acces) / len(acces) + progress_bar.set_postfix( + { + "loss": f"{avg_loss:.2f}", + "acc": f"{avg_acc:.2f}", + "time": f"{time_per_step:.2f}s", + } + ) + + # ================================================ + # 7.2 Evaluation Step + # ================================================ + should_evaluate = ( + args.eval_data_path is not None + or args.eval_hidden_states_path is not None + ) + if ( + should_evaluate + and global_step % (args.eval_interval * args.draft_accumulation_steps) + == 0 + ): + # Run evaluation + draft_model.eval() + eval_acces = [[] for _ in range(eagle3_model.length)] + eval_plosses = [[] for _ in range(eagle3_model.length)] + + for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): + with torch.no_grad(): + plosses, acces = run_forward( + args, eagle3_model, data, target_model, is_online + ) + eval_acces = [ + eval_acces[i] + [acces[i]] for i in range(len(acces)) + ] + eval_plosses = [ + eval_plosses[i] + [plosses[i]] for i in range(len(plosses)) + ] + + # compute average over all minibatches + eval_acces = [torch.stack(acc).mean() for acc in eval_acces] + eval_plosses = [torch.stack(pl).mean() for pl in eval_plosses] + + record_metrcs( + args, + eval_acces, + eval_plosses, + global_step // args.draft_accumulation_steps, + tracker, + mode="eval", + ) + # ================================================ + # 7.3 Save Checkpoints + # ================================================ + if global_step % args.save_interval == 0: + # Save the model + save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) + + if args.max_num_steps is not None and global_step >= args.max_num_steps: + break + + if args.max_num_steps is not None and global_step >= args.max_num_steps: + break + # Save final checkpoint if training ended without saving + if global_step % args.save_interval != 0: + print_on_rank0( + f"Training completed at step {global_step}, saving final checkpoint..." + ) + save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) + + # Close the tracker + tracker.close() + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/SpecForge-ext/specforge.egg-info/PKG-INFO b/SpecForge-ext/specforge.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..72ae9413c7dc29ef3cf20c49a283a88397acc806 --- /dev/null +++ b/SpecForge-ext/specforge.egg-info/PKG-INFO @@ -0,0 +1,107 @@ +Metadata-Version: 2.4 +Name: specforge +Version: 0.2.0 +Summary:
+Home-page: https://github.com/sgl-project/SpecForge +Author: SGLang Team +Requires-Python: >=3.11 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: pre-commit +Requires-Dist: torch==2.9.1 +Requires-Dist: torchaudio==2.9.1 +Requires-Dist: torchvision==0.24.1 +Requires-Dist: transformers==4.57.1 +Requires-Dist: qwen-vl-utils==0.0.11 +Requires-Dist: datasets +Requires-Dist: setuptools +Requires-Dist: tqdm +Requires-Dist: wandb +Requires-Dist: psutil +Requires-Dist: numpy +Requires-Dist: accelerate +Requires-Dist: pydantic +Requires-Dist: sglang==0.5.6 +Requires-Dist: openai-harmony +Requires-Dist: ninja +Requires-Dist: packaging +Requires-Dist: yunchang +Provides-Extra: dev +Requires-Dist: pre-commit; extra == "dev" +Requires-Dist: unittest; extra == "dev" +Provides-Extra: fa +Requires-Dist: flash-attn; extra == "fa" +Dynamic: author +Dynamic: home-page +Dynamic: license-file + +
+logo + +[![documentation](https://img.shields.io/badge/📖-Documentation-red.svg?style=flat)](https://docs.sglang.ai/SpecForge/) +[![SpecBundle](https://img.shields.io/badge/🤗%20SpecBundle-yellow.svg?style=flat)](https://huggingface.co/collections/lmsys/specbundle) +[![DeepWiki](https://img.shields.io/badge/DeepWiki-SpecForge-blue.svg?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACwAAAAyCAYAAAAnWDnqAAAAAXNSR0IArs4c6QAAA05JREFUaEPtmUtyEzEQhtWTQyQLHNak2AB7ZnyXZMEjXMGeK/AIi+QuHrMnbChYY7MIh8g01fJoopFb0uhhEqqcbWTp06/uv1saEDv4O3n3dV60RfP947Mm9/SQc0ICFQgzfc4CYZoTPAswgSJCCUJUnAAoRHOAUOcATwbmVLWdGoH//PB8mnKqScAhsD0kYP3j/Yt5LPQe2KvcXmGvRHcDnpxfL2zOYJ1mFwrryWTz0advv1Ut4CJgf5uhDuDj5eUcAUoahrdY/56ebRWeraTjMt/00Sh3UDtjgHtQNHwcRGOC98BJEAEymycmYcWwOprTgcB6VZ5JK5TAJ+fXGLBm3FDAmn6oPPjR4rKCAoJCal2eAiQp2x0vxTPB3ALO2CRkwmDy5WohzBDwSEFKRwPbknEggCPB/imwrycgxX2NzoMCHhPkDwqYMr9tRcP5qNrMZHkVnOjRMWwLCcr8ohBVb1OMjxLwGCvjTikrsBOiA6fNyCrm8V1rP93iVPpwaE+gO0SsWmPiXB+jikdf6SizrT5qKasx5j8ABbHpFTx+vFXp9EnYQmLx02h1QTTrl6eDqxLnGjporxl3NL3agEvXdT0WmEost648sQOYAeJS9Q7bfUVoMGnjo4AZdUMQku50McDcMWcBPvr0SzbTAFDfvJqwLzgxwATnCgnp4wDl6Aa+Ax283gghmj+vj7feE2KBBRMW3FzOpLOADl0Isb5587h/U4gGvkt5v60Z1VLG8BhYjbzRwyQZemwAd6cCR5/XFWLYZRIMpX39AR0tjaGGiGzLVyhse5C9RKC6ai42ppWPKiBagOvaYk8lO7DajerabOZP46Lby5wKjw1HCRx7p9sVMOWGzb/vA1hwiWc6jm3MvQDTogQkiqIhJV0nBQBTU+3okKCFDy9WwferkHjtxib7t3xIUQtHxnIwtx4mpg26/HfwVNVDb4oI9RHmx5WGelRVlrtiw43zboCLaxv46AZeB3IlTkwouebTr1y2NjSpHz68WNFjHvupy3q8TFn3Hos2IAk4Ju5dCo8B3wP7VPr/FGaKiG+T+v+TQqIrOqMTL1VdWV1DdmcbO8KXBz6esmYWYKPwDL5b5FA1a0hwapHiom0r/cKaoqr+27/XcrS5UwSMbQAAAABJRU5ErkJggg==)](https://deepwiki.com/sgl-project/SpecForge) + +[![github badge](https://img.shields.io/badge/📃%20LMSYS-Blog-black.svg?style=flat)](https://lmsys.org/blog/2025-07-25-spec-forge/) +[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://sgl-fru7574.slack.com/archives/C09784E3EN6) +[![license](https://img.shields.io/badge/License-MIT%202.0-blue)](./LICENSE) + +
+ +## 📍 Overview + +SpecForge is an ecosystem project developed by the SGLang team. It is a framework for training speculative decoding models so that you can smoothly port them over to the SGLang serving framework to speed up your inference. + +We have seen many open-source projects for speculative decoding, but most of them are not well-maintained or not directly compatible with SGLang. We prepared this project because we wish that the open-source community can enjoy a speculative decoding framework that is +- regularly maintained by the SpecForge team: the code is runnable out-of-the-box +- directly compatible with SGLang: there is no additional efforts for porting to SGLang +- provide performant training capabilities: we provided online/offline/tensor-parallel/FSDP to suit your needs + + +Check out [**our documentation**](https://docs.sglang.ai/SpecForge/) to get started. + + +## 🚀 Accelerate with SpecBundle + +SpecBundle is a collection of production-grade speculative decoding models that are released by the SpecForge team and our industry partners. They provide higher acceptance rate compared to the existing open-source checkpoints over a wide range of domains. Together with SGLang, you can experience up to 4x speedup for inference. Check out our resources below: + + +| Item | Link | +| --- | --- | +| 📝 Documentation | [Link](https://docs.sglang.io/SpecForge/community_resources/specbundle.html) | +| 📊 Performance Dashboard | [Link](https://docs.sglang.io/SpecForge/SpecBundle/index.html) | +| 🤗 Hugging Face Collection | [Link](https://huggingface.co/collections/lmsys/specbundle) | + + +## 🎉 News + +- [2025-12] 🎉 Released SpecBundle (phase 1) and SpecForge v0.2. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-12-23-spec-bundle-phase-1/) +- [2025-12] 🔔 Released the roadmap for 2026 Q1. +- [2025-08] 🔔 SpecForge is listed as a [flagship project](https://lmsys.org/about/) in LMSYS. Congratulations to the SpecForge team! +- [2025-08] 🔥 SpecForge powered the Eagle3 draft model for GPT-OSS. Check out the blog at [LMSYS.org](https://lmsys.org/blog/2025-08-27-gpt-oss/) +- [2025-07] 🔥 SpecForge is released together with Llama4-Eagle3 checkpoints. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-07-25-spec-forge/) + +## ✨ Acknowledgements + +acknowledgements + +We would like to express our sincere gratitude to the official EAGLE team, especially Hongyang Zhang and Yuhui Li, for their invaluable contributions and support. Our thanks also go to the NVIDIA team—particularly Avery H and Izzy Putterman—and to the Google team, especially Ying Wang, for their insightful discussions and generous assistance throughout the project. + +We are especially grateful to Meituan for their strong backing and meaningful contributions, which played a vital role in driving this project forward. + +This project has also been inspired by many outstanding open-source projects from the LLM community, including [EAGLE](https://github.com/SafeAILab/EAGLE), [BaldEagle](https://github.com/NickL77/BaldEagle), and [TensorRT-Model-Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) and others. Their contributions and shared knowledge have greatly benefited our work. + +## 💡 Special Thanks to Voltage Park + +We would like to extend our sincere thanks to [Voltage Park](https://www.voltagepark.com/), our official infrastructure partner. As part of a formal collaboration with the SGLang team, Voltage Park provided critical GPU resources that empowered us to train and evaluate large-scale speculative decoding models efficiently and reliably. This partnership was instrumental in making SpecForge possible. We deeply appreciate Voltage Park’s mission to make cutting-edge AI infrastructure more accessible, and we look forward to continued collaboration as we push the boundaries of open-source LLM serving and optimization. + +## 📃 Citation + +```bibtex +@misc{specforge2025, + title={SpecForge: Train speculative decoding models effortlessly}, + author={Shenggui Li, Yikai Zhu, Chao Wang, Fan Yin, Shuai Shi, Yubo Wang, Yi Zhang, Yingyi Huang, Haoshuai Zheng, Yineng Zhang}, + year={2025}, + publisher={GitHub}, + howpublished={\url{https://github.com/sgl-project/specforge}}, +} diff --git a/SpecForge-ext/specforge.egg-info/SOURCES.txt b/SpecForge-ext/specforge.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..14a046bed3047542f0f022f41d6adeb262976585 --- /dev/null +++ b/SpecForge-ext/specforge.egg-info/SOURCES.txt @@ -0,0 +1,19 @@ +LICENSE +MANIFEST.in +README.md +pyproject.toml +requirements.txt +setup.py +version.txt +specforge/__init__.py +specforge/args.py +specforge/distributed.py +specforge/lr_scheduler.py +specforge/optimizer.py +specforge/tracker.py +specforge/utils.py +specforge.egg-info/PKG-INFO +specforge.egg-info/SOURCES.txt +specforge.egg-info/dependency_links.txt +specforge.egg-info/requires.txt +specforge.egg-info/top_level.txt \ No newline at end of file diff --git a/SpecForge-ext/specforge.egg-info/dependency_links.txt b/SpecForge-ext/specforge.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/SpecForge-ext/specforge.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/SpecForge-ext/specforge.egg-info/requires.txt b/SpecForge-ext/specforge.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..7131cf58693b7dd85a7753275b225ae62cb09131 --- /dev/null +++ b/SpecForge-ext/specforge.egg-info/requires.txt @@ -0,0 +1,26 @@ +pre-commit +torch==2.9.1 +torchaudio==2.9.1 +torchvision==0.24.1 +transformers==4.57.1 +qwen-vl-utils==0.0.11 +datasets +setuptools +tqdm +wandb +psutil +numpy +accelerate +pydantic +sglang==0.5.6 +openai-harmony +ninja +packaging +yunchang + +[dev] +pre-commit +unittest + +[fa] +flash-attn diff --git a/SpecForge-ext/specforge.egg-info/top_level.txt b/SpecForge-ext/specforge.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..3aa0f69562ec1873dd2cb390694dbd84fca3ce40 --- /dev/null +++ b/SpecForge-ext/specforge.egg-info/top_level.txt @@ -0,0 +1 @@ +specforge diff --git a/SpecForge-ext/specforge/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c7ea22991c43562a50648305179cb1f955eb850 Binary files /dev/null and b/SpecForge-ext/specforge/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/__pycache__/args.cpython-311.pyc b/SpecForge-ext/specforge/__pycache__/args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3e5376d59d11001bfcfbbce260109a3ca469cc6 Binary files /dev/null and b/SpecForge-ext/specforge/__pycache__/args.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/__pycache__/distributed.cpython-311.pyc b/SpecForge-ext/specforge/__pycache__/distributed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f074b7958451fa376a11fa1fb640c99bbd2491 Binary files /dev/null and b/SpecForge-ext/specforge/__pycache__/distributed.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/__pycache__/lr_scheduler.cpython-311.pyc b/SpecForge-ext/specforge/__pycache__/lr_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55595a97585fc654bc6a1c470642275da5c264c4 Binary files /dev/null and b/SpecForge-ext/specforge/__pycache__/lr_scheduler.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/__pycache__/optimizer.cpython-311.pyc b/SpecForge-ext/specforge/__pycache__/optimizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..426af79d23a9698e1eeb9118f425def8ef93eae7 Binary files /dev/null and b/SpecForge-ext/specforge/__pycache__/optimizer.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/__pycache__/tracker.cpython-311.pyc b/SpecForge-ext/specforge/__pycache__/tracker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8be4996edff5fe7ac90fa49a99682f98c9bf65ac Binary files /dev/null and b/SpecForge-ext/specforge/__pycache__/tracker.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/benchmarks/benchmark_flex_attention.py b/SpecForge-ext/specforge/benchmarks/benchmark_flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..20f989565727ffe42ab112c8818ac32371b9313d --- /dev/null +++ b/SpecForge-ext/specforge/benchmarks/benchmark_flex_attention.py @@ -0,0 +1,336 @@ +import argparse +import time + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch._dynamo as dynamo +from transformers import LlamaConfig +from transformers.cache_utils import DynamicCache + +from specforge.modeling.draft.llama3_eagle import ( + LlamaAttention, + LlamaFlexAttention, + prepare_decoder_attention_mask, +) + +dynamo.config.recompile_limit = 64 + +config_dict = { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "max_position_embeddings": 16384, + "rms_norm_eps": 1e-05, + "vocab_size": 32000, + "hidden_act": "silu", + "num_hidden_layers": 1, +} + +config = LlamaConfig(**config_dict) + +TTT_LENGTH = 7 +BATCH_SIZE = 4 +HIDDEN_SIZE = config.hidden_size * 2 + + +def run_attention( + seq_len: int, + hidden_states_list: list[torch.Tensor], + attention_backend: str = "sdpa", + enable_profile: bool = False, +): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + batch_size = hidden_states_list[0].shape[0] + # Initialize cache and attention function based on backend + if attention_backend == "sdpa": + cache_hidden = [[], []] + past_key_values = None + attn_func = LlamaAttention(config).to(device).to(torch.bfloat16) + elif attention_backend == "flex_attention": + cache_hidden = None + past_key_values = DynamicCache() + attn_func = LlamaFlexAttention(config).to(device).to(torch.bfloat16) + else: + raise ValueError(f"Unknown attention backend: {attention_backend}") + + # Simulate inputs - move to device + position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to(device) + input_embeds = torch.randn(batch_size, seq_len, config.hidden_size).to(device) + attention_mask = torch.ones(batch_size, seq_len).to(device) + decoder_attention_mask = prepare_decoder_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_len), + inputs_embeds=input_embeds, + past_key_values_length=0, + ) + + loss_list = [] + + if attention_backend == "flex_attention" and enable_profile: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"./profiler_logs/{attention_backend}" + ), + record_shapes=False, + profile_memory=False, + with_stack=True, + with_modules=False, + ) + profiler.start() + for idx in range(TTT_LENGTH): + is_last = idx == TTT_LENGTH - 1 + hidden_states = hidden_states_list[idx] + # Call attention function with appropriate parameters + if attention_backend == "sdpa": + output = attn_func( + hidden_states=hidden_states, + attention_mask=decoder_attention_mask, + position_ids=position_ids, + cache_hidden=cache_hidden, + output_attentions=False, + use_cache=True, + ) + else: # flex_attention + output = attn_func( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=False, + use_cache=True, + ) + + # Compute a simple loss for benchmarking + loss = output[0].sum() + loss_list.append(loss) + + # Compute mean loss and backward pass + if loss_list: + mean_loss = sum(loss_list) / len(loss_list) + mean_loss.backward() + + if attention_backend == "flex_attention" and enable_profile: + profiler.stop() + + +def benchmark_function( + attention_backend: str, + seq_lengths: list, + enable_profile: bool = False, + enable_warmup: bool = True, +): + """Benchmark a function for speed and GPU memory usage per sequence length.""" + print(f"\n=== Benchmarking {attention_backend} ===") + + results_per_seq_len = [] + + for seq_len in seq_lengths: + print(f"\nTesting sequence length: {seq_len}") + + # Clear GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Warm up runs for this sequence length + if enable_warmup: + print("Warming up...") + for _ in range(2): + hidden_states = [ + torch.randn( + BATCH_SIZE, + seq_len, + HIDDEN_SIZE, + requires_grad=True, + device="cuda", + dtype=torch.bfloat16, + ) + for _ in range(TTT_LENGTH) + ] + run_attention(seq_len, hidden_states, attention_backend) + # Clear cache again after warmup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + # Record initial memory + initial_memory = 0 + if torch.cuda.is_available(): + initial_memory = torch.cuda.memory_allocated() + hidden_states = [ + torch.randn( + BATCH_SIZE, + seq_len, + HIDDEN_SIZE, + requires_grad=True, + device="cuda", + dtype=torch.bfloat16, + ) + for _ in range(TTT_LENGTH) + ] + start_time = time.time() + run_attention( + seq_len, + hidden_states, + attention_backend, + enable_profile and seq_len == seq_lengths[0], + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + end_time = time.time() + + # Record memory usage + peak_memory = 0 + current_memory = 0 + if torch.cuda.is_available(): + peak_memory = torch.cuda.max_memory_allocated() + current_memory = torch.cuda.memory_allocated() + results_per_seq_len.append( + { + "seq_len": seq_len, + "time": end_time - start_time, + "peak_memory": peak_memory, + "memory_increase": current_memory - initial_memory, + } + ) + + print(f" Time: {end_time - start_time:.3f}s") + print(f" Peak memory: {peak_memory / 1024**3:.3f} GB") + print( + f" Memory increase: {(current_memory - initial_memory) / 1024**3:.3f} GB" + ) + + return results_per_seq_len + + +def plot_results(eagle_results, flex_results, seq_lengths): + """Plot speed and memory comparison between Eagle and Flex attention.""" + + # Extract data for plotting + eagle_times = [r["time"] for r in eagle_results] + flex_times = [r["time"] for r in flex_results] + eagle_memory = [r["peak_memory"] / 1024**3 for r in eagle_results] # Convert to GB + flex_memory = [r["peak_memory"] / 1024**3 for r in flex_results] # Convert to GB + + # Create subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + + # Speed comparison plot + ax1.plot( + seq_lengths, eagle_times, "b-o", label="Eagle (SDPA)", linewidth=2, markersize=8 + ) + ax1.plot( + seq_lengths, + flex_times, + "r-s", + label="Flex Attention", + linewidth=2, + markersize=8, + ) + ax1.set_xlabel("Sequence Length") + ax1.set_ylabel("Time (seconds)") + ax1.set_title("Speed Comparison: Eagle vs Flex Attention") + ax1.legend() + ax1.grid(True, alpha=0.3) + ax1.set_xscale("linear") + ax1.set_yscale("log") + + # Memory comparison plot + ax2.plot( + seq_lengths, + eagle_memory, + "b-o", + label="Eagle (SDPA)", + linewidth=2, + markersize=8, + ) + ax2.plot( + seq_lengths, + flex_memory, + "r-s", + label="Flex Attention", + linewidth=2, + markersize=8, + ) + ax2.set_xlabel("Sequence Length") + ax2.set_ylabel("Peak Memory (GB)") + ax2.set_title("Memory Usage Comparison: Eagle vs Flex Attention") + ax2.legend() + ax2.grid(True, alpha=0.3) + + # Set y-axis ticks every 10GB + max_memory = max(max(eagle_memory), max(flex_memory)) + ax2.set_yticks(np.arange(0, max_memory + 10, 10)) + + plt.tight_layout() + plt.savefig("attention_benchmark_comparison.png", dpi=300, bbox_inches="tight") + plt.show() + + # Print summary statistics + print(f"\n=== Performance Summary ===") + print(f"Sequence lengths tested: {seq_lengths}") + print(f"\nSpeed ratios (Eagle/Flex):") + for i, seq_len in enumerate(seq_lengths): + ratio = eagle_times[i] / flex_times[i] if flex_times[i] > 0 else float("inf") + print(f" {seq_len:4d}: {ratio:.2f}x") + + print(f"\nMemory ratios (Eagle/Flex):") + for i, seq_len in enumerate(seq_lengths): + ratio = eagle_memory[i] / flex_memory[i] if flex_memory[i] > 0 else float("inf") + print(f" {seq_len:4d}: {ratio:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark attention mechanisms") + parser.add_argument( + "--enable-profile", action="store_true", help="Enable profiling" + ) + args = parser.parse_args() + + print("PyTorch version:", torch.__version__) + if torch.cuda.is_available(): + print("CUDA available:", torch.cuda.is_available()) + print("GPU:", torch.cuda.get_device_name()) + print( + "GPU memory:", + torch.cuda.get_device_properties(0).total_memory / 1024**3, + "GB", + ) + else: + print("CUDA not available - running on CPU") + + # Define sequence lengths to test + seq_lengths = [128 * i for i in range(1, 28, 4)] + # Add extra long context + seq_lengths.extend([16384, 32768]) + + print(f"Testing sequence lengths: {seq_lengths}") + + # Run benchmarks + print("\n" + "=" * 50) + # Truncate seqlen after 2560 since naive eagle goes OOM + eagle_seq_lengths = [seq_len for seq_len in seq_lengths if seq_len <= 2560] + eagle_results = benchmark_function("sdpa", eagle_seq_lengths) + print("\n" + "=" * 50) + flex_results = benchmark_function( + "flex_attention", seq_lengths, enable_profile=args.enable_profile + ) + # Pad the memory usage on eagle to max memory 80GB when data not available + max_time = max(result["time"] for result in flex_results) + for result in flex_results: + if result["seq_len"] not in eagle_seq_lengths: + eagle_results.append( + { + "seq_len": result["seq_len"], + "time": max_time, + "peak_memory": 80 * 1024**3, + "memory_increase": 0, # Not used in plotting + } + ) + + # Plot results + plot_results(eagle_results, flex_results, seq_lengths) diff --git a/SpecForge-ext/specforge/benchmarks/benchmark_loss.py b/SpecForge-ext/specforge/benchmarks/benchmark_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..940787a860d98ee406bee6a29df9127bae675d92 --- /dev/null +++ b/SpecForge-ext/specforge/benchmarks/benchmark_loss.py @@ -0,0 +1,179 @@ +import argparse +import time + +import torch + +from specforge.core.loss import LogSoftmaxLoss, _compute_loss + +TTT_LENGTH = 7 + + +def benchmark_loss_method( + loss_method: str, + test_configs: list, +): + """Benchmark a loss computation method for speed and GPU memory usage.""" + print(f"\n=== Benchmarking {loss_method} Loss ===") + + results = [] + + for config in test_configs: + B, T, V = config + print(f"\nTesting config: B={B}, T={T}, V={V}") + + # Clear GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Create tensors outside timing measurement + target = torch.softmax( + torch.randn(B, T, V, device="cuda", dtype=torch.float32), dim=-1 + ) + position_mask = torch.ones((B, T, 1), dtype=torch.bool, device="cuda") + + # Pre-allocate logits tensors for each TTT step + logits_list = [] + for i in range(TTT_LENGTH): + logits = torch.randn( + B, T, V, device="cuda", requires_grad=True, dtype=torch.float32 + ) + logits_list.append(logits) + + torch.cuda.synchronize() # Ensure all operations are complete + start_time = time.time() + + plosses = [] + for i in range(TTT_LENGTH): + logits = logits_list[i] + if loss_method == "triton": + loss = LogSoftmaxLoss.apply(logits, target, position_mask) + else: + loss = _compute_loss(logits, target, position_mask) + plosses.append(loss) + + ploss_weight = [0.8**i for i in range(len(plosses))] + ploss = ( + sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) + / TTT_LENGTH + ) + ploss.backward() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + end_time = time.time() + total_time = end_time - start_time + # Record memory usage + peak_memory = 0 + if torch.cuda.is_available(): + peak_memory = torch.cuda.max_memory_allocated() + + results.append( + { + "B": B, + "T": T, + "V": V, + "time_total": total_time, + "peak_memory": peak_memory, + } + ) + + print(f" Total time (forward + backward): {total_time*1000:.3f}ms") + print(f" Peak memory: {peak_memory / 1024**3:.3f} GB") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark loss computation methods") + parser.add_argument( + "--num-runs", type=int, default=5, help="Number of runs for averaging" + ) + args = parser.parse_args() + + print("PyTorch version:", torch.__version__) + if torch.cuda.is_available(): + print("CUDA available:", torch.cuda.is_available()) + print("GPU:", torch.cuda.get_device_name()) + print( + "GPU memory:", + torch.cuda.get_device_properties(0).total_memory / 1024**3, + "GB", + ) + else: + print("CUDA not available - running on CPU") + + # Define test configurations (B, T, V) + test_configs = [ + (1, 1024, 32000), + (1, 1024, 64000), + (1, 4096, 32000), + (1, 4096, 64000), + (1, 8192, 32000), + (1, 8192, 64000), + (1, 16384, 32000), + ] + + print(f"Testing configurations: {test_configs}") + + # Run benchmarks + print("\n" + "=" * 60) + pytorch_results = benchmark_loss_method("pytorch", test_configs) + + print("\n" + "=" * 60) + triton_results = benchmark_loss_method("triton", test_configs) + + # Print results summary + print(f"\n=== Performance Summary ===") + print(f"Configurations tested: {len(test_configs)}") + + # Print detailed results table + print( + f"\n{'Config (B,T,V)':<15} {'PyTorch (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'PyTorch Mem (GB)':<18} {'Triton Mem (GB)':<15} {'Memory Save':<12}" + ) + print("-" * 115) + + for i, config in enumerate(test_configs): + B, T, V = config + config_str = f"({B},{T},{V})" + + pytorch_result = next( + (r for r in pytorch_results if r["B"] == B and r["T"] == T and r["V"] == V), + None, + ) + triton_result = next( + (r for r in triton_results if r["B"] == B and r["T"] == T and r["V"] == V), + None, + ) + + if pytorch_result and triton_result: + pytorch_time_str = f"{pytorch_result['time_total']*1000:.2f}" + pytorch_mem_str = f"{pytorch_result['peak_memory']/1024**3:.2f}" + + triton_time_str = f"{triton_result['time_total']*1000:.2f}" + triton_mem_str = f"{triton_result['peak_memory']/1024**3:.2f}" + + if triton_result["time_total"] > 0: + speedup = pytorch_result["time_total"] / triton_result["time_total"] + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + + # Calculate memory savings percentage + if pytorch_result["peak_memory"] > 0: + memory_save_pct = ( + (pytorch_result["peak_memory"] - triton_result["peak_memory"]) + / pytorch_result["peak_memory"] + ) * 100 + memory_save_str = f"{memory_save_pct:.1f}%" + else: + memory_save_str = "N/A" + + print( + f"{config_str:<15} {pytorch_time_str:<15} {triton_time_str:<15} {speedup_str:<10} {pytorch_mem_str:<18} {triton_mem_str:<15} {memory_save_str:<12}" + ) + + +if __name__ == "__main__": + main() diff --git a/SpecForge-ext/specforge/core/__init__.py b/SpecForge-ext/specforge/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0ebc907b5d9a05954e1a69057c32cd070fd66e --- /dev/null +++ b/SpecForge-ext/specforge/core/__init__.py @@ -0,0 +1,9 @@ +from .dflash import OnlineDFlashModel, create_dflash_loss_mask +from .eagle3 import OnlineEagle3Model, QwenVLOnlineEagle3Model + +__all__ = [ + "OnlineDFlashModel", + "create_dflash_loss_mask", + "OnlineEagle3Model", + "QwenVLOnlineEagle3Model", +] diff --git a/SpecForge-ext/specforge/core/dflash.py b/SpecForge-ext/specforge/core/dflash.py new file mode 100644 index 0000000000000000000000000000000000000000..dabdae6b520158cdb083d0e71cd5f447dd81968c --- /dev/null +++ b/SpecForge-ext/specforge/core/dflash.py @@ -0,0 +1,215 @@ +# coding=utf-8 +"""DFlash Training Wrapper.""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from specforge.modeling.draft.dflash import DFlashDraftModel + +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + + FLEX_ATTENTION_AVAILABLE = True +except ImportError: + FLEX_ATTENTION_AVAILABLE = False + BlockMask = None + create_block_mask = None + + +class OnlineDFlashModel(nn.Module): + """DFlash online training wrapper with block-wise CE loss.""" + + def __init__( + self, + draft_model: DFlashDraftModel, + target_lm_head: nn.Module, + target_embed_tokens: nn.Module, + mask_token_id: int, + block_size: int = 16, + attention_backend: str = "flex_attention", + ): + super().__init__() + self.draft_model = draft_model + self.lm_head = target_lm_head + self.embed_tokens = target_embed_tokens + self.block_size = block_size + self.mask_token_id = mask_token_id + self.attention_backend = attention_backend + + # Cache for BlockMask + self._cached_block_mask: Optional[BlockMask] = None + self._cached_seq_len: Optional[int] = None + self._cached_bsz: Optional[int] = None + + def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor: + """Prepare noise input: first token of each block is real, rest are MASK.""" + seq_len = input_ids.shape[1] + device = input_ids.device + + positions = torch.arange(seq_len, device=device) + is_block_start = (positions % self.block_size) == 0 + + noise_input_ids = torch.full_like(input_ids, self.mask_token_id) + noise_input_ids[:, is_block_start] = input_ids[:, is_block_start] + + return noise_input_ids + + def _get_or_create_block_mask( + self, bsz: int, q_len: int, kv_len: int, device: torch.device + ) -> "BlockMask": + """Get cached BlockMask or create a new one.""" + if ( + self._cached_block_mask is not None + and self._cached_seq_len == q_len + and self._cached_bsz == bsz + ): + return self._cached_block_mask + + block_size = self.block_size + + def dflash_mask_fn(b, h, q_idx, kv_idx): + L = q_len + is_ctx = kv_idx < L + q_block = q_idx // block_size + k_block_ctx = kv_idx // block_size + k_block_noise = (kv_idx - L) // block_size + ctx_visible = is_ctx & (k_block_ctx < q_block) + noise_visible = (~is_ctx) & (k_block_noise == q_block) + return ctx_visible | noise_visible + + block_mask = create_block_mask( + dflash_mask_fn, + B=bsz, + H=1, + Q_LEN=q_len, + KV_LEN=kv_len, + device=device, + ) + + self._cached_block_mask = block_mask + self._cached_seq_len = q_len + self._cached_bsz = bsz + + return block_mask + + def _create_parallel_attention_mask( + self, seq_len: int, device: torch.device + ) -> torch.Tensor: + """ + Create [L, 2L] attention mask for parallel training. + - Left half (ctx): Q can see K_ctx if K's block < Q's block + - Right half (noise): Q can see K_noise if same block (bidirectional) + """ + indices = torch.arange(seq_len, device=device) + block_ids = indices // self.block_size + + q_block_ids = block_ids.unsqueeze(1) + k_block_ids = block_ids.unsqueeze(0) + + ctx_mask = k_block_ids < q_block_ids + noise_mask = q_block_ids == k_block_ids + + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) + full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) + + return full_mask + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + hidden_states: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Parallel block-wise training forward pass.""" + bsz, seq_len = input_ids.shape + device = input_ids.device + + # Truncate to multiple of block_size + n_blocks = seq_len // self.block_size + effective_len = n_blocks * self.block_size + input_ids = input_ids[:, :effective_len] + hidden_states = hidden_states[:, :effective_len, :] + loss_mask = loss_mask[:, :effective_len] + attention_mask = attention_mask[:, :effective_len] + + # Prepare inputs + noise_input_ids = self.prepare_noise_input(input_ids) + noise_embedding = self.embed_tokens(noise_input_ids) + + # Position IDs: [ctx_pos, noise_pos] both 0..L-1 + pos_seq = torch.arange(effective_len, device=device) + position_ids = torch.cat([pos_seq, pos_seq], dim=0).unsqueeze(0).expand(bsz, -1) + + # Construct attention mask + if ( + self.attention_backend == "flex_attention" + and FLEX_ATTENTION_AVAILABLE + and create_block_mask is not None + ): + dflash_attn_mask = self._get_or_create_block_mask( + bsz=bsz, + q_len=effective_len, + kv_len=effective_len * 2, + device=device, + ) + else: + dflash_attn_mask = self._create_parallel_attention_mask( + effective_len, device + ) + dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype) + dflash_attn_mask = ( + dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1) + ) + + # Forward pass + hidden = self.draft_model( + position_ids=position_ids, + noise_embedding=noise_embedding, + target_hidden=hidden_states, + attention_mask=dflash_attn_mask, + ) + + # Compute loss (skip block 0 and block starts) + dflash_loss_mask_base = create_dflash_loss_mask( + effective_len, self.block_size, device + ) + combined_mask = loss_mask * dflash_loss_mask_base.unsqueeze(0) + + logits = self.lm_head(hidden) + + logits_flat = logits.reshape(-1, logits.size(-1)) + labels_flat = input_ids.reshape(-1) + mask_flat = combined_mask.reshape(-1) + + active_indices = mask_flat > 0.5 + active_logits = logits_flat[active_indices] + active_labels = labels_flat[active_indices] + + loss = F.cross_entropy(active_logits, active_labels) + + with torch.no_grad(): + preds = active_logits.argmax(dim=-1) + correct = (preds == active_labels).float().sum() + total = active_labels.numel() + accuracy = correct / total + + return loss, accuracy + + +def create_dflash_loss_mask( + seq_len: int, block_size: int, device: torch.device +) -> torch.Tensor: + """Create DFlash loss mask: excludes block 0 and first position of each block.""" + positions = torch.arange(seq_len, device=device) + block_ids = positions // block_size + + is_block_0 = block_ids == 0 + is_block_start = (positions % block_size) == 0 + + valid_mask = ~is_block_0 & ~is_block_start + return valid_mask.float() diff --git a/SpecForge-ext/specforge/core/eagle3.py b/SpecForge-ext/specforge/core/eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..abf43527f200aedc02ac06b47ee7120c150eb60a --- /dev/null +++ b/SpecForge-ext/specforge/core/eagle3.py @@ -0,0 +1,596 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from transformers.cache_utils import DynamicCache +from yunchang import EXTRACT_FUNC_DICT + +from specforge.core.loss import LogSoftmaxLoss +from specforge.distributed import ( + gather_outputs_and_unpad, + get_sp_ring_group, + get_sp_ulysses_group, +) +from specforge.modeling.draft import Eagle3DraftModel +from specforge.utils import padding + + +class Eagle3Model(nn.Module): + pass + + +class OnlineEagle3Model(Eagle3Model): + """ + In sgl-spec, we implement offline/online training. + Online training means we have the target hidden_states available during training. + Eagle3 using test time training technique (TTT) to train the draft model. + 1. We first extract the hidden states from the target model. + 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). + 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) + 4. We concat the projected hidden states and embedding output as the input for the draft model. + 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) + """ + + def __init__( + self, + draft_model: Eagle3DraftModel, + length: int = 7, + attention_backend="sdpa", + target_model: Optional[Eagle3Model] = None, + ): + """ + Args: + target_model: the target model to extract hidden states. + draft_model: the draft model to be trained. + length: TTT length, it means how many turns to unroll during TTT. + """ + super().__init__() + self.draft_model = draft_model + self.length = length + self.attention_backend = attention_backend + self.target_model = target_model + + if self.attention_backend == "usp": + self.extract_func = EXTRACT_FUNC_DICT["basic"] + self.sp_ring_degree = torch.distributed.get_world_size(get_sp_ring_group()) + self.sp_ulysses_degree = torch.distributed.get_world_size( + get_sp_ulysses_group() + ) + self.sp_world_size = self.sp_ring_degree * self.sp_ulysses_degree + self.sp_rank = torch.distributed.get_rank() % self.sp_world_size + + @torch.compile() + def prepare_usp_input(self, full_input): + shared_input = self.extract_func( + full_input, + rank=self.sp_rank, + world_size=self.sp_world_size, + ).clone() + return shared_input + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor, + hidden_states: torch.Tensor, + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + is_vlm: bool = False, + **kwargs, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. + position_ids: (batch, seq_len) + """ + # Step 1: handle vocab size + target_p_padded, position_mask = _compute_target_p_padded( + target=target, + t2d=self.draft_model.t2d, + loss_mask=loss_mask, + length=self.length, + ) + del target + torch.cuda.empty_cache() + + # basic info + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + # Step 2: project the concatenated hidden states to the target hidden size + if self.attention_backend == "usp": + # NOTE: Split first for USP to parallelize computation and ensure + # gradient consistency without redundant full-sequence projection. + hidden_states = self.prepare_usp_input(hidden_states) + hidden_states = self.draft_model.project_hidden_states(hidden_states) + + # Step 3: process kv cache, position ids and position ids + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: + if is_vlm: + mrope_positions_ids, mrope_position_delta = ( + self.target_model.get_rope_index( + input_ids=input_ids, image_grid_thw=image_grid_thw + ) + ) + position_ids = mrope_positions_ids + else: + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # Step 4: handle attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + if self.attention_backend == "sdpa": + attention_mask = self.draft_model.prepare_decoder_attention_mask( + attention_mask=attention_mask, + hidden_states=hidden_states, + batch_size=batch_size, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + ) + + def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): + # 1. Compute Logits(The part that consumes the most VRAM.) + logits_ = self.draft_model.compute_logits(hs) + logits = gather_outputs_and_unpad(logits_, gather_dim=1) + + # 2. Compute Loss + loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask) + + # 3. Compute Accuracy + with torch.no_grad(): + acc_val = _compute_metric_acc( + logits=logits, + target_p=tgt_p, + position_mask=pos_mask, + loss_mask=l_mask, + ) + return loss_val, acc_val + + # Step 5: run TTT + plosses = [] + vlosses = [] + acces = [] + # for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift + global_input_ids = input_ids + if self.attention_backend in ["sdpa", "fa", "usp"]: + cache_hidden = [[], []] + past_key_values = None + elif self.attention_backend == "flex_attention": + cache_hidden = None + past_key_values = DynamicCache() + else: + raise ValueError(f"Unknown attention backend: {self.attention_backend}") + + for idx in range(self.length): + target_p = target_p_padded[:, idx : idx + seq_length, :] + if self.attention_backend == "usp": + input_ids = self.prepare_usp_input(global_input_ids) + else: + input_ids = global_input_ids + + is_last = idx == self.length - 1 + + # Step 5.1: embed the input ids + inputs_embeds = self.draft_model.embed_input_ids(input_ids) + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + + # Step 5.2: run the draft model backbone + hidden_states_out = self.draft_model.backbone( + input_embeds=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + ) + + # update hidden states for next step + hidden_states = hidden_states_out + + if hidden_states.requires_grad: + loss, acc = checkpoint( + compute_loss_and_acc_checkpointed, + hidden_states, + target_p, + position_mask, + loss_mask, + use_reentrant=False, + ) + else: + loss, acc = compute_loss_and_acc_checkpointed( + hidden_states, target_p, position_mask, loss_mask + ) + + plosses.append(loss) + acces.append(acc) + if not is_last: + # Step 5.7: we need to update the loss mask + global_input_ids = padding(global_input_ids, left=False) + position_mask = padding(position_mask, left=False) + loss_mask = padding(loss_mask, left=False) + # Flex attention mask shirnking is handled inside attention module + return plosses, vlosses, acces + + +class QwenVLOnlineEagle3Model(Eagle3Model): + """ + In sgl-spec, we implement offline/online training. + Online training means we have the target hidden_states available during training. + Eagle3 using test time training technique (TTT) to train the draft model. + 1. We first extract the hidden states from the target model. + 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). + 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) + 4. We concat the projected hidden states and embedding output as the input for the draft model. + 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) + """ + + def __init__( + self, + target_model, + draft_model: Eagle3DraftModel, + processor, + length: int = 7, + attention_backend: str = "sdpa", + ): + """ + Args: + target_model: the target model to extract hidden states. + draft_model: the draft model to be trained. + length: TTT length, it means how many turns to unroll during TTT. + """ + super().__init__() + self.target_model = target_model + self.draft_model = draft_model + self.processor = processor + self.length = length + self.attention_backend = attention_backend + + @torch.no_grad() + def _prepare_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L692 + Extract the hidden states from the target model outputs. + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + device: the device to run the target model, if None, use the input_ids device + pixel_values: image pixel values, used for VLM models + image_grid_thw: image grid thw, used for VLM models + + Returns: + hidden_states: (batch, seq_len, 3*hidden_size) + target: (batch, seq_len, vocab_size) + loss_mask: (batch, seq_len) + input_ids: (batch, seq_len) + """ + + if device is None: + device = input_ids.device + + # run the target model to get the hidden states + outputs = self.target_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + use_cache=False, + ) + + # extract the aux hidden states + # output_hidden_states = True will return the embedding output as well + # so we have an offset of 1 + num_hidden_states = len(outputs.hidden_states) + offset = 1 + num_layers = num_hidden_states - 1 + + # Eagle3 uses 3 aux layers from layer 1, num_layers//2, num_layers-4 + low_aux_layer = 1 + offset + mid_aux_layer = num_layers // 2 - 1 + offset + last_aux_layer = num_layers - 4 + offset + + hidden_states0 = outputs.hidden_states[low_aux_layer] + hidden_states1 = outputs.hidden_states[mid_aux_layer] + hidden_states2 = outputs.hidden_states[last_aux_layer] + + hidden_states = torch.cat( + (hidden_states0, hidden_states1, hidden_states2), dim=-1 + ) + + # apply pading + target = outputs.logits + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + + if target is not None: + target = target.to(device) + loss_mask = loss_mask[..., None] + loss_mask = loss_mask.to(device) + + return hidden_states, target, loss_mask, input_ids + + @torch.no_grad() + def _get_input_embeds( + self, + input_ids: torch.Tensor, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + ) -> torch.Tensor: + # get input embeding with image + # inputs_embeds = self.target_model.model.get_input_embeddings()(input_ids) + inputs_embeds = self.draft_model.embed_input_ids(input_ids) + image_embeds = self.target_model.model.get_image_features( + pixel_values, image_grid_thw + ) + image_embeds = torch.cat(image_embeds, dim=0) + n_image_tokens = ( + input_ids == self.target_model.model.config.image_token_id + ).sum() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.target_model.model.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. + position_ids: (batch, seq_len) + pixel_values: batch image pixel values, used for VLM models + image_grid_thw: (batch, 3), image grid thw, used for VLM models + """ + # Step 0: prepare data with the target model + hidden_states, target, loss_mask, input_ids = self._prepare_data( + input_ids, attention_mask, loss_mask, pixel_values, image_grid_thw + ) + + # Step 1: handle vocab size + target_p_padded, position_mask = _compute_target_p_padded( + target=target, + t2d=self.draft_model.t2d, + loss_mask=loss_mask, + length=self.length, + ) + del target + + # basic info + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + # Step 2: project the concatenated hidden states to the target hidden size + hidden_states = self.draft_model.project_hidden_states(hidden_states) + + # Step 3: process kv cache, position ids and position ids + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + attention_mask_tensor = ( + attention_mask + if not isinstance(attention_mask, dict) + else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal( + attention_mask_tensor[:, 0], dim1=1, dim2=2 + ) + attention_mask_tensor = ( + attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + ) + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + position_ids, rope_deltas = self.target_model.model.get_rope_index( + input_ids, + image_grid_thw, + None, + second_per_grid_ts=None, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + else: + position_ids = position_ids + + # Step 4: handle attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + if self.attention_backend == "sdpa": + attention_mask = self.draft_model.prepare_decoder_attention_mask( + attention_mask=attention_mask, + hidden_states=hidden_states, + batch_size=batch_size, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + ) + + # Step 5: run TTT + plosses = [] + vlosses = [] + acces = [] + if self.attention_backend in ["sdpa", "fa"]: + cache_hidden = [[], []] + past_key_values = None + elif self.attention_backend == "flex_attention": + cache_hidden = None + past_key_values = DynamicCache() + else: + raise ValueError(f"Unknown attention backend: {self.attention_backend}") + + for idx in range(self.length): + target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() + is_last = idx == self.length - 1 + + # Step 5.1: embed the input ids + # inputs_embeds = self._get_input_embeds(input_ids, pixel_values, image_grid_thw) + inputs_embeds = self.draft_model.embed_input_ids(input_ids) + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + + # Step 5.2: run the draft model backbone + hidden_states_out = self.draft_model.backbone( + input_embeds=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + ) + + # update hidden states for next step + hidden_states = hidden_states_out + + # Step 5.4: get logits + logits = self.draft_model.compute_logits(hidden_states) + + # Step 5.5: record metrics first as we in-place modify logits + with torch.no_grad(): + acces.append( + _compute_metric_acc( + logits=logits, + target_p=target_p, + position_mask=position_mask, + loss_mask=loss_mask, + ) + ) + + # Step 5.6: calculate loss, in-place modifies logits! + loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) + plosses.append(loss) + + if not is_last: + # Step 5.7: we need to update the loss mask + input_ids = padding(input_ids, left=False) + position_mask = padding(position_mask, left=False) + loss_mask = padding(loss_mask, left=False) + # Flex attention mask shirnking is handled inside attention module + return plosses, vlosses, acces + + +def _compute_target_p_padded(target, t2d, loss_mask, length): + with torch.no_grad(): + target_p, position_mask = _compute_target_p( + target=target, + t2d=t2d, + loss_mask=loss_mask, + ) + + assert len(target_p.shape) == 3 + target_p_padded = F.pad( + target_p, + pad=(0, 0, 0, length), + mode="constant", + # For bitwise equality with previous code + value=1 / target_p.shape[-1], + ) + + return target_p_padded, position_mask + + +@torch.compile(dynamic=None) +def _compute_target_p(target, t2d, loss_mask): + target_head = target + target_max_token = target_head.argmax(-1) + target_mask = t2d[target_max_token] + target_mask = target_mask[..., None].int() + position_mask = target_mask * loss_mask + target_head = target_head[..., t2d] + target_head = target_head.float() + target_p = nn.Softmax(dim=2)(target_head) + target_p = target_p.detach() + return target_p, position_mask + + +@torch.compile(dynamic=None) +def _compute_metric_acc(logits, target_p, position_mask, loss_mask): + return ( + (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) + ).sum() / loss_mask.sum().clamp_min(1e-6) diff --git a/SpecForge-ext/specforge/data/__init__.py b/SpecForge-ext/specforge/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1385db853aeca011f86d4fbb85c4d2cb426b73e --- /dev/null +++ b/SpecForge-ext/specforge/data/__init__.py @@ -0,0 +1,13 @@ +from .preprocessing import ( + build_eagle3_dataset, + build_offline_eagle3_dataset, + generate_vocab_mapping_file, +) +from .utils import prepare_dp_dataloaders + +__all__ = [ + "build_eagle3_dataset", + "build_offline_eagle3_dataset", + "generate_vocab_mapping_file", + "prepare_dp_dataloaders", +] diff --git a/SpecForge-ext/specforge/data/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe2b056310731c45604c4afe1407961bd1430f76 Binary files /dev/null and b/SpecForge-ext/specforge/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/data/__pycache__/preprocessing.cpython-311.pyc b/SpecForge-ext/specforge/data/__pycache__/preprocessing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bb77fcac64e563dfd92abee0b4758c813a3cd6f Binary files /dev/null and b/SpecForge-ext/specforge/data/__pycache__/preprocessing.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/data/__pycache__/template.cpython-311.pyc b/SpecForge-ext/specforge/data/__pycache__/template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d5465b247ae05f17a50ac586564710476567029 Binary files /dev/null and b/SpecForge-ext/specforge/data/__pycache__/template.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/data/preprocessing.py b/SpecForge-ext/specforge/data/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..648c0d1a3285467303616ba9e7c48b69bee83f1e --- /dev/null +++ b/SpecForge-ext/specforge/data/preprocessing.py @@ -0,0 +1,608 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import warnings +from collections import Counter +from typing import Dict, List, Optional, Tuple, Union + +import torch +from tqdm import tqdm +from transformers import ImageProcessingMixin, PreTrainedTokenizer + +from datasets import Dataset as HFDataset + +try: + from qwen_vl_utils import process_vision_info + + HAS_QWEN_VL_UTILS = True +except ImportError: + HAS_QWEN_VL_UTILS = False + process_vision_info = None + + +from .parse import GeneralParser, HarmonyParser, ThinkingParser +from .template import TEMPLATE_REGISTRY, ChatTemplate + +# define a type called conversation +Conversation = List[Dict[str, str]] + + +# ============================== +# This file is for preprocessing the data +# ============================== + + +def _apply_loss_mask_from_chat_template( + text: str, + offsets: torch.Tensor, + chat_template: ChatTemplate, +) -> torch.Tensor: + """ + Apply loss mask to identify assistant response spans using chat template. + + Args: + text: The formatted conversation text. + offsets: Token offset mapping from tokenizer. + chat_template: The chat template to use for identifying assistant spans. + + Returns: + A tensor indicating which tokens should contribute to the loss (1) or not (0). + """ + loss_mask = torch.zeros(len(offsets), dtype=torch.long) + + user_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.user_header}" + ) + assistant_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" + ) + + # Find spans of assistant responses using regex + assistant_pattern = ( + re.escape(assistant_message_separator) + + r"(.*?)(?=" + + re.escape(user_message_separator) + + "|$)" + ) + + matches_found = 0 + + for match in re.finditer(assistant_pattern, text, re.DOTALL): + matches_found += 1 + # Assistant response text span (excluding assistant_header itself) + assistant_start_char = match.start(1) + assistant_end_char = match.end(1) + + # Mark tokens overlapping with assistant response + for idx, (token_start, token_end) in enumerate(offsets): + # Token is part of the assistant response span + if token_end <= assistant_start_char: + continue # token before assistant text + if token_start > assistant_end_char: + continue # token after assistant text + loss_mask[idx] = 1 + + if matches_found == 0: + print("WARNING: No assistant response spans found in the conversation text.") + + return loss_mask + + +# Copied from https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py +def preprocess_conversations( + tokenizer: PreTrainedTokenizer, + conversations: Union[List[Conversation], List[str]], + chat_template: ChatTemplate, + max_length: int = 2048, + is_preformatted: bool = False, + train_only_last_turn: bool = False, + **kwargs, +) -> Dict[str, List[torch.Tensor]]: + """ + Preprocess a batch of ShareGPT style conversations or pre-formatted text. + + Args: + tokenizer: The tokenizer to use for tokenization. + conversations: A list of conversations (if is_preformatted=False) or + a list of pre-formatted text strings (if is_preformatted=True). + chat_template: The chat template to use for formatting/identifying spans. + max_length: The maximum length of the tokenized input. + is_preformatted: Whether the input is already formatted text strings. + train_only_last_turn: If True, only the last assistant turn contributes to the loss. + + Returns: + A dictionary containing: + - input_ids: List of tokenized input IDs. + - loss_mask: List of loss masks indicating which tokens should contribute to the loss. + - attention_mask: List of attention masks. + """ + + # prepare result + results = {"input_ids": [], "loss_mask": [], "attention_mask": []} + + if chat_template.parser_type == "general": + parser = GeneralParser(tokenizer, chat_template) + elif chat_template.parser_type == "thinking": + parser = ThinkingParser(tokenizer, chat_template) + elif chat_template.parser_type == "openai-harmony": + parser = HarmonyParser(tokenizer, chat_template) + else: + raise ValueError(f"Invalid parser type: {chat_template.parser_type}") + + kwargs_list = [{} for _ in range(len(conversations))] + for key, value_list in kwargs.items(): + for i, value in enumerate(value_list): + kwargs_list[i][key] = value + for source, kwargs_item in zip(conversations, kwargs_list): + if not source: + # if the source is None, skip it + continue + input_ids, loss_mask = parser.parse( + source, + max_length, + preformatted=is_preformatted, + train_only_last_turn=train_only_last_turn, + **kwargs_item, + ) + results["input_ids"].append(input_ids[None, :]) + results["loss_mask"].append(loss_mask[None, :]) + results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) + return results + + +def preprocess_vlm_conversations( + processor: ImageProcessingMixin, + examples: List[Conversation], + chat_template: ChatTemplate, + max_length: int = 2048, +) -> Dict[str, List[torch.Tensor]]: + """ + Preprocess a batch of ShareGPT style conversations. + + Args: + processor: The image processor to use for processing images. + examples: A list of examples, where each example is a dictionary containing: + - image: The image in the conversation. + - conversations: A list of conversations, where each conversation is a list of messages. + chat_template: The chat template to use for formatting the conversations. + max_length: The maximum length of the tokenized input. + + Returns: + A dictionary containing: + - input_ids: List of tokenized input IDs. + - loss_mask: List of loss masks indicating which tokens should contribute to the loss. + - attention_mask: List of attention masks. + - pixel_values: List of pixel values for images in the examples. + - image_grid_thw: List of image grid tensors. + """ + system_prompt = chat_template.system_prompt + + # prepare result + results = { + "input_ids": [], + "loss_mask": [], + "attention_mask": [], + "pixel_values": [], + "image_grid_thw": [], + } + + # Note: currently, we assume that each example has only one image + for i, image in enumerate(examples["image"]): + source = examples["conversations"][i] + messages = [{"role": "system", "content": system_prompt}] + if not source: + # if the source is None, skip it + continue + + if source[0]["role"] != "user": + # if the first message is not from user, skip it + source = source[1:] + + convroles = ["user", "assistant"] + for j, sentence in enumerate(source): + role = sentence["role"] + assert role == convroles[j % 2], f"unexpected role {role}" + if role == "user": + # if the message is from user and has image, process the image + messages.append( + { + "role": role, + "content": [ + { + "type": "image", + "image": image, + }, + {"type": "text", "text": sentence["content"]}, + ], + } + ) + else: + messages.append({"role": role, "content": sentence["content"]}) + + conversation = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + # get vision infor use qwen_vl_utils + if not HAS_QWEN_VL_UTILS: + raise ImportError( + "qwen_vl_utils is required for VLM preprocessing but is not installed. " + "Please install it to use VLM features." + ) + image_inputs, video_inputs = process_vision_info(messages) + assert image_inputs is not None, "image_inputs must not be None" + + encoding = processor( + text=[conversation], + images=image_inputs, + videos=video_inputs, + max_length=max_length, + truncation=True, + return_tensors="pt", + return_offsets_mapping=True, + add_special_tokens=False, + ) + input_ids = encoding.input_ids[0] + offsets = encoding.offset_mapping[0] + pixel_values = encoding.pixel_values + image_grid_thw = encoding.image_grid_thw[0] + + # get conversation with image info for loss mask generation + decoded_conversation = processor.tokenizer.decode( + encoding.input_ids[0], skip_special_tokens=False + ) + + # Apply loss mask + loss_mask = _apply_loss_mask_from_chat_template( + decoded_conversation, offsets, chat_template + ) + + results["input_ids"].append(input_ids[None, :]) + results["loss_mask"].append(loss_mask[None, :]) + results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) + results["pixel_values"].append(pixel_values) + results["image_grid_thw"].append(image_grid_thw[None, :]) + return results + + +def build_eagle3_dataset( + dataset: HFDataset, + tokenizer: PreTrainedTokenizer, + chat_template: Optional[str] = None, + max_length: Optional[int] = 2048, + shuffle_seed: Optional[int] = 42, + num_proc: Optional[int] = 8, + cache_dir: Optional[str] = None, + cache_key: Optional[str] = None, + is_vlm: Optional[bool] = False, + processor: Optional[ImageProcessingMixin] = None, + is_preformatted: Optional[bool] = False, + train_only_last_turn: Optional[bool] = False, +) -> HFDataset: + """ + build eagle3 dataset + + Args: + dataset: HF dataset to process. + tokenizer: The tokenizer to use for tokenization. + chat_template: The chat template to use for formatting conversations. + This includes the system prompt and user/assistant tokens + required to delineate different parts of the conversation + for loss mask generation. + max_length: The maximum length of the tokenized input. + shuffle_seed: The seed for shuffling the dataset. + num_proc: The number of processes to use for multiprocessing. + cache_dir: The directory to use for caching the processed dataset. + cache_key: The key to use for caching the processed dataset. + is_vlm: Whether the dataset is for VLM models. + processor: The image processor to use for processing images. + is_preformatted: Whether the dataset contains preformatted text of the conversation + (e.g. includes system prompt, user and assistant start and end tokens) + and doesn't need to have the chat template applied. + Note that the chat_template still needs to be specified to determine + the assistant spans for loss mask generation. + If True, expects "text" column with ready-to-train text. + If False, expects "conversations" column with ShareGPT format. + train_only_last_turn: If True, only the last assistant turn contributes to the loss. + Useful for thinking models where history may not contain thoughts. + + Returns: + The processed HF dataset. + """ + if is_vlm: + assert processor is not None, "processor must be provided when is_vlm is True" + + # Validate chat_template requirement + if chat_template is None: + raise ValueError("chat_template must be provided for all dataset types") + + assert ( + chat_template in TEMPLATE_REGISTRY.get_all_template_names() + ), f"Chat template {chat_template} not found in TEMPLATE_REGISTRY, you may need to register it first" + + template: ChatTemplate = TEMPLATE_REGISTRY.get(chat_template) + + dataset = dataset.shuffle(seed=shuffle_seed) + original_cols = dataset.column_names + + def preprocess_function(examples): + # Handle different dataset formats + if is_vlm: + processed = preprocess_vlm_conversations( + processor, + examples, + template, + max_length, + ) + elif is_preformatted: + # Handle pre-formatted text (should be in "text" column) + if "text" not in examples: + raise ValueError( + f"Expected 'text' column for is_preformatted=True, but found columns: {list(examples.keys())}" + ) + processed = preprocess_conversations( + tokenizer, + examples["text"], + template, + max_length, + is_preformatted=True, + train_only_last_turn=train_only_last_turn, + ) + else: + # Handle ShareGPT conversations + if "conversations" not in examples: + raise ValueError( + f"Expected 'conversations' column for is_preformatted=False, but found columns: {list(examples.keys())}" + ) + conversations = examples.pop("conversations") + if "id" in examples: + examples.pop("id") + processed = preprocess_conversations( + tokenizer, + conversations, + template, + max_length, + is_preformatted=False, + train_only_last_turn=train_only_last_turn, + **examples, + ) + + return processed + + # Process dataset only once + if cache_dir and cache_key: + load_from_cache_file = True + os.makedirs(cache_dir, exist_ok=True) + cache_file_name = os.path.join(cache_dir, f"{cache_key}.pkl") + print(f"dataset is cached at {cache_file_name}") + elif cache_dir is None and cache_key is None: + load_from_cache_file = False + cache_file_name = None + print(f"dataset is not cached") + else: + warnings.warn( + f"cache_dir and cache_key must be provided together to make caching work" + ) + + # adjust batch size based on dataset type + if is_vlm: + batch_size = ( + 200 # reduce batch size for VLM datasets to avoid PyArrow offset overflow + ) + else: + batch_size = 1000 # default for conversations + dataset = dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + batch_size=batch_size, + remove_columns=original_cols, + # keep_in_memory=True, + load_from_cache_file=load_from_cache_file, + cache_file_name=cache_file_name, + ) + + dataset.set_format(type="torch") + return dataset + + +# ============================== +# Offline Eagle3 Dataset +# ============================== +# modified from https://github.com/NickL77/BaldEagle/blob/master/train/modules/data/data.py +def list_local_files(path, suffixes=[".ckpt"]): + datapaths = [] + for root, directories, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + datapaths.append(file_path) + for suffix in suffixes: + datapaths = [f_name for f_name in datapaths if f_name.endswith(suffix)] + return datapaths + + +class OfflineEagle3Dataset(torch.utils.data.Dataset): + def __init__(self, datapath, transform=None, max_len=2048): + self.datapaths = datapath + self.transform = transform + self._epoch = 0 + self.max_len = max_len + + @staticmethod + def process_data(data, max_len, transform=None): + new_data = {} + # Squeeze due to our data generation script adding a batch dimension + hidden_state = data["aux_hidden_state"].squeeze(0)[:max_len][None, :] + target = data["hidden_state"].squeeze(0)[:max_len][None, :] + + input_ids = data["input_ids"][:max_len][None, :] + loss_mask = data["loss_mask"][:max_len][None, :] + loss_mask[0, -1] = 0 + + new_data["attention_mask"] = torch.ones_like(loss_mask, dtype=torch.long) + new_data["loss_mask"] = loss_mask + new_data["target"] = target + new_data["hidden_state"] = hidden_state + new_data["input_ids"] = input_ids + if transform: + new_data = transform(new_data) + return new_data + + def __len__(self): + return len(self.datapaths) + + def _open_file(self, index): + return torch.load(self.datapaths[index], weights_only=False) + + def __getitem__(self, index): + try: + data = self._open_file(index) + except Exception as e: + print(f"ERROR Failed to load {self.datapaths[index]} with error {e}") + data = self._open_file(0) + return self.process_data(data, self.max_len, self.transform) + + def set_epoch(self, epoch): + self._epoch = epoch + + +def build_offline_eagle3_dataset( + hidden_states_path: str, + max_len: int = 2048, +) -> torch.utils.data.Dataset: + return OfflineEagle3Dataset( + list_local_files(hidden_states_path), + max_len=max_len, + ) + + +# ============================== +# Vocab Mapping +# ============================== +def generate_vocab_mapping_file( + dataset: HFDataset, + target_vocab_size: int, + draft_vocab_size: int, + cache_dir: str = "./cache/vocab_mapping", + cache_key: str = "vocab_mapping", +) -> str: + """ + Generate a vocab mapping file for the dataset. + + Args: + dataset: The dataset to process. + target_vocab_size: The target vocabulary size. + draft_vocab_size: The draft vocabulary size. + cache_dir: The directory to use for caching the vocab mapping file. + cache_key: The key to use for caching the vocab mapping file. + + Returns: + The path to the vocab mapping file. + """ + # prepare cache direcotory + os.makedirs(cache_dir, exist_ok=True) + vocab_mapping_path = os.path.join(cache_dir, f"{cache_key}.pt") + + if os.path.exists(vocab_mapping_path): + print(f"Loading vocab mapping from the cached file at: {vocab_mapping_path}") + return vocab_mapping_path + + # we first count the frequency of effectiev tokens in the dataset + token_dict = Counter() + for input_ids, loss_mask in tqdm( + zip(dataset["input_ids"], dataset["loss_mask"]), + total=len(dataset), + desc="Counting tokens for vocab mapping", + ): + masked_ids = input_ids[loss_mask == 1] + unique_ids, counts = masked_ids.unique(return_counts=True) + batch_token_dict = dict(zip(unique_ids.tolist(), counts.tolist())) + token_dict.update(batch_token_dict) + + # generate the d2t and t2d mapping + d2t, t2d = process_token_dict_to_mappings( + token_dict, + draft_vocab_size, + target_vocab_size, + ) + + vocab_mapping = { + "d2t": d2t, + "t2d": t2d, + } + torch.save(vocab_mapping, vocab_mapping_path) + print(f"Saved vocab mapping to: {vocab_mapping_path}") + return vocab_mapping_path + + +def process_token_dict_to_mappings( + token_dict: Counter, + draft_vocab_size: int, + target_vocab_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Process token_dict to create d2t and t2d mappings, with optional caching. + + Args: + token_dict: A Counter object mapping token ids to their frequencies. + draft_vocab_size: The size of the draft vocabulary. + target_vocab_size: The size of the target vocabulary. + + Returns: + A tuple containing: + - d2t: A tensor mapping draft token ids to target token ids. + - t2d: A tensor mapping target token ids to draft token ids. + """ + if len(token_dict) < draft_vocab_size: + existing_tokens = set(token_dict.keys()) + missing_tokens = set(range(draft_vocab_size)) - existing_tokens + for token in missing_tokens: + token_dict[token] = 0 + if len(token_dict) >= draft_vocab_size: + break + print(f"Added missing tokens to reach draft vocab size: {draft_vocab_size}") + print(f"Total tokens after addition: {len(token_dict)}") + total_frequency = sum(token_dict.values()) + top_N = token_dict.most_common(draft_vocab_size) + top_N_frequency_sum = sum(freq for key, freq in top_N) + + if total_frequency == 0: + print( + "Warning: Total token frequency is zero. All tokens will have zero ratio." + ) + top_N_ratio = 0.0 + else: + top_N_ratio = top_N_frequency_sum / total_frequency + + print(f"top {draft_vocab_size} token frequency ratio: {top_N_ratio:.2%}") + used_tokens = [key for key, freq in top_N] + used_tokens.sort() + + d2t = [used_tokens[i] - i for i in range(len(used_tokens))] + t2d = [i in used_tokens for i in range(target_vocab_size)] + d2t = torch.tensor(d2t) + t2d = torch.tensor(t2d) + + return d2t, t2d diff --git a/SpecForge-ext/specforge/data/template.py b/SpecForge-ext/specforge/data/template.py new file mode 100644 index 0000000000000000000000000000000000000000..4803db9af06fb4feb4d6cb5b577ca5766be82a4b --- /dev/null +++ b/SpecForge-ext/specforge/data/template.py @@ -0,0 +1,310 @@ +# Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/chat_template.py#L13 +from typing import List + +from pydantic import BaseModel + + +class ChatTemplate(BaseModel): + """ + This is a dataclass for the chat template. + + Args: + assistant_header(str): The header for the assistant. + user_header(str): The header for the user. + system_prompt(str): The system prompt. + end_of_turn_token(str): The end token of a turn of conversation. + """ + + assistant_header: str | None + user_header: str | None + system_prompt: str | None + end_of_turn_token: str | None + parser_type: str = "general" + assistant_pattern_type: str = "general" + enable_thinking: bool = False + + +class TemplateRegistry: + """ + This is a registry for the chat template. Sgl-spec will register some common chat templates here. + If you have a custom chat template, you can register it via the example below. + + Example: + ```python + from specforge.data.template import TEMPLATE_REGISTRY, ChatTemplate + TEMPLATE_REGISTRY.register( + name="custom", + template=ChatTemplate( + assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n", + user_header="<|start_header_id|>user<|end_header_id|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|eot_id|>" + ) + ) + ``` + """ + + def __init__(self): + self.templates = {} + + def register(self, name: str, template: ChatTemplate, override: bool = False): + """ + Register a chat template for a model type. + + Args: + name(str): The name of the chat template. + template(ChatTemplate): The chat template. + override(bool): Whether to override the existing template, default to False + """ + assert ( + not override and name not in self.templates + ), f"Chat template for the model type {name} has already been registered" + self.templates[name] = template + + def get(self, name: str) -> ChatTemplate: + """ + Get the chat template for a model type. + + Args: + name(str): The name of the chat template. + + Returns: + ChatTemplate: The chat template. + """ + return self.templates[name] + + def get_all_template_names(self) -> List[str]: + """ + Get all the template names. + + Returns: + List[str]: The list of template names. + """ + return list(self.templates.keys()) + + +# global registry +TEMPLATE_REGISTRY = TemplateRegistry() + +# Register the common template here +TEMPLATE_REGISTRY.register( + name="llama3", + template=ChatTemplate( + assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n", + user_header="<|start_header_id|>user<|end_header_id|>", + system_prompt="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.", + end_of_turn_token="<|eot_id|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="llama4", + template=ChatTemplate( + assistant_header="<|header_start|>assistant<|header_end|>\n\n", + user_header="<|header_start|>user<|header_end|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|eot|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen2-vl", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="phi3", + template=ChatTemplate( + assistant_header="<|assistant|>\n", + user_header="<|user|>\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="phi4", + template=ChatTemplate( + assistant_header="<|im_start|>assistant<|im_sep|>", + user_header="<|im_start|>user<|im_sep|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="phi4-mini", + template=ChatTemplate( + assistant_header="<|assistant|>", + user_header="<|user|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="gpt-oss-naive", + template=ChatTemplate( + assistant_header="<|start|>assistant<|channel|>analysis<|message|>", + user_header="<|start|>user<|message|>", + system_prompt=None, + end_of_turn_token="<|end|>", + ), +) + + +TEMPLATE_REGISTRY.register( + name="gpt-oss", + template=ChatTemplate( + assistant_header=None, # the headers are not applicable to openai-harmony's channel tags + user_header=None, + system_prompt=None, + end_of_turn_token=None, + parser_type="openai-harmony", + ), +) + +TEMPLATE_REGISTRY.register( + name="deepseek-r1-distill", + template=ChatTemplate( + assistant_header="<|Assistant|>", + user_header="<|User|>", + end_of_turn_token=None, + system_prompt=None, + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen3-thinking", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + parser_type="thinking", + enable_thinking=True, + ), +) + + +TEMPLATE_REGISTRY.register( + name="qwen3-instruct", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n\n\n\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen3-next-thinking", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + parser_type="thinking", + enable_thinking=True, + ), +) + +TEMPLATE_REGISTRY.register( + name="kimi-k2-thinking", + template=ChatTemplate( + assistant_header="<|im_assistant|>assistant<|im_middle|>", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>", + parser_type="thinking", + enable_thinking=True, + ), +) + +TEMPLATE_REGISTRY.register( + name="kimi-k2-instruct", + template=ChatTemplate( + assistant_header="<|im_assistant|>assistant<|im_middle|>", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="deepseek-v3", + template=ChatTemplate( + assistant_header="<|Assistant|>", + user_header="<|User|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|end▁of▁sentence|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="ling-flash-2.0", + template=ChatTemplate( + assistant_header="ASSISTANT", + user_header="HUMAN", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|role_end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="deepseek-v32", + template=ChatTemplate( + assistant_header="<|Assistant|>", + user_header="<|User|>", + system_prompt="", + end_of_turn_token="<|end▁of▁sentence|>", + parser_type="thinking", + enable_thinking=True, + ), +) + +TEMPLATE_REGISTRY.register( + name="gemma", + template=ChatTemplate( + assistant_header="model\n", + user_header="user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="longcat", + template=ChatTemplate( + assistant_header=" ASSISTANT:", + user_header=" USER:", + system_prompt="You are a helpful assistant.", + end_of_turn_token="", + assistant_pattern_type="longcat", + ), +) + +TEMPLATE_REGISTRY.register( + name="longcat_xml", + template=ChatTemplate( + assistant_header="", + user_header="", + system_prompt="You are a helpful assistant.", + end_of_turn_token="", + ), +) diff --git a/SpecForge-ext/specforge/data/utils.py b/SpecForge-ext/specforge/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9668680dc696230e97938ccba7f9277918220b87 --- /dev/null +++ b/SpecForge-ext/specforge/data/utils.py @@ -0,0 +1,326 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler + +from datasets import Dataset +from specforge.distributed import get_draft_sp_group + + +class DataCollatorWithPadding: + """ + Datacollator that will dynamically pad the inputs for batching. + """ + + def __init__(self): + self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group()) + + def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad to the longest sequence in the batch. + + Args: + intensors: (B, n, S) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N, S) + """ + B, n, S = intensors.shape + padding_tensor = torch.zeros( + B, N - n, S, dtype=intensors.dtype, device=intensors.device + ) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad 2D tensor to the longest sequence in the batch. + + Args: + intensors: (B, n) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N) + """ + B, n = intensors.shape + padding_tensor = torch.zeros( + B, N - n, dtype=intensors.dtype, device=intensors.device + ) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Collate a batch of features. + + Args: + features: A list of features, where each feature is a dictionary containing: + - input_ids: torch.Tensor of shape (n,) + - attention_mask: torch.Tensor of shape (n,) + - loss_mask: torch.Tensor of shape (n,) + + Returns: + A dictionary containing: + - input_ids: torch.Tensor of shape (B, N) + - attention_mask: torch.Tensor of shape (B, N) + - loss_mask: torch.Tensor of shape (B, N) + """ + max_length = max(item["input_ids"].shape[1] for item in features) + # pad for sequence parrel + max_length = ( + (max_length + self.sp_degree - 1) // self.sp_degree + ) * self.sp_degree + batch_input_ids = torch.cat( + [self.paddingtensor2D(item["input_ids"], max_length) for item in features] + ) + batch_attention_mask = torch.cat( + [ + self.paddingtensor2D(item["attention_mask"], max_length) + for item in features + ] + ) + batch_loss_mask = torch.cat( + [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] + ) + batch = { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + "loss_mask": batch_loss_mask, + "hidden_state": None, + "target": None, + } + if all("hidden_state" in item for item in features): + assert all( + "target" in item for item in features + ), "target is required when hidden_state is provided" + batch["hidden_state"] = torch.cat( + [ + self.paddingtensor(item["hidden_state"], max_length) + for item in features + ] + ) + batch["target"] = torch.cat( + [self.paddingtensor(item["target"], max_length) for item in features] + ) + return batch + + +class VlmDataCollatorWithPadding: + """ + Datacollator that will dynamically pad the inputs for batching. + """ + + def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad to the longest sequence in the batch. + + Args: + intensors: (B, n, S) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N, S) + """ + B, n, S = intensors.shape + padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad 2D tensor to the longest sequence in the batch. + + Args: + intensors: (B, n) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N) + """ + B, n = intensors.shape + padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Collate a batch of features. + + Args: + features: A list of features, where each feature is a dictionary containing: + - input_ids: torch.Tensor of shape (n,) + - attention_mask: torch.Tensor of shape (n,) + - loss_mask: torch.Tensor of shape (n,) + - pixel_values: torch.Tensor of shape (grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size) + - image_grid_thw: torch.Tensor of shape (3,) + + Returns: + A dictionary containing: + - input_ids: torch.Tensor of shape (B, N) + - attention_mask: torch.Tensor of shape (B, N) + - loss_mask: torch.Tensor of shape (B, N) + """ + max_length = max(item["input_ids"].shape[1] for item in features) + batch_input_ids = torch.cat( + [self.paddingtensor2D(item["input_ids"], max_length) for item in features] + ) + batch_attention_mask = torch.cat( + [ + self.paddingtensor2D(item["attention_mask"], max_length) + for item in features + ] + ) + batch_loss_mask = torch.cat( + [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] + ) + batch_pixel_values = torch.cat( + [item["pixel_values"] for item in features], dim=0 + ) + batch_image_grid_thw = torch.cat( + [item["image_grid_thw"] for item in features], dim=0 + ) + batch = { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + "loss_mask": batch_loss_mask, + "pixel_values": batch_pixel_values, + "image_grid_thw": batch_image_grid_thw, + "hidden_state": None, + "target": None, + } + if all("hidden_state" in item for item in features): + assert all( + "target" in item for item in features + ), "target is required when hidden_state is provided" + batch["hidden_state"] = torch.cat( + [ + self.paddingtensor(item["hidden_state"], max_length) + for item in features + ] + ) + batch["target"] = torch.cat( + [self.paddingtensor(item["target"], max_length) for item in features] + ) + return batch + + +def prepare_dp_dataloaders( + dataset: Dataset, + batch_size: int, + num_workers: int = 4, + process_group: Optional[dist.ProcessGroup] = None, + pin_memory: Optional[bool] = False, + shuffle: Optional[bool] = False, + is_vlm: Optional[bool] = False, + prefetch_factor: Optional[int] = 2, + **dataloader_kwargs, +) -> DataLoader: + """ + Prepare dataloader for distributed data parallel training. + + Args: + dataset: The dataset to load data from. + batch_size: The batch size for each GPU. + num_workers: The number of workers for data loading. + process_group: The process group for distributed training. + pin_memory: Whether to pin memory for data loading. + shuffle: Whether to shuffle the dataset. + is_vlm: Whether the dataset is a vision-language model dataset. + **dataloader_kwargs: Additional keyword arguments for the DataLoader. + + Returns: + A DataLoader for the dataset. + """ + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle + ) + if is_vlm: + datacollator_cls = VlmDataCollatorWithPadding + else: + datacollator_cls = DataCollatorWithPadding + + if num_workers == 0: + prefetch_factor = None + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + collate_fn=datacollator_cls(), + drop_last=True, + **dataloader_kwargs, + ) + return dataloader + + +def parse_harmony_message_content(content): + """ + 解析 content 字符串中的 Harmony 格式。 + 如果匹配到 Harmony 格式,返回包含 channel 和 content 的列表; + 否则,返回原内容并标记为默认 channel。 + """ + # 匹配 <|channel|>xxx<|message|>yyy<|end|> + pattern = r"<\|channel\|>(.*?)<\|message\|>(.*?)<\|end|>" + matches = re.findall(pattern, content, re.DOTALL) + + if not matches: + # 如果没有匹配到 Harmony 标签,视作普通文本 + return [{"channel": "text", "content": content}] + + results = [] + for channel, msg_body in matches: + results.append({"channel": channel.strip(), "content": msg_body.strip()}) + return results + + +def process_harmony_conversations(conversation): + """ + 处理传入的 list[list[dict]] 结构 + """ + new_conversation = [] + for msg in conversation: + role = msg.get("role") + original_content = msg.get("content", "") + + # 解析 content 中的 Harmony 结构 + segments = parse_harmony_message_content(original_content) + + # 为每个解析出的通道生成一个新的消息字典 + for seg in segments: + new_msg = { + "role": role, + "channel": seg["channel"], # 新增字段标识通道 + "content": seg["content"], + } + new_conversation.append(new_msg) + + return new_conversation diff --git a/SpecForge-ext/specforge/layers/__init__.py b/SpecForge-ext/specforge/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b71718d39de7248cd0c33732c920f129ddd40001 --- /dev/null +++ b/SpecForge-ext/specforge/layers/__init__.py @@ -0,0 +1,10 @@ +from .embedding import VocabParallelEmbedding +from .linear import ColumnParallelLinear, RowParallelLinear +from .lm_head import ParallelLMHead + +__all__ = [ + "VocabParallelEmbedding", + "ColumnParallelLinear", + "RowParallelLinear", + "ParallelLMHead", +] diff --git a/SpecForge-ext/specforge/layers/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6db9357a7a5ae98aa143d916e254b6ab994dc60c Binary files /dev/null and b/SpecForge-ext/specforge/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/layers/__pycache__/lm_head.cpython-311.pyc b/SpecForge-ext/specforge/layers/__pycache__/lm_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6928487ea33c88dab9a736a303a4fbc2a956e715 Binary files /dev/null and b/SpecForge-ext/specforge/layers/__pycache__/lm_head.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/layers/embedding.py b/SpecForge-ext/specforge/layers/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..336d5776c5d1585a0a8ec0f3a24d2458d77703aa --- /dev/null +++ b/SpecForge-ext/specforge/layers/embedding.py @@ -0,0 +1,132 @@ +import math +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from specforge.distributed import get_tp_group, shard_tensor + + +class VocabParallelEmbedding(nn.Module): + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + + # tp-realted + self.tp_group = get_tp_group() + self.tp_rank = dist.get_rank(self.tp_group) + self.tp_size = dist.get_world_size(self.tp_group) + + # deal with the case where the embedding is not divisible by the TP size + self.num_embeddings_per_shard = math.ceil(num_embeddings / self.tp_size) + self.padded_num_embeddings = ( + self.num_embeddings_per_shard * self.tp_size - self.num_embeddings + ) + self.vocab_start_index = self.tp_rank * self.num_embeddings_per_shard + self.vocab_end_index = min( + self.vocab_start_index + self.num_embeddings_per_shard, + self.num_embeddings, + ) + + if ( + padding_idx is not None + and padding_idx >= self.vocab_start_index + and padding_idx < self.vocab_end_index + ): + self.padding_idx = padding_idx - self.vocab_start_index + else: + self.padding_idx = None + + self.weight = nn.Parameter( + torch.empty( + (self.num_embeddings_per_shard, self.embedding_dim), **factory_kwargs + ), + requires_grad=True, + ) + self.reset_parameters() + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "weight" in state_dict: + value = state_dict["weight"] + + # pad this value if it is not divisible by the TP size + if value.shape[0] % self.tp_size != 0: + padding_size = self.tp_size - value.shape[0] % self.tp_size + value = F.pad(value, (0, 0, 0, padding_size)) + state_dict["weight"] = shard_tensor(value, self.tp_group, 0) + + def reset_parameters(self) -> None: + torch.nn.init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def generate_mask(self, input_): + # generate the mask for the vocab which is only owned by the current rank + mask = (input_ >= self.vocab_start_index) & (input_ < self.vocab_end_index) + return mask + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + mask = self.generate_mask(input_) + masked_input = input_ - self.vocab_start_index + masked_input[~mask] = 0 + else: + masked_input = input_ + + output_parallel = F.embedding( + masked_input, + self.weight, + padding_idx=self.padding_idx, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + + # Mask the output embedding. + if self.tp_size > 1: + output_parallel[~mask] = 0 + # Reduce across all the model parallel GPUs. + dist.all_reduce(output_parallel, op=dist.ReduceOp.SUM, group=self.tp_group) + output = output_parallel + else: + output = output_parallel + return output diff --git a/SpecForge-ext/specforge/layers/linear.py b/SpecForge-ext/specforge/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c512d2139d32b233efa5795a234a1426d0e5e4 --- /dev/null +++ b/SpecForge-ext/specforge/layers/linear.py @@ -0,0 +1,204 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from specforge.distributed import get_tp_group, shard_tensor + + +class RowParallelLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + kv_head_replicas=False, + layout_type: str = "normal", + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.layout_type = layout_type + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.tp_rank = dist.get_rank(self.tp_group) + + self.in_features = in_features + self.out_features = out_features + + if kv_head_replicas: + self.in_features_per_shard = in_features + else: + self.in_features_per_shard = in_features // self.tp_size + self.weight = nn.Parameter( + torch.empty(self.out_features, self.in_features_per_shard, **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + """ + This is a state dict hook to be triggered before loading the state dict. This will shard the weights and biases according to the layout type. + """ + if self.layout_type == "normal": + self.handle_normal_layout(state_dict, *args) + else: + raise ValueError(f"Invalid layout type: {self.layout_type}") + + def handle_normal_layout(self, state_dict, *args): + # shard the weights + if "weight" in state_dict: + state_dict["weight"] = shard_tensor(state_dict["weight"], self.tp_group, -1) + + if "bias" in state_dict and self.tp_rank != 0: + state_dict["bias"] = torch.zeros_like(state_dict["bias"]) + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self): + return f"RowParallelLinear(in_features={self.in_features_per_shard}, out_features={self.out_features}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" + + +class ColumnParallelLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + layout_type: str = "normal", + kv_head_replicas=False, + kv_head_idx=None, + total_num_kv_heads=None, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.layout_type = layout_type + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.tp_rank = dist.get_rank(self.tp_group) + + self.in_features = in_features + self.out_features = out_features + self.kv_head_replicas = kv_head_replicas + self.kv_head_idx = kv_head_idx + self.total_num_kv_heads = total_num_kv_heads + if self.kv_head_replicas: + self.out_features_per_shard = out_features + else: + self.out_features_per_shard = out_features // self.tp_size + + self.weight = nn.Parameter( + torch.empty(self.out_features_per_shard, self.in_features, **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter( + torch.empty(self.out_features_per_shard, **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + """ + This is a state dict hook to be triggered before loading the state dict. This will shard the weights and biases according to the layout type. + """ + if self.kv_head_replicas: + assert self.kv_head_idx is not None + assert self.layout_type == "normal" + self.handle_kv_head_replicas(state_dict, *args) + else: + if self.layout_type == "normal": + self.handle_normal_layout(state_dict, *args) + elif self.layout_type == "merged_qkv": + self.handle_merged_qkv(state_dict, *args) + elif self.layout_type == "gate_up": + self.handle_gate_up_layout(state_dict, *args) + else: + raise ValueError(f"Invalid layout type: {self.layout_type}") + + def handle_kv_head_replicas(self, state_dict, *args): + """ + This is a special case for GQA where the key/value are split according to the number of kv heads and the head which belongs to this rank. + As the TP size is larger than the number of kv heads, we only keep one kv head per rank. + """ + if "weight" in state_dict: + state_dict["weight"] = state_dict["weight"].chunk( + self.total_num_kv_heads, dim=0 + )[self.kv_head_idx] + if "bias" in state_dict and state_dict["bias"] is not None: + state_dict["bias"] = state_dict["bias"].chunk( + self.total_num_kv_heads, dim=0 + )[self.kv_head_idx] + + def handle_normal_layout(self, state_dict, *args): + """ + This shards the weights and biases along the column dimension. + """ + # shard the weights + if "weight" in state_dict: + state_dict["weight"] = shard_tensor(state_dict["weight"], self.tp_group, 0) + + if "bias" in state_dict and state_dict["bias"] is not None: + state_dict["bias"] = shard_tensor(state_dict["bias"], self.tp_group, 0) + + def handle_gate_up_layout(self, state_dict, *args): + """ + This handles the gate_up layout where the gate and up weights are concatenated along the column dimension. + """ + if "weight" in state_dict: + gate, up = state_dict["weight"].chunk(2, dim=0) + gate = shard_tensor(gate, self.tp_group, 0) + up = shard_tensor(up, self.tp_group, 0) + state_dict["weight"] = torch.cat((gate, up), dim=0) + + if "bias" in state_dict and state_dict["bias"] is not None: + gate, up = state_dict["bias"].chunk(2, dim=0) + gate = shard_tensor(gate, self.tp_group, 0) + up = shard_tensor(up, self.tp_group, 0) + state_dict["bias"] = torch.cat((gate, up), dim=0) + + def handle_merged_qkv(self, state_dict, *args): + """ + This handles the merged QKV layout where the q, k, v weights are concatenated along the column dimension. + """ + if "weight" in state_dict: + # need to split into qkv and take the correct chunk for the rank + q, k, v = state_dict["weight"].chunk(3, dim=0) + q = shard_tensor(q, self.tp_group, 0) + k = shard_tensor(k, self.tp_group, 0) + v = shard_tensor(v, self.tp_group, 0) + state_dict["weight"] = torch.cat((q, k, v), dim=0) + + if "bias" in state_dict and state_dict["bias"] is not None: + q, k, v = state_dict["bias"].chunk(3, dim=0) + q = shard_tensor(q, self.tp_group, 0) + k = shard_tensor(k, self.tp_group, 0) + v = shard_tensor(v, self.tp_group, 0) + state_dict["bias"] = torch.cat((q, k, v), dim=0) + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self): + return f"ColumnParallelLinear(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" diff --git a/SpecForge-ext/specforge/layers/lm_head.py b/SpecForge-ext/specforge/layers/lm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d50b089da761b2b85396fbf8b40aa1a5d65133 --- /dev/null +++ b/SpecForge-ext/specforge/layers/lm_head.py @@ -0,0 +1,109 @@ +import math +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from specforge.distributed import get_tp_group, shard_tensor + + +class ParallelLMHead(nn.Module): + + def __init__( + self, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.in_features = in_features + self.out_features = out_features + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.tp_rank = dist.get_rank(self.tp_group) + + # tp-related + self.out_features_per_shard = math.ceil(out_features / self.tp_size) + self.padded_out_features = ( + self.out_features_per_shard * self.tp_size - out_features + ) + assert ( + self.out_features_per_shard * self.tp_size + == out_features + self.padded_out_features + ) + + self.weight = nn.Parameter( + torch.empty(self.out_features_per_shard, self.in_features, **factory_kwargs) + ) + self.bias = ( + nn.Parameter(torch.zeros(self.out_features_per_shard, **factory_kwargs)) + if bias + else None + ) + + # init params + self.reset_parameters() + + # handle weight loading + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "weight" in state_dict: + value = state_dict["weight"] + + # pad this value if it is not divisible by the TP size + if value.shape[0] % self.tp_size != 0: + padding_size = self.tp_size - value.shape[0] % self.tp_size + value = F.pad(value, (0, 0, 0, padding_size)) + state_dict["weight"] = shard_tensor(value, self.tp_group, 0) + + if "bias" in state_dict: + value = state_dict["bias"] + + # pad this value if it is not divisible by the TP size + if value.shape[0] % self.tp_size != 0: + padding_size = self.tp_size - value.shape[0] % self.tp_size + value = F.pad(value, (0, padding_size)) + state_dict["bias"] = shard_tensor(value, self.tp_group, 0) + + def forward(self, hidden: torch.Tensor, gather_output: bool = False): + """ + hidden: [B, T, H] or [N, H] + returns: + - if gather_output=False: local logits [*, local_vocab] and (start,end) for stitching + - if gather_output=True: full logits [*, vocab] via all-gather (use for inference) + """ + orig_shape = hidden.shape + hidden = hidden.reshape(-1, self.in_features) # [N, H] + + local_logits = hidden @ self.weight.T # [N, local_vocab] + + if self.bias is not None: + local_logits = local_logits + self.bias + + if not gather_output or self.tp_size == 1: + return local_logits.view( + *orig_shape[:-1], self.out_features_per_shard + ).contiguous() + else: + # all-gather shards along vocab dim + chunks = [torch.empty_like(local_logits) for _ in range(self.tp_size)] + dist.all_gather(chunks, local_logits, group=self.tp_group) + full = torch.cat(chunks, dim=-1)[ + :, : self.out_features + ] # trim padding from ceil-div + return full.view(*orig_shape[:-1], self.out_features).contiguous() + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self): + return f"ParallelLMHead(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" diff --git a/SpecForge-ext/specforge/layers/ring/__init__.py b/SpecForge-ext/specforge/layers/ring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0a04a8f5eae08db74b9697dd8d1e2ae946edc9 --- /dev/null +++ b/SpecForge-ext/specforge/layers/ring/__init__.py @@ -0,0 +1,12 @@ +# adapt from https://github.com/feifeibear/long-context-attention/tree/main/yunchang +from .ring_flash_attn import ( + ring_flash_attn_func, + ring_flash_attn_kvpacked_func, + ring_flash_attn_qkvpacked_func, +) + +__all__ = [ + "ring_flash_attn_func", + "ring_flash_attn_kvpacked_func", + "ring_flash_attn_qkvpacked_func", +] diff --git a/SpecForge-ext/specforge/layers/ring/ring_flash_attn.py b/SpecForge-ext/specforge/layers/ring/ring_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..3c89b7e4a33431b83d4abe7bb64fd7b52a396523 --- /dev/null +++ b/SpecForge-ext/specforge/layers/ring/ring_flash_attn.py @@ -0,0 +1,336 @@ +import torch +from yunchang.kernels import AttnType, select_flash_attn_impl + +from .utils import RingComm, update_out_and_lse + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + attn_type: AttnType = AttnType.FA, + attn_processor=None, +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if not causal or step <= comm.rank: + fn = select_flash_attn_impl( + attn_type, stage="fwd-only", attn_processor=attn_processor + ) + block_out, block_lse = fn( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal and step == 0, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + if attn_type == AttnType.SPARSE_SAGE: + out, lse = block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + if attn_type != AttnType.SPARSE_SAGE: + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + attn_type: AttnType = AttnType.FA, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + fn = select_flash_attn_impl(attn_type, stage="bwd-only") + fn( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + dropout_p, + softmax_scale, + bwd_causal, + window_size, + softcap, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + group, + attn_type, + attn_processor, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=False, + attn_type=attn_type, + attn_processor=attn_processor, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + ctx.attn_type = attn_type + ctx.attn_processor = attn_processor + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + softcap=ctx.softcap, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + attn_type=ctx.attn_type, + ) + return ( + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, + attn_processor=None, +): + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + attn_processor, + ) diff --git a/SpecForge-ext/specforge/modeling/__init__.py b/SpecForge-ext/specforge/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09999d60bc39243219b2c346154fe51ff4594dce --- /dev/null +++ b/SpecForge-ext/specforge/modeling/__init__.py @@ -0,0 +1,19 @@ +# from .auto import AutoDistributedTargetModel, AutoDraftModelConfig, AutoEagle3DraftModel +from .auto import AutoDraftModelConfig, AutoEagle3DraftModel +from .draft.llama3_eagle import LlamaForCausalLMEagle3 +from .target.eagle3_target_model import ( + CustomEagle3TargetModel, + HFEagle3TargetModel, + SGLangEagle3TargetModel, + get_eagle3_target_model, +) + +__all__ = [ + "LlamaForCausalLMEagle3", + "SGLangEagle3TargetModel", + "HFEagle3TargetModel", + "CustomEagle3TargetModel", + "get_eagle3_target_model", + "AutoDraftModelConfig", + "AutoEagle3DraftModel", +] diff --git a/SpecForge-ext/specforge/modeling/_mask_utils.py b/SpecForge-ext/specforge/modeling/_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bedb200299e24ecf531c117618e55837c49facbe --- /dev/null +++ b/SpecForge-ext/specforge/modeling/_mask_utils.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) diff --git a/SpecForge-ext/specforge/modeling/auto.py b/SpecForge-ext/specforge/modeling/auto.py new file mode 100644 index 0000000000000000000000000000000000000000..1e48a43e7a62748f802671500b23adf74c6dd03a --- /dev/null +++ b/SpecForge-ext/specforge/modeling/auto.py @@ -0,0 +1,175 @@ +import json +import os +from typing import Optional, Union + +import torch +from transformers import AutoConfig +from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase +from transformers import ( + GptOssConfig, + Llama4Config, + Llama4TextConfig, + LlamaConfig, + Phi3Config, + PretrainedConfig, + Qwen2Config, + Qwen3Config, + Qwen3MoeConfig, + modeling_utils, +) + +from .draft.llama3_eagle import LlamaForCausalLMEagle3 +from .target.custom_backend import ( + GptOssForCausalLM, + Llama4ForCausalLM, + LlamaForCausalLM, + Phi3ForCausalLM, + Qwen2ForCausalLM, + Qwen3ForCausalLM, + Qwen3MoeForCausalLM, +) + + +class AutoEagle3DraftModel(AutoModelForCausalLMBase): + # the model mapping is currently hardcoded, we should support lazy model mapping via registry + _model_mapping = { + LlamaConfig: LlamaForCausalLMEagle3, + } + + @classmethod + def from_config(cls, config: PretrainedConfig, torch_dtype=None, **config_kwargs): + """ + This class method takes a configuration object and create its model based on the + _model_mapping class variable. + + Args: + config (PretrainedConfig): A configuration object. + + Returns: + A model instance. + """ + # get the model class from the + _model_cls = cls._model_mapping[type(config)] + model = _model_cls(config, **config_kwargs) + + # Convert model to specified dtype if provided + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike[str]], + *model_args, + **kwargs, + ): + original_warn = modeling_utils.logger.warning + + def filtered_warning(msg): + if "embed_tokens.weight" in str(msg) and "initialized" in str(msg): + return + original_warn(msg) + + modeling_utils.logger.warning = filtered_warning + + try: + model = super().from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + finally: + modeling_utils.logger.warning = original_warn + + return model + + +class AutoDistributedTargetModel(AutoModelForCausalLMBase): + # the model mapping is currently hardcoded, we should support lazy model mapping via registry + _model_mapping = { + Llama4TextConfig: [Llama4ForCausalLM], + Qwen3MoeConfig: [Qwen3MoeForCausalLM], + Qwen2Config: [Qwen2ForCausalLM], + LlamaConfig: [LlamaForCausalLM], + Qwen3Config: [Qwen3ForCausalLM], + Phi3Config: [Phi3ForCausalLM], + GptOssConfig: [GptOssForCausalLM], + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike[str]], + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **config_kwargs, + ): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + ) + + if isinstance(config, Llama4Config): + config = config.text_config + + assert ( + type(config) in cls._model_mapping + ), f"Unsupported config type: {type(config)}" + model_cls = cls._model_mapping[type(config)][0] + model = model_cls.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + **config_kwargs, + ) + + if device is not None: + model = model.to(device) + else: + model = model.cuda() + return model + + +class AutoDraftModelConfig: + + _config_mapping = { + "LlamaForCausalLMEagle3": LlamaConfig, + } + + @classmethod + def from_file(cls, config_path: str): + """ + This class method takes a configuration file path and create its configuration object based on the + _config_mapping class variable. + + Args: + config_path (str): A path to a configuration file. + + Returns: + A configuration object. + """ + with open(config_path, "r") as f: + config = json.load(f) + + if "tie_word_embeddings" in config: + print("Set draft model tie_word_embeddings to False") + config["tie_word_embeddings"] = False + + # check for architectures + architectures = config.get("architectures", None) + + if architectures is None: + raise ValueError("No architectures found in the config file") + + if len(architectures) != 1: + raise ValueError("Only one architecture is supported") + + architecture = architectures[0] + + if architecture not in cls._config_mapping: + raise ValueError(f"Architecture {architecture} not supported") + + # If draft_vocab_size is not in config or is None, set draft_vocab_size to vocab_size + if "draft_vocab_size" not in config or config["draft_vocab_size"] is None: + config["draft_vocab_size"] = config.get("vocab_size", None) + + return cls._config_mapping[architecture].from_dict(config) diff --git a/SpecForge-ext/specforge/modeling/target/__init__.py b/SpecForge-ext/specforge/modeling/target/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f70b3b740d055ae72dfebacc5b9f7434f3eed0e --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/__init__.py @@ -0,0 +1,17 @@ +from .eagle3_target_model import ( + CustomEagle3TargetModel, + Eagle3TargetModel, + HFEagle3TargetModel, + SGLangEagle3TargetModel, + get_eagle3_target_model, +) +from .target_head import TargetHead + +__all__ = [ + "Eagle3TargetModel", + "SGLangEagle3TargetModel", + "HFEagle3TargetModel", + "CustomEagle3TargetModel", + "get_eagle3_target_model", + "TargetHead", +] diff --git a/SpecForge-ext/specforge/modeling/utils.py b/SpecForge-ext/specforge/modeling/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4cdd45642c0761e7178c1990e2e1bab6420b15ea --- /dev/null +++ b/SpecForge-ext/specforge/modeling/utils.py @@ -0,0 +1,11 @@ +import torch + + +@torch.no_grad() +def padding(tensor, left=True): + zeropadding = torch.zeros_like(tensor[:, -1:]) + if left: + tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1) + else: + tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1) + return tensor diff --git a/SpecForge-ext/wandb/debug-internal.log b/SpecForge-ext/wandb/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..ebfcabee217138939944197fa6e5181c2b81c43f --- /dev/null +++ b/SpecForge-ext/wandb/debug-internal.log @@ -0,0 +1,18 @@ +{"time":"2026-02-02T07:13:24.265773006Z","level":"INFO","msg":"stream: starting","core version":"0.24.1"} +{"time":"2026-02-02T07:13:54.532856025Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:14:27.017208123Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:15:01.749909771Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:15:40.607517629Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:16:29.823418554Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:17:34.179120391Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:19:04.180785637Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:20:34.182642172Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:22:04.183160003Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:23:34.184196875Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:25:04.184985971Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:26:34.186601319Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:28:04.190713607Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:29:34.193064558Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:31:04.193954293Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:32:34.195503708Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:34:04.197459211Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"} diff --git a/SpecForge-ext/wandb/debug.log b/SpecForge-ext/wandb/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..d007187260678991a840485a4796a2ee47ea12dc --- /dev/null +++ b/SpecForge-ext/wandb/debug.log @@ -0,0 +1,161 @@ +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_setup.py:_flush():81] Current SDK version is 0.24.1 +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_setup.py:_flush():81] Configure stats pid to 601 +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_setup.py:_flush():81] Loading settings from environment variables +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /workspace/hanrui/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug.log +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /workspace/hanrui/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug-internal.log +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:init():844] calling init triggers +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:init():849] wandb.init called with sweep_config: {} +config: {'target_model_path': '/workspace/Qwen3-8B', 'trust_remote_code': False, 'draft_model_config': 'configs/qwen3-8b-qwen3eagle-5layer.json', 'embedding_key': 'model.embed_tokens.weight', 'lm_head_key': 'lm_head.weight', 'is_vlm': False, 'target_model_backend': 'sglang', 'train_data_path': '/workspace/hanrui/qwen3-8b_dflash_regen/sharegpt_train_regenerated.jsonl', 'train_hidden_states_path': None, 'eval_hidden_states_path': None, 'eval_data_path': None, 'chat_template': 'qwen', 'is_preformatted': False, 'train_only_last_turn': False, 'build_dataset_num_proc': 8, 'dataloader_num_workers': 4, 'num_epochs': 10, 'max_num_steps': None, 'batch_size': 2, 'learning_rate': 0.0001, 'max_length': 2048, 'warmup_ratio': 0.015, 'total_steps': 49260, 'max_grad_norm': 0.5, 'ttt_length': 7, 'resume': False, 'ckpt_dir': None, 'eval_interval': 5000, 'save_interval': 5000, 'log_interval': 100, 'seed': 0, 'draft_accumulation_steps': 1, 'tp_size': 1, 'sp_ulysses_size': 1, 'sp_ring_size': 1, 'attention_backend': 'flex_attention', 'cache_key': None, 'cache_dir': 'cache', 'output_dir': 'outputs/qwen3-8b-qwen3eagle-5layer', 'verbose': False, 'dist_timeout': 20, 'model_download_dir': None, 'min_pixels': 50176, 'max_pixels': 802816, 'profile': False, 'profile_start_step': 30, 'profile_num_steps': 4, 'profile_record_shapes': False, 'sglang_attention_backend': 'flashinfer', 'sglang_mem_fraction_static': 0.4, 'sglang_context_length': None, 'sglang_enable_nccl_nvls': False, 'sglang_enable_symm_mem': False, 'sglang_enable_torch_compile': False, 'sglang_enable_dp_attention': False, 'sglang_enable_dp_lm_head': False, 'sglang_enable_piecewise_cuda_graph': False, 'sglang_piecewise_cuda_graph_max_tokens': 4096, 'sglang_piecewise_cuda_graph_tokens': None, 'sglang_ep_size': 1, 'report_to': 'wandb', 'wandb_project': 'qwen3-8b-qwen3eagle', 'wandb_name': '5layer-ttt7', 'wandb_key': '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', 'swanlab_project': None, 'swanlab_name': None, 'swanlab_key': None, 'mlflow_tracking_uri': None, 'mlflow_experiment_name': None, 'mlflow_run_name': None, 'dp_size': 8, 'target_batch_size': 2, '_wandb': {}} +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:init():892] starting backend +2026-02-02 07:13:24,247 INFO MainThread:601 [wandb_init.py:init():895] sending inform_init request +2026-02-02 07:13:24,263 INFO MainThread:601 [wandb_init.py:init():903] backend started and connected +2026-02-02 07:13:24,270 INFO MainThread:601 [wandb_init.py:init():973] updated telemetry +2026-02-02 07:13:24,285 INFO MainThread:601 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout +2026-02-02 07:13:55,052 INFO Thread-7 (wrapped_target):601 [retry.py:__call__():164] [no run ID] Retry attempt failed: +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 204, in _new_conn + sock = connection.create_connection( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 85, in create_connection + raise err + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 73, in create_connection + sock.connect(sa) +TimeoutError: timed out + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 787, in urlopen + response = self._make_request( + ^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 488, in _make_request + raise new_e + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 464, in _make_request + self._validate_conn(conn) + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 1093, in _validate_conn + conn.connect() + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 759, in connect + self.sock = sock = self._new_conn() + ^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 213, in _new_conn + raise ConnectTimeoutError( +urllib3.exceptions.ConnectTimeoutError: (, 'Connection to api.wandb.ai timed out. (connect timeout=20)') + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 644, in send + resp = conn.urlopen( + ^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 841, in urlopen + retries = retries.increment( + ^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/retry.py", line 535, in increment + raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)')) + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/retry.py", line 157, in __call__ + result = self._call_fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/internal/internal_api.py", line 397, in execute + return self.client.execute(*args, **kwargs) # type: ignore + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 52, in execute + result = self._get_result(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 60, in _get_result + return self.transport.execute(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/gql_request.py", line 70, in execute + request = self.session.post(self.url, **post_args) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 637, in post + return self.request("POST", url, data=data, json=json, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 589, in request + resp = self.send(prep, **send_kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 703, in send + r = adapter.send(request, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 665, in send + raise ConnectTimeout(e, request=request) +requests.exceptions.ConnectTimeout: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)')) +2026-02-02 07:14:12,432 INFO Thread-6 (wrapped_target):601 [retry.py:__call__():164] [no run ID] Retry attempt failed: +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 204, in _new_conn + sock = connection.create_connection( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 85, in create_connection + raise err + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 73, in create_connection + sock.connect(sa) +TimeoutError: timed out + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 787, in urlopen + response = self._make_request( + ^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 488, in _make_request + raise new_e + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 464, in _make_request + self._validate_conn(conn) + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 1093, in _validate_conn + conn.connect() + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 759, in connect + self.sock = sock = self._new_conn() + ^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 213, in _new_conn + raise ConnectTimeoutError( +urllib3.exceptions.ConnectTimeoutError: (, 'Connection to api.wandb.ai timed out. (connect timeout=20)') + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 644, in send + resp = conn.urlopen( + ^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 841, in urlopen + retries = retries.increment( + ^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/retry.py", line 535, in increment + raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)')) + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/retry.py", line 157, in __call__ + result = self._call_fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/internal/internal_api.py", line 397, in execute + return self.client.execute(*args, **kwargs) # type: ignore + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 52, in execute + result = self._get_result(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 60, in _get_result + return self.transport.execute(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/gql_request.py", line 70, in execute + request = self.session.post(self.url, **post_args) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 637, in post + return self.request("POST", url, data=data, json=json, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 589, in request + resp = self.send(prep, **send_kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 703, in send + r = adapter.send(request, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 665, in send + raise ConnectTimeout(e, request=request) +requests.exceptions.ConnectTimeout: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)')) diff --git a/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug-internal.log b/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..ebfcabee217138939944197fa6e5181c2b81c43f --- /dev/null +++ b/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug-internal.log @@ -0,0 +1,18 @@ +{"time":"2026-02-02T07:13:24.265773006Z","level":"INFO","msg":"stream: starting","core version":"0.24.1"} +{"time":"2026-02-02T07:13:54.532856025Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:14:27.017208123Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:15:01.749909771Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:15:40.607517629Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:16:29.823418554Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:17:34.179120391Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:19:04.180785637Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:20:34.182642172Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:22:04.183160003Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:23:34.184196875Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:25:04.184985971Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:26:34.186601319Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded"} +{"time":"2026-02-02T07:28:04.190713607Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:29:34.193064558Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:31:04.193954293Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:32:34.195503708Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers)"} +{"time":"2026-02-02T07:34:04.197459211Z","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/graphql\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"} diff --git a/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug.log b/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..d007187260678991a840485a4796a2ee47ea12dc --- /dev/null +++ b/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug.log @@ -0,0 +1,161 @@ +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_setup.py:_flush():81] Current SDK version is 0.24.1 +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_setup.py:_flush():81] Configure stats pid to 601 +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_setup.py:_flush():81] Loading settings from environment variables +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /workspace/hanrui/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug.log +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /workspace/hanrui/SpecForge-ext/wandb/run-20260202_071323-2yze80jn/logs/debug-internal.log +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:init():844] calling init triggers +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:init():849] wandb.init called with sweep_config: {} +config: {'target_model_path': '/workspace/Qwen3-8B', 'trust_remote_code': False, 'draft_model_config': 'configs/qwen3-8b-qwen3eagle-5layer.json', 'embedding_key': 'model.embed_tokens.weight', 'lm_head_key': 'lm_head.weight', 'is_vlm': False, 'target_model_backend': 'sglang', 'train_data_path': '/workspace/hanrui/qwen3-8b_dflash_regen/sharegpt_train_regenerated.jsonl', 'train_hidden_states_path': None, 'eval_hidden_states_path': None, 'eval_data_path': None, 'chat_template': 'qwen', 'is_preformatted': False, 'train_only_last_turn': False, 'build_dataset_num_proc': 8, 'dataloader_num_workers': 4, 'num_epochs': 10, 'max_num_steps': None, 'batch_size': 2, 'learning_rate': 0.0001, 'max_length': 2048, 'warmup_ratio': 0.015, 'total_steps': 49260, 'max_grad_norm': 0.5, 'ttt_length': 7, 'resume': False, 'ckpt_dir': None, 'eval_interval': 5000, 'save_interval': 5000, 'log_interval': 100, 'seed': 0, 'draft_accumulation_steps': 1, 'tp_size': 1, 'sp_ulysses_size': 1, 'sp_ring_size': 1, 'attention_backend': 'flex_attention', 'cache_key': None, 'cache_dir': 'cache', 'output_dir': 'outputs/qwen3-8b-qwen3eagle-5layer', 'verbose': False, 'dist_timeout': 20, 'model_download_dir': None, 'min_pixels': 50176, 'max_pixels': 802816, 'profile': False, 'profile_start_step': 30, 'profile_num_steps': 4, 'profile_record_shapes': False, 'sglang_attention_backend': 'flashinfer', 'sglang_mem_fraction_static': 0.4, 'sglang_context_length': None, 'sglang_enable_nccl_nvls': False, 'sglang_enable_symm_mem': False, 'sglang_enable_torch_compile': False, 'sglang_enable_dp_attention': False, 'sglang_enable_dp_lm_head': False, 'sglang_enable_piecewise_cuda_graph': False, 'sglang_piecewise_cuda_graph_max_tokens': 4096, 'sglang_piecewise_cuda_graph_tokens': None, 'sglang_ep_size': 1, 'report_to': 'wandb', 'wandb_project': 'qwen3-8b-qwen3eagle', 'wandb_name': '5layer-ttt7', 'wandb_key': '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', 'swanlab_project': None, 'swanlab_name': None, 'swanlab_key': None, 'mlflow_tracking_uri': None, 'mlflow_experiment_name': None, 'mlflow_run_name': None, 'dp_size': 8, 'target_batch_size': 2, '_wandb': {}} +2026-02-02 07:13:23,912 INFO MainThread:601 [wandb_init.py:init():892] starting backend +2026-02-02 07:13:24,247 INFO MainThread:601 [wandb_init.py:init():895] sending inform_init request +2026-02-02 07:13:24,263 INFO MainThread:601 [wandb_init.py:init():903] backend started and connected +2026-02-02 07:13:24,270 INFO MainThread:601 [wandb_init.py:init():973] updated telemetry +2026-02-02 07:13:24,285 INFO MainThread:601 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout +2026-02-02 07:13:55,052 INFO Thread-7 (wrapped_target):601 [retry.py:__call__():164] [no run ID] Retry attempt failed: +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 204, in _new_conn + sock = connection.create_connection( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 85, in create_connection + raise err + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 73, in create_connection + sock.connect(sa) +TimeoutError: timed out + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 787, in urlopen + response = self._make_request( + ^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 488, in _make_request + raise new_e + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 464, in _make_request + self._validate_conn(conn) + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 1093, in _validate_conn + conn.connect() + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 759, in connect + self.sock = sock = self._new_conn() + ^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 213, in _new_conn + raise ConnectTimeoutError( +urllib3.exceptions.ConnectTimeoutError: (, 'Connection to api.wandb.ai timed out. (connect timeout=20)') + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 644, in send + resp = conn.urlopen( + ^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 841, in urlopen + retries = retries.increment( + ^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/retry.py", line 535, in increment + raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)')) + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/retry.py", line 157, in __call__ + result = self._call_fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/internal/internal_api.py", line 397, in execute + return self.client.execute(*args, **kwargs) # type: ignore + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 52, in execute + result = self._get_result(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 60, in _get_result + return self.transport.execute(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/gql_request.py", line 70, in execute + request = self.session.post(self.url, **post_args) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 637, in post + return self.request("POST", url, data=data, json=json, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 589, in request + resp = self.send(prep, **send_kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 703, in send + r = adapter.send(request, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 665, in send + raise ConnectTimeout(e, request=request) +requests.exceptions.ConnectTimeout: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)')) +2026-02-02 07:14:12,432 INFO Thread-6 (wrapped_target):601 [retry.py:__call__():164] [no run ID] Retry attempt failed: +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 204, in _new_conn + sock = connection.create_connection( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 85, in create_connection + raise err + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/connection.py", line 73, in create_connection + sock.connect(sa) +TimeoutError: timed out + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 787, in urlopen + response = self._make_request( + ^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 488, in _make_request + raise new_e + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 464, in _make_request + self._validate_conn(conn) + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 1093, in _validate_conn + conn.connect() + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 759, in connect + self.sock = sock = self._new_conn() + ^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connection.py", line 213, in _new_conn + raise ConnectTimeoutError( +urllib3.exceptions.ConnectTimeoutError: (, 'Connection to api.wandb.ai timed out. (connect timeout=20)') + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 644, in send + resp = conn.urlopen( + ^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/connectionpool.py", line 841, in urlopen + retries = retries.increment( + ^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/urllib3/util/retry.py", line 535, in increment + raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)')) + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/retry.py", line 157, in __call__ + result = self._call_fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/internal/internal_api.py", line 397, in execute + return self.client.execute(*args, **kwargs) # type: ignore + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 52, in execute + result = self._get_result(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 60, in _get_result + return self.transport.execute(document, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/wandb/sdk/lib/gql_request.py", line 70, in execute + request = self.session.post(self.url, **post_args) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 637, in post + return self.request("POST", url, data=data, json=json, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 589, in request + resp = self.send(prep, **send_kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/sessions.py", line 703, in send + r = adapter.send(request, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/specforge/lib/python3.11/site-packages/requests/adapters.py", line 665, in send + raise ConnectTimeout(e, request=request) +requests.exceptions.ConnectTimeout: HTTPSConnectionPool(host='api.wandb.ai', port=443): Max retries exceeded with url: /graphql (Caused by ConnectTimeoutError(, 'Connection to api.wandb.ai timed out. (connect timeout=20)'))