atomwalk12's picture
initial commit
0dd6c2f
from dataclasses import dataclass, field
from trl.scripts.utils import ScriptArguments as ScriptArgs
from trl.trainer.model_config import ModelConfig
@dataclass
class ScriptArguments(ScriptArgs):
"""
Extended version of ScriptArguments with support for dataset mixtures.
"""
dataset_name: str | None = field(
default=None, metadata={"help": "Training dataset name. Contains chain-of-thought solutions."}
)
eval_dataset_config: str | None = field(default=None, metadata={"help": "Evaluation dataset config."})
take_n: int | None = field(default=None, metadata={"help": "Number of examples to take from the dataset."})
@dataclass
class SFTModelConfig(ModelConfig):
enforce_eager: bool | None = field(default=None, metadata={"help": "Whether to enforce eager execution."})
@dataclass
class SFTRunConfig:
add_special_tokens: bool = field(
metadata={"help": "Whether to add special tokens to the model."},
)
early_stopping_patience: int = field(
default=3, metadata={"help": "The number of epochs to wait before early stopping."}
)
early_stopping_threshold: float = field(
default=0.0, metadata={"help": "Minimum improvement required to reset patience counter."}
)
benchmarks: list[str] = field(
default_factory=lambda: [],
metadata={"help": "The benchmarks to run after training."},
)
callbacks: list[str] = field(
default_factory=lambda: [],
metadata={"help": "The callbacks to run during training."},
)
chat_template: str | None = field(default=None, metadata={"help": "The chat template to use."})
system_prompt: str | None = field(
default=None,
metadata={"help": "The optional system prompt to use for benchmarking."},
)
hub_model_revision: str | None = field(
default="main",
metadata={"help": "The Hub model branch to push the model to."},
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: str | None = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: str | None = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
wandb_run_group: str | None = field(
default=None,
metadata={"help": ("The group to store runs under.")},
)
wandb_run_id: str | None = field(default=None, metadata={"help": {"The wandb run id."}})
eval_max_new_tokens: int | None = field(
default=None,
metadata={"help": "Max new tokens for evaluation callbacks (does not affect training)."},
)
max_seq_length: int | None = field(
default=None,
metadata={"help": "Max sequence length for evaluation callbacks (does not affect training)."},
)
gpu_memory_utilization: float | None = field(
default=0.95, metadata={"help": "Fraction of GPU memory to be used by vLLM (0-1)"}
)
# Evaluation sampling
max_eval_samples: int | None = field(
default=None,
metadata={"help": "Maximum number of eval samples for periodic evaluations. None/-1 for full dataset."},
)
final_eval_max_samples: int | None = field(
default=None,
metadata={"help": "Maximum number of eval samples for periodic evaluations. None/-1 for full dataset."},
)
@dataclass
class DatasetGenerationConfig:
"""
Data class that stores the dataset generation parameters.
Args:
dataset_name (str): The name of the dataset to generate.
"""
dataset_name: str | None = field(
metadata={"help": "Should be the name used to store the dataset on the Hugging Face Hub."},
)
@dataclass
class LlamaCppServerConfig:
"""
Data class that stores LlamaCPP server parameters with llama_cpp_ prefix.
"""
def __post_init__(self) -> None:
pass
# Server parameters
host: str = field(
metadata={"help": "Host address to bind to"},
)
port: int = field(
metadata={"help": "Port to listen on"},
)
n_ctx: int = field(
metadata={"help": "Context size"},
)
split_mode: int = field(
metadata={"help": "Split mode (0=none, 1=layer, 2=row)"},
)
# Model parameters
n_gpu_layers: int = field(
metadata={"help": "Number of GPU layers to offload"},
)
model: str = field(
metadata={"help": "Model URL to download (GGUF format)"},
)
hf_pretrained_model_name_or_path: str | None = field(
default=None,
metadata={"help": "Huggingface repository ID to ensure that the correct tokenizer is used."},
)
hf_model_repo_id: str | None = field(
default=None,
metadata={"help": "Path to the repository where the model is stored."},
)
@dataclass
class VllmServerConfig:
"""
Data class that stores vLLM server parameters with vllm_ prefix.
"""
# Model parameters
model: str = field(
metadata={"help": "Model name (HuggingFace format)"},
)
# Server parameters
host: str = field(
metadata={"help": "Host address to bind to"},
)
port: int = field(
metadata={"help": "Port to listen on"},
)
enable_auto_tool_choice: bool = field(
metadata={"help": "Enable automatic tool choice"},
)
tool_call_parser: str = field(
metadata={"help": "Tool call parser to use"},
)
chat_template: str | None = field(
default=None,
metadata={"help": "Chat template to use"},
)
quantization: str | None = field(
default=None,
metadata={"help": "Quantization to use"},
)
api_key: str = field(
default="not-used",
metadata={"help": "API key for authentication (use 'not-used' for local development)"},
)
# Memory / performance tuning parameters
dtype: str | None = field(
default=None,
metadata={"help": "Computation dtype for model weights and activations (e.g., float16)"},
)
kv_cache_dtype: str | None = field(
default=None,
metadata={"help": "KV cache dtype (auto, fp8, fp8_e4m3, fp8_e5m2)"},
)
max_model_len: int | None = field(
default=None,
metadata={"help": "Maximum model context length (tokens)"},
)
max_num_seqs: int | None = field(
default=None,
metadata={"help": "Maximum number of concurrent sequences"},
)
gpu_memory_utilization: float | None = field(
default=None,
metadata={"help": "Fraction of GPU memory to be used by vLLM (0-1)"},
)
enforce_eager: bool | None = field(
default=None,
metadata={"help": "Disable CUDA graphs to reduce memory usage"},
)
swap_space: int | None = field(
default=None,
metadata={"help": "CPU swap space in GB per GPU for paging KV cache"},
)
max_num_batched_tokens: int | None = field(
default=None,
metadata={"help": "Limit number of tokens processed per batch (prefill)"},
)
tensor_parallel_size: int | None = field(
default=None,
metadata={"help": "Tensor parallelism degree"},
)
enable_chunked_prefill: bool | None = field(
default=None,
metadata={"help": "Enable chunked prefill to reduce peak prefill memory"},
)
# Model parameters
reasoning_parser: str | None = field(
default=None,
metadata={"help": "Reasoning parser to use"},
)
@dataclass
class DistillationConfig:
"""
Data class that stores the distillation pipeline parameters.
"""
# Dataset parameters
dataset_name: str | None = field(
metadata={"help": "HuggingFace dataset to load"},
)
# Prompt parameters
prompt_column: str = field(
metadata={"help": "Column name for prompt data"},
)
prompt_template: str = field(
metadata={"help": "Template string for formatting prompts"},
)
# Generation parameters (non-defaults first)
model_type: str | None = field(metadata={"help": "Model type for generation"})
enable_reasoning: bool = field(metadata={"help": "Whether to enable thinking"})
max_new_tokens: int = field(
metadata={"help": "Maximum number of new tokens to generate"},
)
num_generations: int = field(
metadata={"help": "Number of generations per problem"},
)
# Processing parameters
input_batch_size: int = field(
metadata={"help": "Batch size for input processing"},
)
use_cache: bool = field(
metadata={"help": "Whether to use cache for the pipeline. This can enable error recovery."},
)
timeout: int = field(
metadata={"help": "Request timeout in seconds"},
)
retries: int = field(
metadata={"help": "Number of retries for failed requests"},
)
# Output parameters
hf_output_dataset: str | None = field(
metadata={"help": "HuggingFace repo to push results to"},
)
argilla_output_dataset: str | None = field(
metadata={"help": "Argilla dataset to push results to. This is used for manual annotation."},
)
private: bool = field(
metadata={"help": "Whether to make the output dataset private when pushing to HF Hub"},
)
# Generation parameters
n_turns: int = field(
metadata={"help": "Number of turns to generate"},
)
min_successful_completions: int = field(
default=-1,
metadata={"help": "Minimum number of successful completions to generate"},
)
strip_think_prefix: bool = field(
default=True,
metadata={"help": "Whether to strip the think prefix from the conversation. This is needed for Qwen3 models."},
)
# Optional stopping sequences (must come after non-default fields)
stop: list[str] | None = field(
default=None,
metadata={"help": "Stop sequences for generation (each string is a stop token)"},
)
debug_mode: bool = field(
default=False,
metadata={"help": "Whether to do evaluation"},
)
take_n: int | None = field(
default=None,
metadata={"help": "Number of examples to take from the dataset."},
)
structured_output: bool = field(
default=False,
metadata={"help": "Whether to use structured output"},
)
deterministic: bool = field(
default=True,
metadata={"help": "Make generation deterministic (temperature=0, top_p=1)"},
)
client_replicas: int | None = field(
default=None,
metadata={"help": "Number of client replicas for parallel processing"},
)
dataset_config: str | None = field(
default=None,
metadata={"help": "Dataset config to use"},
)