| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
|
|
| from distilabel.llms import OpenAILLM |
| from distilabel.pipeline import Pipeline |
| from distilabel.steps.tasks import TextGeneration |
|
|
|
|
| def build_distilabel_pipeline( |
| model: str, |
| base_url: str = "http://localhost:8000/v1", |
| prompt_column: Optional[str] = None, |
| temperature: Optional[float] = None, |
| top_p: Optional[float] = None, |
| max_new_tokens: int = 8192, |
| num_generations: int = 1, |
| ) -> Pipeline: |
| generation_kwargs = {"max_new_tokens": max_new_tokens} |
|
|
| if temperature is not None: |
| generation_kwargs["temperature"] = temperature |
|
|
| if top_p is not None: |
| generation_kwargs["top_p"] = top_p |
|
|
| with Pipeline().ray() as pipeline: |
| TextGeneration( |
| llm=OpenAILLM( |
| base_url=base_url, |
| api_key="something", |
| model=model, |
| |
| timeout=10 * 60, |
| generation_kwargs=generation_kwargs, |
| ), |
| input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, |
| input_batch_size=64, |
| num_generations=num_generations, |
| ) |
|
|
| return pipeline |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| from datasets import load_dataset |
|
|
| parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") |
| parser.add_argument( |
| "--hf-dataset", |
| type=str, |
| required=True, |
| help="HuggingFace dataset to load", |
| ) |
| parser.add_argument( |
| "--hf-dataset-config", |
| type=str, |
| required=False, |
| help="Dataset config to use", |
| ) |
| parser.add_argument( |
| "--hf-dataset-split", |
| type=str, |
| default="train", |
| help="Dataset split to use", |
| ) |
| parser.add_argument("--prompt-column", type=str, default="prompt") |
| parser.add_argument( |
| "--model", |
| type=str, |
| required=True, |
| help="Model name to use for generation", |
| ) |
| parser.add_argument( |
| "--vllm-server-url", |
| type=str, |
| default="http://localhost:8000/v1", |
| help="URL of the vLLM server", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| help="Temperature for generation", |
| ) |
| parser.add_argument( |
| "--top-p", |
| type=float, |
| help="Top-p value for generation", |
| ) |
| parser.add_argument( |
| "--max-new-tokens", |
| type=int, |
| default=8192, |
| help="Maximum number of new tokens to generate", |
| ) |
| parser.add_argument( |
| "--num-generations", |
| type=int, |
| default=1, |
| help="Number of generations per problem", |
| ) |
| parser.add_argument( |
| "--hf-output-dataset", |
| type=str, |
| required=False, |
| help="HuggingFace repo to push results to", |
| ) |
| parser.add_argument( |
| "--private", |
| action="store_true", |
| help="Whether to make the output dataset private when pushing to HF Hub", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| print("\nRunning with arguments:") |
| for arg, value in vars(args).items(): |
| print(f" {arg}: {value}") |
| print() |
|
|
| print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") |
| dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) |
| print("Dataset loaded!") |
|
|
| pipeline = build_distilabel_pipeline( |
| model=args.model, |
| base_url=args.vllm_server_url, |
| prompt_column=args.prompt_column, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| max_new_tokens=args.max_new_tokens, |
| num_generations=args.num_generations, |
| ) |
|
|
| print("Running generation pipeline...") |
| distiset = pipeline.run(dataset=dataset, use_cache=False) |
| print("Generation pipeline finished!") |
|
|
| if args.hf_output_dataset: |
| print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") |
| distiset.push_to_hub(args.hf_output_dataset, private=args.private) |
| print("Dataset pushed!") |
|
|