| | import torch |
| | from typing import Optional |
| | from datasets import load_dataset, Features, Sequence, Value |
| |
|
| | def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): |
| | if num_draft_layers == 1: |
| | return [(num_target_layers // 2)] |
| | start = 1 |
| | end = num_target_layers - 3 |
| | span = end - start |
| | target_layer_ids = [ |
| | int(round(start + (i * span) / (num_draft_layers - 1))) |
| | for i in range(num_draft_layers) |
| | ] |
| | return target_layer_ids |
| |
|
| | def extract_context_feature( |
| | hidden_states: list[torch.Tensor], |
| | layer_ids: Optional[list[int]], |
| | ) -> torch.Tensor: |
| | offset = 1 |
| | selected_states = [] |
| | for layer_id in layer_ids: |
| | selected_states.append(hidden_states[layer_id + offset]) |
| | target_hidden = torch.cat(selected_states, dim=-1) |
| | return target_hidden |
| |
|
| | def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: |
| | if temperature < 1e-5: |
| | return torch.argmax(logits, dim=-1) |
| | bsz, seq_len, vocab_size = logits.shape |
| | logits = logits.view(-1, vocab_size) |
| | logits = logits / temperature |
| | probs = torch.softmax(logits, dim=-1) |
| | return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) |
| |
|
| | def load_and_process_dataset(data_name: str): |
| | |
| | if data_name == "gsm8k": |
| | dataset = load_dataset("openai/gsm8k", "main", split="test") |
| | prompt_fmt = "Solve the following math problem. Make sure to put the answer (and only answer) inside \\boxed{{}}.\n\n{question}" |
| | dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) |
| | |
| | elif data_name == "math500": |
| | dataset = load_dataset("HuggingFaceH4/MATH-500", split="test") |
| | prompt_fmt = "Solve the following math problem. Make sure to put the answer (and only answer) inside \\boxed{{}}.\n\n{problem}" |
| | dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) |
| | |
| | elif data_name == "aime24": |
| | dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") |
| | prompt_fmt = "Solve the following math problem. Make sure to put the answer (and only answer) inside \\boxed{{}}.\n\n{problem}" |
| | dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) |
| |
|
| | elif data_name == "aime25": |
| | dataset = load_dataset("MathArena/aime_2025", split="train") |
| | prompt_fmt = "Solve the following math problem. Make sure to put the answer (and only answer) inside \\boxed{{}}.\n\n{problem}" |
| | dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) |
| |
|
| | |
| | elif data_name == "alpaca": |
| | dataset = load_dataset("tatsu-lab/alpaca", split="train") |
| | dataset = dataset.map(lambda x: {"formatted_input": (f"{x['instruction']}\n\nInput:\n{x['input']}" if x['input'] else x['instruction'])}) |
| | dataset = dataset.map(lambda x: {"turns": [x["formatted_input"]]}) |
| |
|
| | elif data_name == "mt-bench": |
| | dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") |
| | dataset = dataset.map(lambda x: {"turns": x["prompt"]}) |
| |
|
| | |
| | elif data_name == "humaneval": |
| | dataset = load_dataset("openai/openai_humaneval", split="test") |
| | prompt_fmt = "Write a solution to the following problem and make sure that it passes the tests:\n```python\n{prompt}\n```" |
| | dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) |
| | |
| | elif data_name == "mbpp": |
| | dataset = load_dataset("google-research-datasets/mbpp", "sanitized", split="test") |
| | dataset = dataset.map(lambda x: {"turns": [x["prompt"]]}) |
| | |
| | elif data_name == "lbpp": |
| | LBPP_PY_TEST_URL = "https://huggingface.co/datasets/CohereLabs/lbpp/resolve/main/python/test.parquet" |
| | dataset = load_dataset("parquet", data_files={"test": LBPP_PY_TEST_URL})["test"] |
| | dataset = dataset.map(lambda x: {"turns": [x["instruction"]]}) |
| |
|
| | elif data_name == "swe-bench": |
| | dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") |
| | prompt_fmt = "Problem Statement:\n{problem_statement}\nPlease fix the issue described above." |
| | dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) |
| | |
| | elif data_name == "livecodebench": |
| | base = "https://huggingface.co/datasets/livecodebench/code_generation_lite/resolve/main/" |
| | allowed_files = ["test.jsonl", "test2.jsonl", "test3.jsonl", "test4.jsonl", "test5.jsonl", "test6.jsonl"] |
| | urls = [base + fn for fn in allowed_files] |
| | dataset = load_dataset("json", data_files={"test": urls})["test"] |
| | def format_lcb(doc): |
| | system_prompt = ( |
| | "You are an expert Python programmer. You will be given a question (problem specification) " |
| | "and will generate a correct Python program that matches the specification and passes all tests. " |
| | "You will NOT return anything except for the program" |
| | ) |
| | question_block = f"### Question:\n{doc['question_content']}" |
| | if doc.get("starter_code"): |
| | format_message = "### Format: Use the following code structure:" |
| | code_block = f"```python\n{doc['starter_code']}\n```" |
| | else: |
| | format_message = "### Format: Write your code in the following format:" |
| | code_block = "```python\n# YOUR CODE HERE\n```" |
| | answer_footer = "### Answer: (use the provided format with backticks)" |
| | return f"{system_prompt}\n\n{question_block}\n\n{format_message}\n{code_block}\n\n{answer_footer}" |
| | target_features = Features({"turns": Sequence(Value("large_string"))}) |
| | dataset = dataset.map( |
| | lambda x: {"turns": [format_lcb(x)]}, |
| | remove_columns=dataset.column_names, |
| | features=target_features |
| | ) |
| | |
| | return dataset |