Spaces:
Runtime error
Runtime error
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional | |
| from distilabel.llms import OpenAILLM | |
| from distilabel.pipeline import Pipeline | |
| from distilabel.steps.tasks import TextGeneration | |
| def build_distilabel_pipeline( | |
| model: str, | |
| base_url: str = "http://localhost:8000/v1", | |
| prompt_column: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| max_new_tokens: int = 8192, | |
| num_generations: int = 1, | |
| ) -> Pipeline: | |
| generation_kwargs = {"max_new_tokens": max_new_tokens} | |
| if temperature is not None: | |
| generation_kwargs["temperature"] = temperature | |
| if top_p is not None: | |
| generation_kwargs["top_p"] = top_p | |
| with Pipeline().ray() as pipeline: | |
| TextGeneration( | |
| llm=OpenAILLM( | |
| base_url=base_url, | |
| api_key="something", | |
| model=model, | |
| # thinking can take some time... | |
| timeout=10 * 60, | |
| generation_kwargs=generation_kwargs, | |
| ), | |
| input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, | |
| input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion | |
| num_generations=num_generations, | |
| ) | |
| return pipeline | |
| if __name__ == "__main__": | |
| import argparse | |
| from datasets import load_dataset | |
| parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") | |
| parser.add_argument( | |
| "--hf-dataset", | |
| type=str, | |
| required=True, | |
| help="HuggingFace dataset to load", | |
| ) | |
| parser.add_argument( | |
| "--hf-dataset-config", | |
| type=str, | |
| required=False, | |
| help="Dataset config to use", | |
| ) | |
| parser.add_argument( | |
| "--hf-dataset-split", | |
| type=str, | |
| default="train", | |
| help="Dataset split to use", | |
| ) | |
| parser.add_argument("--prompt-column", type=str, default="prompt") | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| required=True, | |
| help="Model name to use for generation", | |
| ) | |
| parser.add_argument( | |
| "--vllm-server-url", | |
| type=str, | |
| default="http://localhost:8000/v1", | |
| help="URL of the vLLM server", | |
| ) | |
| parser.add_argument( | |
| "--temperature", | |
| type=float, | |
| help="Temperature for generation", | |
| ) | |
| parser.add_argument( | |
| "--top-p", | |
| type=float, | |
| help="Top-p value for generation", | |
| ) | |
| parser.add_argument( | |
| "--max-new-tokens", | |
| type=int, | |
| default=8192, | |
| help="Maximum number of new tokens to generate", | |
| ) | |
| parser.add_argument( | |
| "--num-generations", | |
| type=int, | |
| default=1, | |
| help="Number of generations per problem", | |
| ) | |
| parser.add_argument( | |
| "--hf-output-dataset", | |
| type=str, | |
| required=False, | |
| help="HuggingFace repo to push results to", | |
| ) | |
| parser.add_argument( | |
| "--private", | |
| action="store_true", | |
| help="Whether to make the output dataset private when pushing to HF Hub", | |
| ) | |
| args = parser.parse_args() | |
| print("\nRunning with arguments:") | |
| for arg, value in vars(args).items(): | |
| print(f" {arg}: {value}") | |
| print() | |
| print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") | |
| dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) | |
| print("Dataset loaded!") | |
| pipeline = build_distilabel_pipeline( | |
| model=args.model, | |
| base_url=args.vllm_server_url, | |
| prompt_column=args.prompt_column, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| max_new_tokens=args.max_new_tokens, | |
| num_generations=args.num_generations, | |
| ) | |
| print("Running generation pipeline...") | |
| distiset = pipeline.run(dataset=dataset, use_cache=False) | |
| print("Generation pipeline finished!") | |
| if args.hf_output_dataset: | |
| print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") | |
| distiset.push_to_hub(args.hf_output_dataset, private=args.private) | |
| print("Dataset pushed!") | |