| | import subprocess |
| | import sys |
| | import time |
| | from typing import List |
| |
|
| | from distilabel.steps.generators.data import LoadDataFromDicts |
| | from distilabel.steps.expand import ExpandColumns |
| | from distilabel.steps.keep import KeepColumns |
| | from distilabel.steps.tasks.self_instruct import SelfInstruct |
| | from distilabel.steps.tasks.evol_instruct.base import EvolInstruct |
| | from distilabel.llms.huggingface import InferenceEndpointsLLM |
| | from distilabel.pipeline import Pipeline |
| | from distilabel.steps import TextGenerationToArgilla |
| | from dotenv import load_dotenv |
| |
|
| | from domain import ( |
| | DomainExpert, |
| | CleanNumberedList, |
| | create_topics, |
| | create_examples_template, |
| | APPLICATION_DESCRIPTION, |
| | ) |
| |
|
| | load_dotenv() |
| |
|
| |
|
| | def define_pipeline( |
| | argilla_api_key: str, |
| | argilla_api_url: str, |
| | argilla_dataset_name: str, |
| | topics: List[str], |
| | perspectives: List[str], |
| | domain_expert_prompt: str, |
| | examples: List[dict], |
| | hub_token: str, |
| | endpoint_base_url: str, |
| | ): |
| | """Define the pipeline for the specific domain.""" |
| |
|
| | terms = create_topics(topics, perspectives) |
| | template = create_examples_template(examples) |
| | with Pipeline("farming") as pipeline: |
| | load_data = LoadDataFromDicts( |
| | name="load_data", |
| | data=[{"input": term} for term in terms], |
| | batch_size=64, |
| | ) |
| | llm = InferenceEndpointsLLM( |
| | base_url=endpoint_base_url, |
| | api_key=hub_token, |
| | ) |
| | self_instruct = SelfInstruct( |
| | name="self-instruct", |
| | application_description=APPLICATION_DESCRIPTION, |
| | num_instructions=5, |
| | input_batch_size=8, |
| | llm=llm, |
| | ) |
| |
|
| | evol_instruction_complexity = EvolInstruct( |
| | name="evol_instruction_complexity", |
| | llm=llm, |
| | num_evolutions=2, |
| | store_evolutions=True, |
| | input_batch_size=8, |
| | include_original_instruction=True, |
| | input_mappings={"instruction": "question"}, |
| | ) |
| |
|
| | expand_instructions = ExpandColumns( |
| | name="expand_columns", columns={"instructions": "question"} |
| | ) |
| | cleaner = CleanNumberedList(name="clean_numbered_list") |
| | expand_evolutions = ExpandColumns( |
| | name="expand_columns_evolved", |
| | columns={"evolved_instructions": "evolved_questions"}, |
| | ) |
| |
|
| | domain_expert = DomainExpert( |
| | name="domain_expert", |
| | llm=llm, |
| | input_batch_size=8, |
| | input_mappings={"instruction": "evolved_questions"}, |
| | output_mappings={"generation": "domain_expert_answer"}, |
| | ) |
| |
|
| | domain_expert._system_prompt = domain_expert_prompt |
| | domain_expert._template = template |
| |
|
| | keep_columns = KeepColumns( |
| | name="keep_columns", |
| | columns=["model_name", "evolved_questions", "domain_expert_answer"], |
| | ) |
| |
|
| | to_argilla = TextGenerationToArgilla( |
| | name="text_generation_to_argilla", |
| | dataset_name=argilla_dataset_name, |
| | dataset_workspace="admin", |
| | api_url=argilla_api_url, |
| | api_key=argilla_api_key, |
| | input_mappings={ |
| | "instruction": "evolved_questions", |
| | "generation": "domain_expert_answer", |
| | }, |
| | ) |
| |
|
| | load_data.connect(self_instruct) |
| | self_instruct.connect(expand_instructions) |
| | expand_instructions.connect(cleaner) |
| | cleaner.connect(evol_instruction_complexity) |
| | evol_instruction_complexity.connect(expand_evolutions) |
| | expand_evolutions.connect(domain_expert) |
| | domain_expert.connect(keep_columns) |
| | keep_columns.connect(to_argilla) |
| | return pipeline |
| |
|
| |
|
| | def serialize_pipeline( |
| | argilla_api_key: str, |
| | argilla_api_url: str, |
| | argilla_dataset_name: str, |
| | topics: List[str], |
| | perspectives: List[str], |
| | domain_expert_prompt: str, |
| | hub_token: str, |
| | endpoint_base_url: str, |
| | pipeline_config_path: str = "pipeline.yaml", |
| | examples: List[dict] = [], |
| | ): |
| | """Serialize the pipeline to a yaml file.""" |
| | pipeline = define_pipeline( |
| | argilla_api_key=argilla_api_key, |
| | argilla_api_url=argilla_api_url, |
| | argilla_dataset_name=argilla_dataset_name, |
| | topics=topics, |
| | perspectives=perspectives, |
| | domain_expert_prompt=domain_expert_prompt, |
| | hub_token=hub_token, |
| | endpoint_base_url=endpoint_base_url, |
| | examples=examples, |
| | ) |
| | pipeline.save(path=pipeline_config_path, overwrite=True, format="yaml") |
| |
|
| |
|
| | def create_pipelines_run_command( |
| | hub_token: str, |
| | argilla_api_key: str, |
| | argilla_api_url: str, |
| | pipeline_config_path: str = "pipeline.yaml", |
| | argilla_dataset_name: str = "domain_specific_datasets", |
| | ): |
| | """Create the command to run the pipeline.""" |
| | command_to_run = [ |
| | sys.executable, |
| | "-m", |
| | "distilabel", |
| | "pipeline", |
| | "run", |
| | "--config", |
| | pipeline_config_path, |
| | "--param", |
| | f"text_generation_to_argilla.dataset_name={argilla_dataset_name}", |
| | "--param", |
| | f"text_generation_to_argilla.api_key={argilla_api_key}", |
| | "--param", |
| | f"text_generation_to_argilla.api_url={argilla_api_url}", |
| | "--param", |
| | f"self-instruct.llm.api_key={hub_token}", |
| | "--param", |
| | f"evol_instruction_complexity.llm.api_key={hub_token}", |
| | "--param", |
| | f"domain_expert.llm.api_key={hub_token}", |
| | "--ignore-cache", |
| | ] |
| | return command_to_run |
| |
|
| |
|
| | def run_pipeline( |
| | hub_token: str, |
| | argilla_api_key: str, |
| | argilla_api_url: str, |
| | pipeline_config_path: str = "pipeline.yaml", |
| | argilla_dataset_name: str = "domain_specific_datasets", |
| | ): |
| | """Run the pipeline and yield the output as a generator of logs.""" |
| |
|
| | command_to_run = create_pipelines_run_command( |
| | hub_token=hub_token, |
| | pipeline_config_path=pipeline_config_path, |
| | argilla_dataset_name=argilla_dataset_name, |
| | argilla_api_key=argilla_api_key, |
| | argilla_api_url=argilla_api_url, |
| | ) |
| |
|
| | |
| | process = subprocess.Popen( |
| | args=command_to_run, |
| | stdout=subprocess.PIPE, |
| | stderr=subprocess.PIPE, |
| | env={"HF_TOKEN": hub_token}, |
| | ) |
| |
|
| | while process.stdout and process.stdout.readable(): |
| | time.sleep(0.2) |
| | line = process.stdout.readline() |
| | if not line: |
| | break |
| | yield line.decode("utf-8") |
| |
|