import ast import html import json import os import re import threading import time from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError from typing import Any, Dict, List, Optional, Tuple import requests from huggingface_hub import HfApi import gradio as gr import logging from openai import OpenAI from datasets import ( ClassLabel, Sequence, Value, get_dataset_config_names, get_dataset_split_names, load_dataset, ) import urllib.parse import urllib.request import urllib.error LOGO_SRC = "" MODEL_ID = "DataChef-32B" MAX_ATTEMPTS = 4 PREVIEW_LIMIT = 3 TRUNCATE_LIMIT = 400 PLAN_MAX_TOKENS = 16384 CODE_MAX_TOKENS = 8192 DATASET_PREVIEW_TIMEOUT = 60 # seconds RUN_CODE_URL = os.getenv("DATACHEF_RUN_URL", "http://127.0.0.1:7001/run_code") RUN_CODE_TIMEOUT = int(os.getenv("DATACHEF_RUN_TIMEOUT", "300")) DATASET_LOAD_PROBE_TIMEOUT = int(os.getenv("DATASET_LOAD_PROBE_TIMEOUT", "60")) DATASET_LOAD_PROBE_CONCURRENCY = int(os.getenv("DATASET_LOAD_PROBE_CONCURRENCY", "5")) RUN_CODE_KEY = os.getenv("DATACHEF_RUN_KEY") # optional shared secret sent as x-forward-key _DATASET_PREVIEW_CACHE: Dict[str, Dict] = {} _LAST_PREVIEW_IDS: List[str] = [] _LAST_PREVIEW_DATASET_INFOS: List[Dict] = [] _LAST_PREVIEW_ERRORS: Dict[str, str] = {} # Preset library PRESETS = { "Physics": { "task": { "description": "Train a language model so it gains strong physics-domain knowledge.", "benchmark": { "name": "PHYSICS", "description": ( "PHYSICS covers five physics disciplines and includes physics problems ranging from high school to graduate-level courses, with rigorous quality control.\nExample: {'id': 6, 'question': 'In an isothermal vaporization at 1 atm, how much does the internal energy of 1 mol of water increase?', 'answer': [['\\boxed{3.75 \\times 10^{4}}']]}" ), }, }, "datasets": [ {"dataset_id": "jeffmeloy/sonnet3.5_science_conversations"}, {"dataset_id": "TIGER-Lab/MMLU-STEM"}, {"dataset_id": "enesxgrahovac/the-feynman-lectures-on-physics"}, {"dataset_id": "Vikhrmodels/physics_big"}, ], }, "Math": { "task": { "description": "Train a math reasoning model with strong mathematical ability and logical multi-step reasoning for olympiad-level problems.", "benchmark": { "name": "AIME2025", "description": ( "The AIME is a prestigious high school mathematics competition. It features 30 challenging problems with integer answers (000-999).\nExample problem: Find the sum of all integer bases b > 9 for which 17_b is a divisor of 97_b." ), }, }, "datasets": [ {"dataset_id": "Jackrong/GPT-OSS-120B-Distilled-Reasoning-math"}, {"dataset_id": "joey234/mmlu-high_school_mathematics-neg-prepend"}, {"dataset_id": "joey234/mmlu-high_school_mathematics-neg"}, {"dataset_id": "ilsp/greek_lyceum_mathematics"}, {"dataset_id": "facebook/natural_reasoning"}, {"dataset_id": "Alibaba-Apsara/Superior-Reasoning-SFT-gpt-oss-120b"}, ], }, "Finance": { "task": { "description": "Fine-tune the LLM to equip it with finance-related knowledge.", "benchmark": { "name": "OpenFinData", "description": ( "The OpenFinData dataset comprises six modules, each encompassing multiple task dimensions to address diverse evaluation requirements in the financial sector: Financial Knowledge, Financial Discrimination, Financial Computation, Financial Analysis, Financial Interpretation, and Financial Compliance. Examples: [{\“question\”: \"You are a financial data review assistant. Identify which data point contains an obvious error. Provide the correct option. \nEuropean bond yields closed higher across the board. UK 10-year government bond yields rose 6.1 basis points to 4.328%, French 10-year yields gained 7 basis points to 3.298%, German 10-year yields increased 7.1 basis points to 2.714%, and Italian 10-year yields climbed 6.2 basis points to 4.567%. and the Spanish 10-year government bond yield rose 6.8 basis points to 3.758%.\nA. No obvious errors in the data\nB. UK 10-year government bond yield rose 6.1 basis points to 4.328%\nC. French 10-year government bond yield rose 7 basis points to 3.298%\nD. German 10-year government bond yields rose 7.1 basis points to 2.714%, Italian 10-year government bond yields rose 6.2 basis points to 4.567%“, ”answer“: ‘A’, {\”question\“: ”You are an entity recognition assistant. Please list the stock market concept sectors mentioned in the following content. Recently, autonomous driving concept stocks have surged strongly, with continuous inflow of incremental capital injecting new momentum into the automotive sector. Meanwhile, the implementation of relevant regulations has allowed Level 3 and Level 4 autonomous vehicles to be approved for road use, which is expected to significantly promote the development of domestic autonomous driving technology.“, \"answer\": \”Autonomous Driving\"}]." ), }, }, "datasets": [ {"dataset_id": "adityarane/financial-qa-dataset"}, {"dataset_id": "Josephgflowers/Financial-NER-NLP"}, {"dataset_id": "Josephgflowers/Finance-Instruct-500k"}, {"dataset_id": "sujet-ai/Sujet-Finance-Instruct-177k"}, {"dataset_id": "gbharti/finance-alpaca"}, ], }, "Custom": { "task": {"description": "", "benchmark": {"name": "", "description": ""}}, "datasets": [] } } # 设置默认启动时的配置 DEFAULT_SAMPLE = PRESETS["Physics"] PLANNER_SYSTEM_PROMPT = "你是一名 Planner,你的目标是为训练垂类 LLM 设计数据处理计划,该计划用于指导后续的代码生成和数据生产。请用英文输出。" CODER_SYSTEM_PROMPT = "你是一名 Coder。你的目标是为垂类 LLM 训练生成可执行的数据处理脚本。请用英文输出。" PLAN_INSTRUCTIONS = """\ 请用英文输出,including all headings and field values. Use the English section titles exactly as `## Training Data` and `## Processing Steps`. 根据任务描述、测试数据 (Benchmark) 和可用的 Huggingface 训练数据集,设计一个可行的 **数据处理计划**,内容包括: (1) 从可用的 Hugging Face 训练数据集中选择合适的作为原始数据; (2) 数据处理流程: 将选择的原始数据处理为 SFT 训练数据。 数据处理计划将用于指导代码生成和数据生产,因此请确保详尽、具备实际可行性,并避免模糊、似是而非或无意义的陈述。 一个优质的 SFT 数据集应该是 **高质量(样本准确无噪声)**, **多样性(覆盖多样的目标)**,**相关性(与目标垂直领域高度相关)**的。 注意事项: - 在可用的 Huggingface 数据集中选择适合且常用的,优先选择信息较完整、各字段含义明确、与实验目标相符的数据集。**禁止将 测试数据 (Benchmark) 用于训练**。不要使用未提供的 Huggingface 数据集、split 或者 configuration。 - 计划数据处理流程时,要依据给定数据集的信息进行设计,不要对数据集内容进行猜测。 - 在构造训练数据时,输入输出应当能够组成合理的问答对话,如果数据中存在与输出有关的上下文信息,应当在构造方案中设计方法,将上下文以合理的方式嵌入输入。如果测试数据 (Benchmark) 中的对话有特殊形式(如选择、填空、格式限制),那么在构造数据时,注意对齐测试数据任务格式。 - 将数据最终处理为能直接用于模型训练的对话格式: {"dialogs": [{"role": "user", "content": "user 内容"}, {"role": "assistant", "content": "assistant 内容"}]} - 计划数据处理流程时,可以灵活使用 LLM 推理服务,包括但不限于从异构数据源提取构造指令数据、利用 LLM 进行数据增强和合成、对样本进行指令评估或者打标等。明确说明 LLM prompt 构造方式。 按照以下格式生成数据处理计划,不要回复其他内容,且全部使用英文: ## Training Data [ {"dataset_id": "dataset_id", "split": "split_to_load", "name": "config_name", "sample_num": "num_samples_to_load", "reason": "why_this_dataset"}, ... ] ## Processing Steps [processing steps] """ CODE_INSTRUCTIONS = """\ 请用英文输出。根据可用的 Huggingface 训练数据,数据处理计划和可以调用的工具信息,生成需要执行的数据处理脚本。 当前仅需要验证数据处理计划的可行性,计划中的训练样本数可能较大,在初始数据加载时加载 20 个样本,来保证脚本运行的高效性。 注意事项: - 数据处理脚本应该提供数据处理的思路和调用工具的逻辑。 - 使用 `load_remote_dataset` 加载数据集,加载数据集时注意传入正确的的 `name` 和 `split` 参数。 - 如果想要将 Dataset 中的 sample 转化为文本,直接 `text = str(sample)` 操作即可, 不能使用 `text = json.dumps(sample, ensure_ascii=False)`,因为 sample is not JSON serializable。 - 如果需要合并数据集,请使用 concatenate_datasets([a, b, c]) (导入: from datasets import concatenate_datasets)。合并多个数据集时,需要注意它们的 features 应该是 aligned,例如,不能将 null 和 int32 的 feature 进行合并。 - 请将最后的数据处理为 `{'dialogs': [{'role': 'user', 'content': '[user content]'}, {'role': 'assistant', 'content': '[assistant content]'}]}` sharegpt 格式。可以使用 `format_to_sharegpt` 工具 (导入: from aidp import format_to_sharegpt)。 - 请把最后处理的数据保存在 data/processed/ 目录下。 - 生成代码时可以使用 `try/except` 来考虑特殊情况,例如代码 `"user": sample["prompt"].strip()` 可以在大部分情况下工作,但遇到 sample["prompt"] 为 None 时会执行失败。 按照以下格式生成一个代码块,不要回复其他内容: ```python # data-processing code block ``` """ AIDP_HEADER = """\ from aidp import ( load_remote_dataset, format_to_sharegpt, deduplicate_by_text_hash, select_by_filter, select_by_score, select_by_random, generate_dataset_with_llm, extract_json, generate_text_embeddings, score_dataset_with_llm, dump_dataset, ) """ TOOL_INFO = """\ ## Capabilities - You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of a Python class or function. - You can freely combine the use of any other public packages, like scikit-learn, NumPy, pandas, etc. ## Available Tools: Each tool is described in JSON format. When you call a tool, **import the tool at first**. ### load_remote_dataset 导入:`from aidp import load_remote_dataset` load_remote_dataset(path: str, name: str | None = None, split: str | None = None, num_samples: int | None = None, shuffle: bool = True) -> datasets.arrow_dataset.Dataset Load dataset from Hugging Face repo. Args: path: dataset id of hugging face repo to be loaded. name: The config name to load. split: The split name to load. num_samples: The number of samples to take from the dataset. shuffle: Whether shuffle the dataset before take. Returns: Dataset: loaded Hugging Face Dataset. Example: ```python from aidp import load_remote_dataset ds = load_remote_dataset("openlifescienceai/medmcqa", split="train") ``` ### format_to_sharegpt 导入:`from aidp import format_to_sharegpt` format_to_sharegpt(dataset: datasets.arrow_dataset.Dataset, system_map: Optional[Callable[[dict], str]] = None, user_map: Optional[Callable[[dict], str]] = None, assistant_map: Optional[Callable[[dict], str]] = None) -> datasets.arrow_dataset.Dataset Convert Dataset to ShareGPT format Dataset. Args: dataset (Dataset): Hugging Face 格式 dataset system_map (Optional[Callable[[dict], str]]): callable function,从数据源中得到对话的 system prompt user_map: callable function,从数据源中得到对话的 user prompt assistant_map: callable function,从数据源中得到对话的 assistant response Returns: Dataset: ShareGPT 格式数据集 - dialogs (list[dict]): 对话数据列表,每个元素是一个 dict,包括 role 和 content,role 为 system,user 或 assistant。 例如: {'dialogs': [ {"role": "system", "content": "xxx"}, {"role": "user", "content": "xxx"}, {"role": "assistant", "content": "xxx"}, ]} 要访问全部 user 角色内容,可以使用 '\\n'.join([d['content'] for d in sample['dialogs'] if d['role'] == 'user']) Example Code: ```python formatted_ds = format_to_sharegpt( ds, user_map=lambda x: f"{x['instruction']}\nx['input']", assistant_map=lambda x: f"{x['output']}" ) ``` 上述代码会把原 HuggingFace Dataset ds 转化为 ShareGPT 格式,其包含一个 feature `dialogs`,其中 dialogs 的每个元素是一个字典列表。 数据源里的 `"instruction"` 与 `"input"` 合并后作为 dialogs 中 user 角色内容,而 `"output"` 则映射为 assistant 角色内容。 ### deduplicate_by_text_hash 导入:`from aidp import deduplicate_by_text_hash` deduplicate_by_text_hash(dataset: datasets.arrow_dataset.Dataset, text_map: Callable[[dict], str] = None, lowercase: bool = False, ignore_non_character: bool = False) -> datasets.arrow_dataset.Dataset Deduplicate samples in the dataset using exact matching. Args: dataset: input dataset text_map: callable function to extract text from sample lowercase: Whether to convert sample text to lower case ignore_non_character: Whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations Returns: deduplicated dataset. ### select_by_filter 导入:`from aidp import select_by_filter` select_by_filter(dataset: datasets.arrow_dataset.Dataset, filter_fn: Callable[[dict], bool]) -> datasets.arrow_dataset.Dataset Select data using a boolean function. Args: dataset (Dataset): dataset to be selected. filter_fn (Callable[[dict], bool]): 输入单条样本,返回 True/False Returns: Dataset: selected dataset. ### select_by_score 导入:`from aidp import select_by_score` select_by_score(dataset: datasets.arrow_dataset.Dataset, score_fn: Callable[[dict], float], top_ratio: Optional[Annotated[float, FieldInfo(annotation=NoneType, required=True, metadata=[Ge(ge=0), Le(le=1)])]] = None, top_k: Optional[int] = None, reverse: bool = True) -> datasets.arrow_dataset.Dataset Select data using a score function. Args: dataset (Dataset): dataset to be selected. score_fn (Callable[[dict], float]): 任意指定的打分函数,输入样本,输出分数。例如可以基于文本长度、或者更加复杂的模型评分、困惑度等 top_ratio (Optional[float]): 范围 0 - 1,确定选择样本的比例 top_k (Optional[int]): 确定选择样本的数量 reverse (bool): 默认为 True,即选择分数更高的样本 Returns: Dataset: selected dataset. ### select_by_random 导入:`from aidp import select_by_random` select_by_random(dataset: datasets.arrow_dataset.Dataset, select_ratio: Optional[Annotated[float, FieldInfo(annotation=NoneType, required=True, metadata=[Ge(ge=0), Le(le=1)])]] = None, select_num: Optional[int] = None) -> datasets.arrow_dataset.Dataset Randomly select samples from the dataset. Args: dataset (Dataset): The input dataset. select_ratio (Optional[float]): The ratio of samples to select. select_num (Optional[int]): The number of samples to select. Returns: Dataset: The selected samples as a new dataset. ### generate_dataset_with_llm 导入:`from aidp import generate_dataset_with_llm` generate_dataset_with_llm(dataset: datasets.arrow_dataset.Dataset, system_prompt: str, response_parser: Callable[[str, dict], List[dict]]) -> datasets.arrow_dataset.Dataset Generate samples according to given dataset. Note: Limited by the inference speed, please limit the size of input dataset smaller than 10000. You can use `select_by_random` to sample the dataset at first. Args: dataset (Dataset): The input dataset. system_prompt (str): The system prompt to generate the llm prompt. You need to describe the response format that LLM need to follow. response_parser (Callable): The function to parse the llm response. input `response` (llm response) and `raw_sample` (original sample dict), output a list of parsed sample dict. If the output list longer than 1, generate more than one samples from the given raw sample. Returns: Dataset: The generated dataset. Example: ```python from aidp import generate_dataset_with_llm, extract_json # Assume the dataset is a news dataset and we need to generate sample according to `content` field. SYSTEM_PROMPT = ''' Read the `content` field provided by user. You need to ask a question and give the corresponding answer according to the news content. Your answer must be in JSON format as below, DO NOT reply any other content: ```json {"question": "[question content]", "answer": "[answer content]"} ``` ''' def response_parser(response: str, raw_sample: dict) -> list: parsed = extract_json(response) if parsed is not None: # Construct the final sample from the LLM response and the raw sample data. user = f'{raw_sample["content"]}\\nAccording to the above news content, answer the question: {parsed["question"]}' assistant = parsed['answer'] return [dict(user=user, assistant=assistant)] else: return [] # The new dataset will contains `user` and `assistant` field. new_dataset = generate_dataset_with_llm(dataset, system_prompt=SYSTEM_PROMPT, response_parser=response_parser) ``` Note: You need to describe the required response format to parse in the `system_prompt`. Recommend to require JSON format response and use `extract_json` function to parse the response. ### extract_json 导入:`from aidp import extract_json` extract_json(text: str) -> list | dict | None ### generate_text_embeddings 导入:`from aidp import generate_text_embeddings` generate_text_embeddings(text_map: Callable[[dict], str], dataset: datasets.arrow_dataset.Dataset) -> datasets.arrow_dataset.Dataset Generate embeddings for a given dataset. Args: text_map (Callable[[dict], str]): A function that extracts the text from each sample to calculate embedding. dataset: the input dataset. Returns: Dataset: A new dataset with an additional column `emb` containing the embedding for each sample. Examples: # We calculate the embedding according to the `question` field of each sample. ```python from aidp import generate_text_embeddings emb_dataset = generate_text_embeddings(text_map=lambda x: x['question'], dataset=dataset) ``` ### score_dataset_with_llm 导入:`from aidp import score_dataset_with_llm` score_dataset_with_llm(dataset: datasets.arrow_dataset.Dataset, query_map: Callable[[dict], str], response_map: Callable[[dict], str], task_description: str, evaluation_protocol: str) -> datasets.arrow_dataset.Dataset Evaluate dataset with an LLM grader. For each sample, the grader examines the query and the response under the provided `task_description` and `evaluation_protocol`, and returns a final score 1-5. Note: Limited by the inference speed, please limit the size of input dataset smaller than 10000. You can use `select_by_random` to sample the dataset at first. Args: dataset (Dataset): The input dataset to be evaluated. query_map (Callable[[dict], str]): A function that extracts the "question" text from each sample, e.g., `lambda x: next((i['content'] for i in x['dialogs'] if i['role'] == 'system'), "") + next((i['content'] for i in x['dialogs'] if i['role'] == 'user'), "")` or any custom serializer that returns `str`. response_map (Callable[[dict], str]): A function that extracts the candidate's "answer"/"response" text from each sample, e.g., `lambda x: next((i['content'] for i in x['dialogs'] if i['role'] == 'assistant'), "")`, returning `str`. task_description (str): A concise description of the target task, optionally including an example. evaluation_protocol (str): Evaluation protocol used to evaluate the data samples. Returns: Dataset: A new dataset with an additional column `'llm_score'` containing a score of 1-5 for each sample. Examples: ```python from aidp import score_dataset_with_llm # Example: evaluate a dataset based on the `question` and `answer` field, along with a concise task description. The task is in the form of multiple-choice questions. evaluated_dataset = score_dataset_with_llm( dataset, query_map=lambda s: s["question"], response_map=lambda s: s["answer"], test_description="微调 LLM 使其具备 finance 相关的知识。" evaluation_protocol="数据样本应该是 1) 选择题格式,可带有选项解析;2) 金融领域相关的。" ) # Example: evaluate a dataset based on the `dialogs` field, along with a detailed task description that includes examples under the task. evaluated_dataset = score_dataset_with_llm( dataset, query_map=lambda x: next((i['content'] for i in x['dialogs'] if i['role'] == 'system'), "") + next((i['content'] for i in x['dialogs'] if i['role'] == 'user'), "") response_map=lambda x: next((i['content'] for i in x['dialogs'] if i['role'] == 'assistant'), "") test_description="微调 LLM 使其具备 law 相关的知识。使用 LawBench 作为评测集。LawBench 经过精心设计,可对大语言模型的法律能力进行精确评估。 模拟了司法认知的三个维度,并选择了20个任务来评估大模型的能力。与一些仅有多项选择题的现有基准相比,LawBench 包含了更多与现实世界应用密切相关的任务类型,如法律实体识别、阅读理解、犯罪金额计算和咨询等。LawBench 数据集包括 20 个不同的任务,涵盖 3 个认知水平:(1)法律知识记忆:大语言模型能否记住必要的法律概念、术语、法条和事实。(2)法律知识理解:大语言模型能否理解法律文本中的实体、事件和关系,从而理解法律文本的意义和内涵。(3)法律知识应用:大语言模型能否正确利用其法律知识、对其进行推理从而解决下游应用中的现实法律任务。Example in LawBench: [{"instruction": "请根据具体场景与问题给出法律依据,只需要给出具体法条内容,每个场景仅涉及一个法条。", "question": "场景:某个地区的三个以上专业农民合作社想要出资设立农民专业合作社联合社,以提高其在市场中的竞争力和规模效应。根据哪条法律,三个以上的农民专业合作社可以出资设立农民专业合作社联合社?", "answer": "根据《农民专业合作社法》第五十六条,三个以上的农民专业合作社在自愿的基础上,可以出资设立农民专业合作社联合社。该联合社应当有自己的名称、组织机构和住所,由联合社全体成员制定并承认的章程,以及符合章程规定的成员出资。"}]" evaluation_protocol="数据样本应该是 1) 法律相关的;2) 带有丰富法律知识的、附例子或者法律条文依据。", ) # Filter: filter samples with `llm_score` greater than 3 filtered_dataset = evaluated.filter(lambda x: x['llm_score'] > 3) ``` ### dump_dataset 导入:`from aidp import dump_dataset` dump_dataset(dataset: datasets.arrow_dataset.Dataset, filename: str | pathlib.Path) Dump a dataset to a JSON Lines file. Args: dataset (Dataset): 要保存的数据集。 filename (str | Path): 要保存的路径。 Returns: None """ _DEMO_THEME: Optional[Any] = None _DEMO_CSS: Optional[str] = None _REMOTE_API_BASE = ( os.getenv("OPENAI_API_BASE") or os.getenv("DATACHEF_VLLM_URL") or "" ).strip().rstrip("/") _REMOTE_API_KEY = (os.getenv("DATACHEF_VLLM_API_KEY") or os.getenv("OPENAI_API_KEY") or "").strip() _REMOTE_FORWARD_KEY = (os.getenv("X_FORWARD_KEY") or os.getenv("FORWARD_KEY") or "").strip() # Default headers for all LLM calls; add x-forward-key when provided. _REMOTE_DEFAULT_HEADERS: Dict[str, str] = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "Content-Type": "application/json", } if _REMOTE_FORWARD_KEY: _REMOTE_DEFAULT_HEADERS["x-forward-key"] = _REMOTE_FORWARD_KEY _REMOTE_CLIENT: Optional[OpenAI] = None _REMOTE_CLIENT_LOCK = threading.Lock() # --- Logging --- logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper()) logger = logging.getLogger("datachef.app") def _load_preset_config(preset_name: str): """根据选择的 Preset 名称,返回更新后的 UI 组件值""" if preset_name not in PRESETS: return gr.update(), gr.update(), gr.update(), gr.update() config = PRESETS[preset_name] task = config.get("task", {}) benchmark = task.get("benchmark", {}) # 解析数据集为 Dataframe 需要的 List[List[str]] 格式 dataset_rows = [] if config.get("datasets"): dataset_rows = [ [d.get("dataset_id", "")] for d in config["datasets"] if d.get("dataset_id") ] return ( task.get("description", ""), # task_description benchmark.get("name", ""), # benchmark_name benchmark.get("description", ""), # benchmark_description dataset_rows # dataset_ids ) def _truncate_text(text: str, limit: int = TRUNCATE_LIMIT) -> str: if len(text) <= limit: return text return f"[TRUNCATED, length={len(text)}] {text[:limit]}" def _log_preview(label: str, text: str, limit: int = 400) -> None: preview = text.replace("\n", "\\n") if len(preview) > limit: preview = preview[:limit] + "...(truncated)" logger.info("[preview] %s len=%s preview=%s", label, len(text), preview) def _log_full(label: str, text: str) -> None: logger.info("[full] %s len=%s text=%s", label, len(text), text) def _strip_think(text: str) -> str: if not text: return text cleaned = re.sub(r".*?", "", text, flags=re.DOTALL | re.IGNORECASE) if "" in cleaned: cleaned = cleaned.split("", 1)[0] cleaned = cleaned.replace("", "") cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) return cleaned.strip() def _escape_html(text: Any) -> str: return html.escape(str(text), quote=True) def _render_collapsible_json(title: str, data: Any, preview: Optional[str] = None) -> str: """Render a collapsible JSON block with optional preview text.""" pretty = json.dumps(data, ensure_ascii=False, indent=2) if preview is None: preview = str(data) if len(preview) > 140: preview = preview[:140] + "..." return ( "
" f"{_escape_html(preview or title)}" f"
{_escape_html(pretty)}
" "
" ) def _render_processed_scores_table(data_scores: List[Dict]) -> str: """Render processed training data as an HTML table with collapsible samples.""" if not data_scores: return "

No processed data returned.

" judgement_map = { "A": "A: Invalid", "B": "B: Format Error", "C": "C: Incorrect", "D": "D: Task Mismatch", "E": "E: Pass", } rows_html = [] for idx, s in enumerate(data_scores, start=1): sample = { "user": s.get("question") or s.get("question_preview") or s.get("input") or "", "assistant": s.get("answer") or s.get("answer_preview") or s.get("output") or "", } judgement = s.get("judgement", s.get("feedback", "")) judgement_disp = judgement_map.get(str(judgement).strip().upper(), judgement or "") score_val = s.get("score") score_disp = f"{score_val:.2f}" if isinstance(score_val, (int, float)) else _escape_html(score_val if score_val is not None else "") sample_preview = sample["user"] if isinstance(sample["user"], str) else "" sample_cell = _render_collapsible_json("sample", sample, preview=sample_preview or "View sample") rows_html.append( "" f"{_escape_html(s.get('index', idx))}" f"{score_disp}" f"{_escape_html(judgement_disp)}" f"{sample_cell}" "" ) rows_joined = "\n".join(rows_html) table_html = ( '
' '' "" "" "" f"{rows_joined}" "
#scorejudgementsample
" "
" ) return table_html def _render_preview_table_html(dataset_infos: List[Dict], errors: Optional[Dict[str, str]] = None) -> str: """Render raw preview data as an HTML table with collapsible samples.""" rows_html = [] idx = 1 for info in dataset_infos or []: ds_id = info.get("dataset_id", "Unknown") for ex_group in info.get("examples", []): for sample in ex_group.get("preview_examples", []): sample_cell = _render_collapsible_json("sample", sample) rows_html.append( "" f"{idx}" f"{_escape_html(ds_id)}" f"{sample_cell}" "" ) idx += 1 if rows_html: base = ( "
" "" "" f"{''.join(rows_html)}" "
#Source Datasetsample
" ) else: base = "

No preview samples available.

" if errors: err_lines = "".join(f"
  • {_escape_html(k)}: {_escape_html(v)}
  • " for k, v in errors.items()) base += f"
    Loading Errors
    " return base def _truncate_dataset_examples_for_prompt(datasets: List[Dict], limit: int = TRUNCATE_LIMIT) -> List[Dict]: """Return a truncated copy of datasets for LLM prompts to avoid context overflow.""" truncated: List[Dict] = [] for ds in datasets or []: ds_copy = { "dataset_id": ds.get("dataset_id"), "revision": ds.get("revision"), "examples": [], } for ex in ds.get("examples", []): ex_copy = { "name": ex.get("name"), "split": ex.get("split"), "schema": ex.get("schema"), "preview_examples": [], } for sample in ex.get("preview_examples", []): sample_copy = {} for k, v in sample.items(): val = json.dumps(v, ensure_ascii=False) if not isinstance(v, str) else v sample_copy[k] = _truncate_text(val, limit=limit) ex_copy["preview_examples"].append(sample_copy) ds_copy["examples"].append(ex_copy) truncated.append(ds_copy) return truncated def _render_plan_prompt(task_description: str, benchmark: Dict[str, str], datasets: List[Dict]) -> str: datasets = _truncate_dataset_examples_for_prompt(datasets) parts = [ "# 任务描述", task_description, "", "# 测试数据 (Benchmark)", f"## {benchmark.get('name', '')}", benchmark.get("description", ""), "", "# 可用的 Huggingface 训练数据集", ] for item in datasets: parts.append(f"## {item['dataset_id']}") parts.append(str(item["examples"])) parts.append("") parts.append("---") parts.append(PLAN_INSTRUCTIONS.strip()) return "\n".join(parts).strip() def _render_code_prompt(datasets: List[Dict], plan: str, tool_info: str) -> str: datasets = _truncate_dataset_examples_for_prompt(datasets) parts = [ "# 可用的 Huggingface 训练数据集", ] for item in datasets: parts.append(f"## {item['dataset_id']}") parts.append(str(item["examples"])) parts.append("") parts.extend([ "# 数据处理计划", plan, "", "# 工具信息", tool_info, "", "---", CODE_INSTRUCTIONS.strip(), ]) return "\n".join(parts).strip() def _compose_task_context(task_description: str, benchmark_description: str) -> str: """Combine task description and benchmark description for downstream execution/logging.""" benchmark_description = benchmark_description or "" if benchmark_description.strip(): return f"{task_description}\n\nBenchmark: {benchmark_description}" return task_description def _build_tool_info() -> str: return TOOL_INFO.strip() def _resolve_model_id() -> str: return os.getenv("DATACHEF_MODEL_PATH") or MODEL_ID def _use_remote_llm() -> bool: return bool(_REMOTE_API_BASE) def _get_remote_client() -> OpenAI: if not _REMOTE_API_BASE: raise RuntimeError("Remote base URL is empty.") global _REMOTE_CLIENT if _REMOTE_CLIENT is not None: return _REMOTE_CLIENT with _REMOTE_CLIENT_LOCK: if _REMOTE_CLIENT is not None: return _REMOTE_CLIENT # Some self-hosted vLLM endpoints do not require authentication; OpenAI client # still expects a non-empty api_key, so provide a dummy token when missing. api_key = _REMOTE_API_KEY or os.getenv("OPENAI_API_KEY") or "EMPTY" _REMOTE_CLIENT = OpenAI( base_url=_REMOTE_API_BASE, api_key=api_key, default_headers=_REMOTE_DEFAULT_HEADERS, ) return _REMOTE_CLIENT def _generate_text_remote(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float) -> str: client = _get_remote_client() _log_full("llm_request", json.dumps({"messages": messages, "max_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, ensure_ascii=False)) try: completion = client.chat.completions.create( model=_resolve_model_id(), messages=messages, max_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) except Exception as e: # noqa: BLE001 raise RuntimeError(f"Remote LLM error: {e}") from e content = completion.choices[0].message.content if completion.choices else "" _log_full("llm_response", content or "") return _strip_think(content or "") def _generate_text(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float) -> str: if not _use_remote_llm(): raise RuntimeError("DATACHEF_VLLM_URL is not set; please configure remote vLLM endpoint.") return _generate_text_remote(messages, max_new_tokens, temperature, top_p) def _request_keywords_via_proxy(task_description: str, benchmark_description: str, n: int = 3) -> List[str]: """ Ask proxy /generate_keywords to produce search keywords. Falls back to simple heuristics if proxy fails. """ base = RUN_CODE_URL.rsplit("/", 1)[0].rstrip("/") url = f"{base}/generate_keywords" headers = {} if RUN_CODE_KEY: headers["x-forward-key"] = RUN_CODE_KEY payload = { "task_description": task_description, "benchmark_description": benchmark_description, "n": n, } try: resp = requests.post(url, json=payload, headers=headers, timeout=20) resp.raise_for_status() data = resp.json() if resp.content else {} keywords = data.get("keywords") or [] keywords = [k.strip() for k in keywords if isinstance(k, str) and k.strip()] if keywords: return keywords[:n] except Exception as e: logger.warning("generate_keywords proxy failed: %s", e) # fallback: basic split of task description fallback = [] for token in re.split(r"[;,.,。\\n]+", task_description): token = token.strip() if len(token.split()) <= 5 and token: fallback.append(token) if len(fallback) >= n: break return fallback def _search_hf_datasets(keywords: List[str], limit_per_kw: int = 5) -> List[str]: """Search HuggingFace datasets for each keyword and return dataset ids.""" api = HfApi() found: List[str] = [] seen = set() for kw in keywords: try: try: # huggingface_hub>=0.36 uses `search` instead of the deprecated `full_text_search` results = api.list_datasets( search=kw, limit=limit_per_kw, sort="likes", direction=-1, # descending ) except TypeError: # fallback for older hub versions results = api.list_datasets( full_text_search=kw, limit=limit_per_kw, sort="likes", ) except Exception as e: logger.warning("HF search failed for '%s': %s", kw, e) continue for ds in results: ds_id = getattr(ds, "id", None) or getattr(ds, "dataset", None) if ds_id and ds_id not in seen: seen.add(ds_id) found.append(ds_id) return found def _auto_suggest_datasets(task_description: str, benchmark_description: str) -> List[List[str]]: if not task_description: return [] benchmark_description = benchmark_description or "" keywords = _request_keywords_via_proxy(task_description, benchmark_description, n=3) ds_ids = _search_hf_datasets(keywords, limit_per_kw=3) return [[d] for d in ds_ids] def _rows(table) -> List[List[str]]: """Normalize Gradio dataframe output (which may be pandas DataFrame) to list of rows.""" if table is None: return [] if hasattr(table, "to_numpy"): try: return table.to_numpy().tolist() except Exception: return [] return table or [] def _add_selected_suggestion( current, suggestions, idx: Optional[int], ) -> List[List[str]]: rows_current = _rows(current) rows_suggestions = _rows(suggestions) if idx is None or idx < 0 or idx >= len(rows_suggestions): return rows_current val = str(rows_suggestions[idx][0]).strip() if rows_suggestions[idx] else "" if not val: return rows_current existing = [str(r[0]).strip() for r in rows_current if r and len(r) > 0] if val in existing: return rows_current return rows_current + [[val]] def _add_all_suggestions( current, suggestions, ) -> List[List[str]]: rows_current = _rows(current) rows_suggestions = _rows(suggestions) existing = [str(r[0]).strip() for r in rows_current if r and len(r) > 0] merged = list(rows_current) for row in rows_suggestions: val = str(row[0]).strip() if row else "" if val and val not in existing: existing.append(val) merged.append([val]) return merged def _add_empty_row(current) -> List[List[str]]: """在当前表格末尾添加空行""" rows = _rows(current) # 确保追加的是一个 list 包含空字符串,对应单列 return rows + [[""]] def _delete_selected_row(current, idx: Optional[int]) -> Tuple[List[List[str]], None]: """删除选中的行""" rows = _rows(current) if idx is not None and 0 <= idx < len(rows): rows.pop(idx) # 返回处理后的行,并重置选中状态为 None return rows, None def _auto_suggest_datasets_ui(task_description: str, benchmark_description: str) -> Tuple[List[List[str]], Optional[int]]: suggestions = _auto_suggest_datasets(task_description, benchmark_description) return suggestions, None def _feature_repr(feature) -> str: if isinstance(feature, Value): if feature.dtype in ("string", "large_string"): return "str" return feature.dtype if isinstance(feature, ClassLabel): return "ClassLabel" if isinstance(feature, Sequence): return f"List[{_feature_repr(feature.feature)}]" return str(feature) def _feature_repr_from_hf(feature: Dict) -> str: """Lightweight repr using datasets-server feature schema.""" if not isinstance(feature, dict): return "unknown" type_info = feature.get("type", {}) if not isinstance(type_info, dict): return str(type_info) if type_info is not None else "unknown" type_name = type_info.get("_type") if type_name == "Value": dtype = type_info.get("dtype") if dtype in ("string", "large_string"): return "str" return dtype or "Value" if type_name == "ClassLabel": return "ClassLabel" if type_name == "Sequence": inner = type_info.get("feature", {}) inner_repr = _feature_repr_from_hf({"type": inner}) if isinstance(inner, dict) else "Any" return f"List[{inner_repr}]" return type_name or str(type_info) or "unknown" def _run_with_timeout(func, timeout: int, *args, **kwargs): """Run a callable with a timeout; cancel the future if it takes too long.""" executor = ThreadPoolExecutor(max_workers=1) future = executor.submit(func, *args, **kwargs) try: return future.result(timeout=timeout) finally: executor.shutdown(wait=False, cancel_futures=True) import time # 确保开头导入了 time def _render_progress_html(current_step: int, error: bool = False, detail_msg: str = "", est_time: int = 0) -> str: steps = ["Data Preview", "Processing Plan", "Code Generation", "Code Execution"] icons_html = '
    ' active_progress_html = "" # 生成一个唯一的 ID 和 动画名,确保每次切换步骤时浏览器都会重置动画 unique_id = f"step_bar_{current_step}_{int(time.time())}" anim_name = f"fill_{unique_id}" for idx, name in enumerate(steps): is_active = (idx == current_step) if idx < current_step: status_class, icon = "done", "✓" elif is_active: if error: status_class, icon = "error", "✕" elif current_step == 4: status_class, icon = "done", "✓" else: status_class, icon = "active", "⋯" if est_time > 0: # 通过内联 style 定义一个新的 keyframes 动画 active_progress_html = f"""
    Est. time: ~{est_time}s
    """ else: status_class, icon = "", str(idx + 1) icons_html += f"""
    {icon}
    {name}
    """ if idx < len(steps) - 1: icons_html += '
    ' icons_html += '
    ' text_html = f'
    {detail_msg}
    ' return f'
    {icons_html}{active_progress_html}{text_html}
    ' def _get_dataset_examples(dataset_id: str, timeout: Optional[int] = None) -> Dict: """Load dataset examples with optional timeout and cache results to avoid reloading.""" if dataset_id in _DATASET_PREVIEW_CACHE: return _DATASET_PREVIEW_CACHE[dataset_id] errors: List[str] = [] try: if timeout: result = _run_with_timeout(_build_dataset_examples, timeout, dataset_id) else: result = _build_dataset_examples(dataset_id) _DATASET_PREVIEW_CACHE[dataset_id] = result return result except FuturesTimeoutError: errors.append(f"Preview for {dataset_id} timed out after {timeout}s") except Exception as e: errors.append(str(e)) # Fallback to legacy full load if lightweight preview fails try: result = _build_dataset_examples_legacy(dataset_id) _DATASET_PREVIEW_CACHE[dataset_id] = result return result except Exception as e: errors.append(str(e)) raise RuntimeError("; ".join(errors)) def _run_code_probe(code_text: str, timeout: int) -> Tuple[bool, str]: """ Run lightweight code on the proxy runner to test dataset loading. Returns (success, error_message). """ if not RUN_CODE_URL: return False, "RUN_CODE_URL not configured" headers = {} if RUN_CODE_KEY: headers["x-forward-key"] = RUN_CODE_KEY try: resp = requests.post( RUN_CODE_URL, json={"code": code_text, "timeout": timeout}, headers=headers, timeout=min(timeout, 30) + 10, ) resp.raise_for_status() data = resp.json() if resp.content else {} except Exception as e: # noqa: BLE001 return False, f"probe request failed: {e}" job_id = data.get("job_id") if not job_id: return False, "probe runner did not return job_id" status_base = RUN_CODE_URL.rsplit("/", 1)[0].rstrip("/") status_url = f"{status_base}/status/{job_id}" deadline = time.time() + timeout poll_interval = 1.5 result_data = None while time.time() < deadline: try: r = requests.get(status_url, headers=headers, timeout=min(timeout, 30)) r.raise_for_status() job = r.json() if r.content else {} except Exception: time.sleep(poll_interval) continue status = job.get("status") if status in {"success", "failed", "timeout", "error"}: result_data = job break time.sleep(poll_interval) if result_data is None: return False, f"timed out after {timeout}s" status = result_data.get("status") retcode = result_data.get("returncode") stderr = result_data.get("stderr") or "" summary = result_data.get("summary") or "" if status == "success" and retcode == 0: return True, "" return False, summary or stderr or f"probe failed (status={status}, returncode={retcode})" def _probe_dataset_load(dataset_id: str, name: Optional[str], split: str) -> Tuple[bool, str]: """ Trigger a small load via proxy runner to weed out slow/invalid datasets. Loads up to 20 samples; returns (success, error_message). """ name_literal = "None" if not name or name == "default" else repr(name) split_literal = repr(split or "train") code = ( "from aidp import load_remote_dataset\n" "from datasets import disable_caching\n" "disable_caching()\n" f"ds = load_remote_dataset({repr(dataset_id)}, name={name_literal}, split={split_literal}, num_samples=20, shuffle=False)\n" "print('rows', len(ds))\n" "print('columns', list(ds.features.keys()))\n" ) return _run_code_probe(code, timeout=DATASET_LOAD_PROBE_TIMEOUT) def _filter_dataset_infos_by_load(dataset_infos: List[Dict]) -> Tuple[List[Dict], Dict[str, str]]: """ Run load probes for each dataset (concurrent, limited). Datasets that fail or timeout are dropped. Returns (kept_dataset_infos, load_errors). """ kept: List[Dict] = [] errors: Dict[str, str] = {} with ThreadPoolExecutor(max_workers=max(1, DATASET_LOAD_PROBE_CONCURRENCY)) as pool: future_map = {} for info in dataset_infos: dataset_id = info.get("dataset_id", "") examples = info.get("examples") or [] if not examples: errors[dataset_id] = "no examples to probe" continue first = examples[0] name = first.get("name") or None split = first.get("split") or "train" future = pool.submit(_probe_dataset_load, dataset_id, name, split) future_map[future] = (dataset_id, info) for future in future_map: dataset_id, info = future_map[future] try: ok, err = future.result() except Exception as e: # pragma: no cover - robustness ok, err = False, str(e) if ok: kept.append(info) else: errors[dataset_id] = err return kept, errors def _fetch_json(url: str, timeout: int) -> Dict: try: with urllib.request.urlopen(url, timeout=timeout) as resp: if resp.status != 200: raise RuntimeError(f"HTTP {resp.status}: {resp.read().decode('utf-8', 'ignore')}") data = resp.read() return json.loads(data.decode("utf-8")) except urllib.error.URLError as e: raise RuntimeError(f"Request failed: {e}") from e def _select_split(splits: List[Dict]) -> Tuple[str, str]: if not splits: raise RuntimeError("No splits available.") # Prefer train split for item in splits: if item.get("split") == "train": return item.get("config") or "default", "train" first = splits[0] return first.get("config") or "default", first.get("split") or "train" def _build_dataset_examples(dataset_id: str) -> Dict: """Lightweight preview using datasets-server first-rows API (no full download).""" quoted_id = urllib.parse.quote(dataset_id, safe="") splits_url = f"https://datasets-server.huggingface.co/splits?dataset={quoted_id}" splits_resp = _fetch_json(splits_url, DATASET_PREVIEW_TIMEOUT) splits = splits_resp.get("splits", []) config, split = _select_split(splits) first_rows_url = ( "https://datasets-server.huggingface.co/first-rows" f"?dataset={quoted_id}&config={urllib.parse.quote(config, safe='')}" f"&split={urllib.parse.quote(split, safe='')}&offset=0&length={PREVIEW_LIMIT}" ) first_rows = _fetch_json(first_rows_url, DATASET_PREVIEW_TIMEOUT) features = first_rows.get("features", []) schema = { feat["name"]: _feature_repr_from_hf(feat) for feat in features if isinstance(feat, dict) and "name" in feat } preview_examples: List[Dict] = [] for row_item in first_rows.get("rows", []): if not isinstance(row_item, dict): continue row = row_item.get("row", {}) if not isinstance(row, dict): preview_examples.append({"row": _truncate_text(str(row))}) continue trimmed = {} for k, v in row.items(): value_text = json.dumps(v, ensure_ascii=False) if not isinstance(v, str) else v trimmed[k] = value_text preview_examples.append(trimmed) if len(preview_examples) >= PREVIEW_LIMIT: break return { "dataset_id": dataset_id, "revision": "main", "examples": [{ "name": config or "default", "split": split, "schema": schema, "preview_examples": preview_examples, }], } def _build_dataset_examples_legacy(dataset_id: str) -> Dict: """Fallback preview by loading a small slice via datasets (may download data).""" fallback_revisions = [None, "convert/parquet", "refs/convert/parquet", "parquet"] last_err: Optional[Exception] = None script_err: Optional[RuntimeError] = None def load_with_revision(revision: Optional[str]) -> Dict: config_names = get_dataset_config_names(dataset_id, revision=revision) config = config_names[0] if config_names else None split_names = ( get_dataset_split_names(dataset_id, config, revision=revision) if config else get_dataset_split_names(dataset_id, revision=revision) ) if "train" in split_names: split = "train" else: split = split_names[0] if split_names else "train" load_kwargs = dict(split=f"{split}[:{PREVIEW_LIMIT}]") if revision: load_kwargs["revision"] = revision if config: ds = load_dataset(dataset_id, name=config, **load_kwargs) else: ds = load_dataset(dataset_id, **load_kwargs) schema = {k: _feature_repr(v) for k, v in ds.features.items()} preview_examples: List[Dict] = [] for item in ds: trimmed = {} for k, v in item.items(): value_text = json.dumps(v, ensure_ascii=False) if not isinstance(v, str) else v trimmed[k] = value_text preview_examples.append(trimmed) return { "dataset_id": dataset_id, "revision": revision or "main", "examples": [{ "name": config or "default", "split": split, "schema": schema, "preview_examples": preview_examples, }], } for rev in fallback_revisions: try: return load_with_revision(rev) except RuntimeError as e: last_err = e msg = str(e) if "Dataset scripts are no longer supported" in msg: script_err = e continue if "doesn't exist for dataset" in msg or "Not Found" in msg: continue raise except Exception as e: last_err = e raise if script_err: raise RuntimeError( f"{dataset_id} requires running its dataset script, which is disabled in datasets>=4. " "No parquet/convert revision was found." ) from script_err raise last_err or RuntimeError("Unknown error while loading dataset.") def _normalize_dataset_ids(dataset_ids: Any) -> List[str]: if dataset_ids is None: return [] if hasattr(dataset_ids, "to_numpy"): rows = dataset_ids.to_numpy().tolist() else: rows = dataset_ids if not rows: return [] flat: List[str] = [] for row in rows: if not row: continue value = str(row[0]).strip() if value: flat.append(value) seen = set() result = [] for item in flat: if item in seen: continue seen.add(item) result.append(item) return result def _build_input_sample( task_description: str, benchmark_name: str, benchmark_description: str, dataset_ids: List[str], ) -> Dict: if not dataset_ids: raise ValueError("Dataset list is empty. Please select at least one dataset.") # If preview just ran with the same datasets, reuse its filtered result to avoid re-probing. global _LAST_PREVIEW_IDS, _LAST_PREVIEW_DATASET_INFOS, _LAST_PREVIEW_ERRORS if _LAST_PREVIEW_IDS == dataset_ids and _LAST_PREVIEW_DATASET_INFOS: dataset_infos = list(_LAST_PREVIEW_DATASET_INFOS) errors = dict(_LAST_PREVIEW_ERRORS) else: dataset_infos = [] errors: Dict[str, str] = {} for dataset_id in dataset_ids: try: dataset_infos.append(_get_dataset_examples(dataset_id, DATASET_PREVIEW_TIMEOUT)) except Exception as e: errors[dataset_id] = str(e) if dataset_infos: dataset_infos, load_errors = _filter_dataset_infos_by_load(dataset_infos) errors.update(load_errors) _LAST_PREVIEW_IDS = dataset_ids _LAST_PREVIEW_DATASET_INFOS = list(dataset_infos) _LAST_PREVIEW_ERRORS = dict(errors) # === 修改开始: 检查是否所有数据集都挂了 === if not dataset_infos: error_details = "; ".join([f"{k}: {v}" for k, v in errors.items()]) raise ValueError(f"All selected datasets failed to load (timeout or access error). Please try smaller or different datasets. Details: {error_details}") # === 修改结束 === if errors and not dataset_infos: # Fallback (logic covered above, but kept for safety) raise ValueError(f"Dataset build error: {errors}") return { "id": 0, "task": { "description": task_description, "benchmark": { "name": benchmark_name, "description": benchmark_description, }, }, "datasets": dataset_infos, "errors": errors, } def _preview_dataset(dataset_ids_table: List[List[str]]) -> str: dataset_ids = _normalize_dataset_ids(dataset_ids_table) if not dataset_ids: return "Please provide at least one dataset_id." try: global _LAST_PREVIEW_IDS, _LAST_PREVIEW_DATASET_INFOS, _LAST_PREVIEW_ERRORS dataset_infos: List[Dict] = [] errors: Dict[str, str] = {} for dataset_id in dataset_ids: try: dataset_infos.append(_get_dataset_examples(dataset_id, DATASET_PREVIEW_TIMEOUT)) except Exception as e: errors[dataset_id] = str(e) # Probe loadability and drop slow/failed datasets before rendering preview if dataset_infos: dataset_infos, load_errors = _filter_dataset_infos_by_load(dataset_infos) errors.update(load_errors) _LAST_PREVIEW_IDS = dataset_ids _LAST_PREVIEW_DATASET_INFOS = list(dataset_infos) _LAST_PREVIEW_ERRORS = dict(errors) return _render_preview_table_html(dataset_infos, errors if errors else None) except Exception as e: return f"Preview error: {e}" def _on_df_select(evt: gr.SelectData) -> Optional[int]: if not evt.index: return None return evt.index[0] def _parse_code_block(response: str) -> List[str]: blocks = [] start_pattern = re.compile(r"(?m)^[ \\t]*```python[ \\t]*\\r?\\n") for match in start_pattern.finditer(response): start_pos = match.end() next_match = start_pattern.search(response, start_pos) search_end = next_match.start() if next_match else len(response) last_end = -1 for end_match in re.finditer(r"```", response[start_pos:search_end]): end_pos = start_pos + end_match.start() after = end_pos + len("```") if after >= len(response) or response[after] in ["\n", "\r"]: last_end = end_pos if last_end != -1: content = response[start_pos:last_end].rstrip("\n\r ") blocks.append(content) return blocks def _strip_code_fence(code_text: str) -> str: """Remove surrounding ``` or ```python fences if present.""" code = code_text.strip() if code.startswith("```"): # remove first line fence lines = code.splitlines() if lines and lines[0].lstrip("`").lower().startswith("python"): lines = lines[1:] elif lines and lines[0].startswith("```"): lines = lines[1:] # drop trailing fence if lines and lines[-1].strip().startswith("```"): lines = lines[:-1] code = "\n".join(lines).strip() return code def _syntax_check(code_text: str) -> Tuple[bool, str]: try: ast.parse(code_text) except SyntaxError as e: location = f"line {e.lineno}, col {e.offset}" if e.lineno and e.offset else "unknown location" return False, f"SyntaxError: {e.msg} ({location})" return True, "" def _build_sample_from_inputs( task_description: str, benchmark_name: str, benchmark_description: str, dataset_ids_table: List[List[str]], ) -> Tuple[Optional[Dict], Optional[str]]: dataset_ids = _normalize_dataset_ids(dataset_ids_table) if not task_description or not benchmark_name or not dataset_ids: return None, "Missing required input fields." # Allow empty benchmark description by normalizing to empty string benchmark_description = benchmark_description or "" try: sample = _build_input_sample( task_description=task_description, benchmark_name=benchmark_name, benchmark_description=benchmark_description, dataset_ids=dataset_ids, ) except Exception as e: return None, f"Dataset build error: {e}" return sample, None def _generate_plan_from_sample(sample: Dict) -> str: planner_messages = [{ "role": "system", "content": PLANNER_SYSTEM_PROMPT, }, { "role": "user", "content": _render_plan_prompt( task_description=sample["task"]["description"], benchmark=sample["task"]["benchmark"], datasets=sample["datasets"], ), }] try: plan_text = _generate_text( planner_messages, max_new_tokens=PLAN_MAX_TOKENS, temperature=0.6, top_p=0.95, ) _log_preview("planner_output", plan_text) return plan_text except Exception as e: return f"Planner error: {e}" def _generate_code_from_plan(datasets: List[Dict], plan_text: str, progress=None) -> str: _log_preview("coder_plan_input", plan_text) tool_info = _build_tool_info() first_code = "" for attempt in range(MAX_ATTEMPTS): # 进度条提示:显示当前尝试次数 if progress: progress(0.5 + (attempt * 0.1), desc=f"Coding (Attempt {attempt + 1}/{MAX_ATTEMPTS})...") logger.info("[attempt] coder attempt %s/%s", attempt + 1, MAX_ATTEMPTS) coder_messages = [{ "role": "system", "content": CODER_SYSTEM_PROMPT, }, { "role": "user", "content": _render_code_prompt( datasets=datasets, plan=plan_text, tool_info=tool_info, ), }] _log_full("coder_request_messages", json.dumps(coder_messages, ensure_ascii=False)) try: coder_response = _generate_text( coder_messages, max_new_tokens=CODE_MAX_TOKENS, temperature=0.6, top_p=0.95, ) except Exception as e: return f"Coder error: {e}" _log_full("coder_raw_response", coder_response) blocks = _parse_code_block(coder_response) syn_code = "" for block in blocks: if "data-processing code block" in block: syn_code = block break if not syn_code and blocks: for block in blocks: if "test code block" in block: continue syn_code = block break if not syn_code and blocks: syn_code = blocks[0] if not syn_code: syn_code = coder_response if attempt == 0: first_code = syn_code cleaned_code = _strip_code_fence(syn_code) _log_full("coder_cleaned_code", cleaned_code) ok, _ = _syntax_check(AIDP_HEADER + cleaned_code.strip() + "\n") if ok: if progress: progress(1.0, desc="Code Generation Complete") return cleaned_code else: if progress: progress(0.5 + (attempt * 0.1), desc=f"Retrying ({attempt + 1})...") return first_code def _execute_generated_code_remote(code_text: str, timeout: int = RUN_CODE_TIMEOUT, task_description: str = "") -> Tuple[str, str, str]: """ Returns: eval_md: HTML table of evaluation scores and samples. preview_md: HTML table of processed data previews (kept for compatibility). logs_txt: Plain text of stdout/stderr/status. """ if not code_text or not code_text.strip(): return "", "", "No code to run. Please generate code first." _log_preview("run_code_request", code_text, limit=800) if task_description: _log_preview("run_code_task_description", task_description, limit=400) logger.info("[run_code] POST %s timeout=%ss", RUN_CODE_URL, timeout) headers = {} if RUN_CODE_KEY: headers["x-forward-key"] = RUN_CODE_KEY try: resp = requests.post( RUN_CODE_URL, json={"code": code_text, "timeout": timeout, "task_description": task_description}, headers=headers, timeout=min(timeout, 30) + 10, ) resp.raise_for_status() data = resp.json() if resp.content else {} except Exception as e: logger.exception("[run_code] request failed") return "", "", f"Run error: {e}" job_id = data.get("job_id") if not job_id: return "", "", "Run error: runner did not return job_id." status_base = RUN_CODE_URL.rsplit("/", 1)[0].rstrip("/") status_url = f"{status_base}/status/{job_id}" deadline = time.time() + timeout poll_interval = 2 result_data = None while time.time() < deadline: try: r = requests.get(status_url, headers=headers, timeout=min(timeout, 30)) r.raise_for_status() job = r.json() if r.content else {} except Exception as e: # noqa: BLE001 logger.warning("[run_code] status poll error: %s", e) time.sleep(poll_interval) continue status = job.get("status") if status in {"success", "failed", "timeout", "error"}: result_data = job break time.sleep(poll_interval) if result_data is None: return "", "", f"Run error: job {job_id} did not complete within {timeout}s" data = result_data # 1. Parse Logs status = data.get("status", "unknown") retcode = data.get("returncode", "n/a") summary = data.get("summary", "") stdout = data.get("stdout", "") stderr = data.get("stderr", "") duration = data.get("duration_sec") logger.info( "[run_code] status=%s returncode=%s duration=%s summary=%s", status, retcode, f"{duration:.2f}s" if duration is not None else "n/a", _truncate_text(summary or "", limit=200), ) logs = [f"Status: {status} (returncode={retcode})"] if duration is not None: logs.append(f"Duration: {duration:.2f}s") if summary: logs.append(f"Summary: {summary}") if stdout: _log_preview("run_code_stdout", stdout, limit=800) logs.append("\n--- stdout ---\n" + stdout) if stderr: _log_preview("run_code_stderr", stderr, limit=800) logs.append("\n--- stderr ---\n" + stderr) logs_txt = "\n".join(logs) # 2. Parse Evaluation Table -> Renamed to Processed Training Data data_scores = data.get("data_scores") or [] data_score_error = data.get("data_score_error") eval_md = "" if data_scores: eval_md = _render_processed_scores_table(data_scores) elif data_score_error: eval_md = f"**Data Processing Error:** {data_score_error}" else: eval_md = "No data returned from execution." # 3. Processed Preview (Unused in new UI but kept for compatibility) preview_md = "" return eval_md, preview_md, logs_txt def _run_full_automation( task_description: str, benchmark_name: str, benchmark_description: str, dataset_ids_table: List[List[str]] ): """ Generator that runs the entire pipeline: Preview -> Plan -> Code -> Execute. Yields: (status_html, preview_html, plan_md, plan_raw, code_src, eval_html) """ empty = "" loading_html = "" dataset_ids = _normalize_dataset_ids(dataset_ids_table) # 数量限制检查 if len(dataset_ids) > 10: err_msg = f"Resource Limit: Please select no more than 10 datasets (currently selected {len(dataset_ids)}). Limited resources available." yield ( gr.update(value=_render_progress_html(0, error=True, detail_msg=err_msg), visible=True), f"
    {err_msg}
    ", empty, empty, empty, empty ) return # --- Step 1: Data Preview --- # 计算预览时间:<5个 60s,>=5个 120s preview_time = 60 if len(dataset_ids) < 5 else 120 yield ( gr.update(value=_render_progress_html(0, detail_msg="Fetching dataset info...", est_time=preview_time), visible=True), loading_html, empty, empty, empty, empty ) # 执行 Preview preview_md = _preview_dataset(dataset_ids_table) # --- Step 2: Planning --- # Plan 时间:固定 60s yield ( gr.update(value=_render_progress_html(1, detail_msg="Designing data recipe...", est_time=60)), preview_md, loading_html, empty, empty, empty ) # 构建 Sample sample, error = _build_sample_from_inputs(task_description, benchmark_name, benchmark_description, dataset_ids_table) if error: err_html = _render_progress_html(1, error=True, detail_msg="Dataset loading failed") yield (gr.update(value=err_html), preview_md, f"❌ Error: {error}", f"❌ {error}", empty, empty) return # 生成 Plan plan_text = _generate_plan_from_sample(sample) if plan_text.startswith("Planner error"): err_html = _render_progress_html(1, error=True, detail_msg="Planner failed") yield (gr.update(value=err_html), preview_md, plan_text, plan_text, empty, empty) return # --- Step 3: Coding --- # Coding 时间:固定 60s yield ( gr.update(value=_render_progress_html(2, detail_msg="Generating Python script...", est_time=60)), preview_md, plan_text, plan_text, loading_html, empty ) # 生成 Code code_text = _generate_code_from_plan(sample["datasets"], plan_text) if code_text.startswith("Coder error"): err_html = _render_progress_html(2, error=True, detail_msg="Coder failed") yield (gr.update(value=err_html), preview_md, plan_text, plan_text, code_text, empty) return # --- Step 4: Execution --- # Execution 时间:固定 180s yield ( gr.update(value=_render_progress_html(3, detail_msg="Executing code on remote runner...", est_time=180)), preview_md, plan_text, plan_text, code_text, "Processing data... this may take a moment" ) # 执行 Code exec_task_desc = _compose_task_context(task_description, benchmark_description) eval_md, _, logs_txt = _execute_generated_code_remote(code_text, task_description=exec_task_desc) # 检查运行结果 is_run_error = False run_msg = "Done" if "Run error" in logs_txt or "timeout" in logs_txt.lower() or "traceback" in logs_txt.lower(): if "Status: success" not in logs_txt: is_run_error = True run_msg = "Execution failed or timed out" if is_run_error: final_progress = _render_progress_html(3, error=True, detail_msg=run_msg) if "timeout" in logs_txt.lower(): eval_md = f"
    Timeout: Code execution took too long. Please simplify the logic or try fewer samples.

    " + eval_md else: eval_md = f"
    Runtime Error: Check the code or logs.

    " + eval_md yield ( gr.update(value=final_progress), preview_md, plan_text, plan_text, code_text, eval_md ) else: # 正常完成,隐藏进度条 yield ( gr.update(visible=False), preview_md, plan_text, plan_text, code_text, eval_md ) def _format_model_status() -> str: model_id = _resolve_model_id() if _use_remote_llm(): auth = "with api key" if (_REMOTE_API_KEY or os.getenv("OPENAI_API_KEY")) else "dummy key" return f"Remote vLLM @ {_REMOTE_API_BASE or 'unset'} ({auth}), model={model_id}" return "Remote vLLM not configured (set DATACHEF_VLLM_URL)." def _warmup_model() -> str: return _format_model_status() # --- CSS 样式定义 (全局) --- _CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500&family=Inter:wght@400;500;600;700;800&display=swap'); :root { --primary-col: #10b981; --panel-bg: #ffffff; --border-col: #e2e8f0; } .gradio-container { font-family: 'Inter', system-ui, sans-serif !important; background-color: #f8fafc !important; } /* 标题区域 */ .header-container { padding: 1.5rem 0 1rem; margin-bottom: 1rem; border-bottom: 1px solid var(--border-col); } /* 卡片容器 */ .control-panel { background: var(--panel-bg); border: 1px solid var(--border-col); border-radius: 12px; padding: 1.25rem; box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.05); } /* 右侧输出面板 */ .output-panel { background: var(--panel-bg); border: 1px solid var(--border-col); border-radius: 12px; padding: 1.25rem; box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.05); display: flex; flex-direction: column; justify-content: flex-start; height: 100%; } /* 绿色胶囊标题 */ .green-header { background-color: #ecfdf5; color: #047857; border: 1px solid #a7f3d0; padding: 4px 12px; border-radius: 6px; font-size: 0.85rem; font-weight: 700; display: inline-flex; align-items: center; margin-bottom: 12px; margin-top: 5px; letter-spacing: 0.02em; font-family: 'Inter', sans-serif; } /* 状态栏 */ .status-bar-compact { background: #f1f5f9; border-radius: 8px; padding: 8px 12px; margin-bottom: 1.25rem; display: flex; align-items: center; justify-content: space-between; font-size: 0.85rem; } /* 进度条样式 */ .step-container { background: #fff; padding: 12px 20px; border-radius: 8px; border: 1px solid #e2e8f0; margin-bottom: 15px; box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); } .step-row { display: flex; align-items: center; justify-content: center; width: 100%; } .step-item { display: flex; align-items: center; gap: 8px; color: #94a3b8; font-weight: 600; font-size: 0.9rem; } .step-icon { width: 24px; height: 24px; border-radius: 50%; display: flex; align-items: center; justify-content: center; font-size: 11px; border: 2px solid #cbd5e1; background: white; color: #64748b; font-weight: 700; } .step-item.done { color: #059669; } .step-item.done .step-icon { background: #10b981; border-color: #10b981; color: white; } .step-item.active { color: #0ea5e9; } .step-item.active .step-icon { border-color: #0ea5e9; color: #0ea5e9; background: #e0f2fe; animation: pulse 2s infinite; } .step-arrow { color: #cbd5e1; margin: 0 12px; font-size: 12px; } .step-status-text { margin-top: 8px; font-size: 0.8rem; color: #64748b; text-align: center; height: 16px; } @keyframes pulse { 0% { box-shadow: 0 0 0 0 rgba(14, 165, 233, 0.4); } 70% { box-shadow: 0 0 0 6px rgba(14, 165, 233, 0); } 100% { box-shadow: 0 0 0 0 rgba(14, 165, 233, 0); } } /* === 左侧表格与按钮美化 (Gradio Dataframe) === */ .unified-dataframe table { font-size: 12px !important; } .unified-dataframe td, .unified-dataframe th { padding: 6px 8px !important; } .unified-dataframe { margin-top: 0 !important; margin-bottom: 0 !important; } .sub-label { font-size: 0.85rem; font-weight: 700; color: #059669; margin-bottom: 6px; display: flex; align-items: center; gap: 6px; white-space: nowrap; } /* 小操作按钮 */ .action-btn { font-size: 0.75rem !important; padding: 0px 10px !important; height: 28px !important; min-height: 28px !important; border: 1px solid #cbd5e1 !important; background: linear-gradient(to bottom, #ffffff, #f8fafc); color: #475569 !important; border-radius: 6px !important; box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); } .action-btn:hover { background: #f1f5f9 !important; border-color: #94a3b8 !important; color: #0f172a !important; } .delete-btn { color: #b91c1c !important; border-color: #fca5a5 !important; background: #fef2f2 !important; } .delete-btn:hover { background: #fee2e2 !important; border-color: #f87171 !important; } /* 搜索按钮特化 */ .search-btn-primary { font-weight: 800 !important; font-size: 0.9rem !important; text-transform: uppercase; letter-spacing: 0.05em; box-shadow: 0 2px 4px rgba(16, 185, 129, 0.2); } /* === 右侧结果表格美化 (HTML Table) === */ .data-table-host { width: 100%; } .table-wrapper { width: 100%; overflow-x: auto; } .data-table { width: 100%; border-collapse: separate; border-spacing: 0; border: 1px solid #e2e8f0; border-radius: 8px; font-size: 0.85rem; background-color: #fff; table-layout: auto; /* 允许根据内容自动调整列宽 */ } .data-table thead th { background: #ecfdf5; color: #166534; font-weight: 700; padding: 10px 12px; border-bottom: 1px solid #bbf7d0; text-align: left; white-space: nowrap; } .data-table td { padding: 8px 12px; border-bottom: 1px solid #f1f5f9; color: #334155; vertical-align: top; /* 核心修复:防止 text 溢出,同时允许长单词换行 */ word-wrap: break-word; max-width: 600px; } .data-table tr:last-child td { border-bottom: none; } .data-table tr:hover td { background-color: #f8fafc; } /* 样本列特化:保证宽度,防止被挤压 */ .sample-cell { min-width: 300px; width: 50%; } .sample-details summary { cursor: pointer; font-weight: 600; color: #047857; outline: none; margin-bottom: 4px; } .sample-pre { margin-top: 6px; padding: 10px; background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 6px; font-family: 'JetBrains Mono', monospace; font-size: 0.75rem; /* 12px */ color: #0f172a; /* 核心修复:保证 JSON 代码块自动换行,不撑破表格 */ white-space: pre-wrap; word-break: break-all; max-height: 400px; /* 过长则滚动 */ overflow-y: auto; } .error-block { margin-top: 12px; padding: 10px 12px; background: #fef2f2; border: 1px solid #fecdd3; border-radius: 10px; color: #991b1b; } /* 通用工具类 */ .no-label-input > label > span { display: none; } .mt-0 { margin-top: 0 !important; } .mt-2 { margin-top: 0.5rem !important; } .mt-4 { margin-top: 1rem !important; } .mb-2 { margin-bottom: 0.5rem !important; } .gap-2 { gap: 0.5rem; } .code-box { font-size: 12px !important; } .gradio-container .tabs { border-bottom: 1px solid #e2e8f0; } .control-panel .gradio-dataframe { margin-top: 0 !important; margin-bottom: 0 !important; } .btn-row { display: flex; flex-direction: row; align-items: center; } .progress-bar-container { width: 100%; background-color: #f1f5f9; border-radius: 10px; /* 圆角稍微大一点更好看 */ height: 8px; margin-top: 8px; overflow: hidden; border: 1px solid #e2e8f0; } .progress-bar-fill { height: 100%; background: linear-gradient(90deg, #0ea5e9, #22c55e); /* 增加渐变效果 */ width: 0%; } .est-time-text { font-size: 0.7rem; color: #94a3b8; margin-top: 4px; text-align: center; font-family: 'JetBrains Mono', monospace; } /* === 新增: 用户指南样式 (保持不变) === */ .info-box { background: #f0f9ff; border: 1px solid #bae6fd; color: #0369a1; padding: 12px 16px; border-radius: 8px; font-size: 0.85rem; line-height: 1.5; margin-bottom: 16px; } @keyframes progressFill { 0% { width: 0%; } 90% { width: 90%; } /* 稍微留一点余地,防止还没跑完就满了 */ 100% { width: 95%; } } """ def create_demo() -> gr.Blocks: global _DEMO_THEME, _DEMO_CSS # 1. 配置主题和 CSS theme = gr.themes.Soft( primary_hue="emerald", neutral_hue="slate", radius_size="md", text_size="sm", font=[gr.themes.GoogleFont("Inter"), "sans-serif"], font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "monospace"], ) _DEMO_THEME = theme _DEMO_CSS = _CUSTOM_CSS # 2. 准备默认数据 default_sample = DEFAULT_SAMPLE default_task = default_sample.get("task", {}) default_benchmark = default_task.get("benchmark", {}) default_dataset_rows: List[List[str]] = [] if default_sample.get("datasets"): default_dataset_rows = [[d.get("dataset_id", "")] for d in default_sample["datasets"] if d.get("dataset_id")] # 3. 辅助 HTML 函数 def get_header_html(): return f"""
    Logo

    DataChef

    Automatic Data Recipe Generation for LLM Adaptation
    """ def get_status_text(): model_id = _resolve_model_id() if _use_remote_llm(): status_dot = "" status_text = f"Ready" model_info = f"{model_id}" else: status_dot = "" status_text = f"Not Configured" model_info = "Remote vLLM unset" return f"
    {status_dot} {status_text}
    {model_info}
    " # 4. 构建 UI with gr.Blocks(title="DataChef") as demo: # --- 顶部 Header --- with gr.Row(elem_classes=["header-container"]): gr.HTML(get_header_html()) # --- 主体布局 --- with gr.Row(): # =============== 左侧控制面板 (Left Column) =============== with gr.Column(scale=32, elem_classes=["control-panel"]): # 顶部状态栏 with gr.Row(elem_classes=["status-bar-compact"]): status_html = gr.HTML(get_status_text()) load_model_btn = gr.Button("↻", size="sm", variant="secondary", min_width=30) # user guide gr.HTML("""
    👋 Welcome to DataChef Demo
    """) # 预设选择 gr.HTML("
    📋 Configuration Presets
    ") # 修复:定义 preset_dropdown preset_dropdown = gr.Dropdown(choices=list(PRESETS.keys()), value="Physics", show_label=False, container=False, interactive=True) # 任务描述 gr.HTML("
    🎯 Task Description
    ") task_description = gr.Textbox(show_label=False, lines=3, value=default_task.get("description", ""), placeholder="Describe task...", elem_classes=["no-label-input"]) # Benchmark 设置 gr.HTML("
    ⚖️ Benchmark Settings
    ") benchmark_name = gr.Textbox(label="Name", placeholder="Benchmark Name (e.g., AIME2025)", value=default_benchmark.get("name", ""), lines=1) benchmark_description = gr.Textbox( label="Description (Optional)", placeholder="Recommand: Provide description and test examples to guide the generation (e.g., AIME is xxx. Example: {'question': 'Calculate the velocity of...', xxx})", lines=3, value=default_benchmark.get("description", "") ) # 数据集选择区域 gr.HTML("
    📚 Raw Dataset Sources
    ") # 1. Suggested Dataset Area with gr.Row(elem_classes=["mb-2", "btn-row"]): gr.HTML("
    🔍 Suggested
    ", scale=0) auto_ds_btn = gr.Button("🔎 Auto Search", variant="primary", scale=1, elem_classes=["search-btn-primary"]) suggested_ds = gr.Dataframe( headers=["Dataset ID"], datatype=["str"], row_count=(3, "fixed"), value=[], interactive=False, wrap=True, elem_classes=["unified-dataframe"] ) # 中间操作按钮组 with gr.Row(elem_classes=["btn-row", "mt-2", "gap-2"]): add_suggested_btn = gr.Button("⬇️ Add Selected", size="sm", elem_classes=["action-btn"]) add_all_suggested_btn = gr.Button("⏬ Add All", size="sm", elem_classes=["action-btn"]) # 2. Selected Dataset Area gr.HTML("
    ✅ Selected
    ") # 修复:col_count -> column_count dataset_ids = gr.Dataframe( headers=["Dataset ID"], datatype=["str"], row_count=(3, "dynamic"), value=default_dataset_rows, interactive=True, column_count=(1, "fixed"), wrap=True, elem_classes=["unified-dataframe"] ) # Selected 表格操作按钮 selected_ds_idx = gr.State(value=None) with gr.Row(elem_classes=["btn-row", "mt-2", "gap-2"]): del_row_btn = gr.Button("🗑️ Delete", size="sm", elem_classes=["action-btn", "delete-btn"]) add_row_btn = gr.Button("➕ Add Row", size="sm", elem_classes=["action-btn"]) # 状态存储 suggested_selected_idx = gr.State(value=None) # 运行按钮 with gr.Row(elem_classes=["mt-4"]): run_btn = gr.Button("🚀 Generate Recipe", variant="primary", size="lg") # =============== 右侧输出面板 (Right Column) =============== with gr.Column(scale=68): # 进度状态 process_status = gr.HTML(visible=False, label=None) # Plan & Code 分栏视图 with gr.Row(equal_height=False, elem_classes=["mt-0"]): # Plan with gr.Column(variant="panel", elem_classes=["output-panel"]): gr.HTML("
    🧠 Processing Plan
    ") with gr.Tabs(): with gr.Tab("Rendered"): plan_out_md = gr.Markdown(value="_Waiting..._", line_breaks=True) with gr.Tab("Source"): plan_out_raw = gr.Code(language="markdown", interactive=False, lines=15, elem_classes=["code-box"]) # Code with gr.Column(variant="panel", elem_classes=["output-panel"]): gr.HTML("
    💻 Executable Code
    ") code_out = gr.Code(language="python", interactive=False, lines=18, elem_classes=["code-box"]) # 结果面板 with gr.Column(variant="panel", elem_classes=["output-panel", "mt-4"]): gr.HTML("
    📊 Data Inspector & Results
    ") with gr.Tabs(): with gr.Tab("🏆 Processed Training Data"): eval_output_html = gr.HTML("Recipe generation not started yet.", elem_classes=["data-table-host"]) with gr.Tab("📥 Raw Data"): preview_out_html = gr.HTML("Waiting for input...", elem_classes=["data-table-host"]) # --- 事件绑定 --- # 此时 preset_dropdown 已经定义 preset_dropdown.change( fn=_load_preset_config, inputs=[preset_dropdown], outputs=[task_description, benchmark_name, benchmark_description, dataset_ids] ) # 自动搜索 auto_ds_btn.click( fn=_auto_suggest_datasets_ui, inputs=[task_description, benchmark_description], outputs=[suggested_ds, suggested_selected_idx], ) # 选中 Suggested 行 suggested_ds.select( _on_df_select, outputs=[suggested_selected_idx], ) # Add Suggested (Single) add_suggested_btn.click( fn=_add_selected_suggestion, inputs=[dataset_ids, suggested_ds, suggested_selected_idx], outputs=[dataset_ids], ) # Add All Suggested add_all_suggested_btn.click( fn=_add_all_suggestions, inputs=[dataset_ids, suggested_ds], outputs=[dataset_ids], ) # 选中 Selected 行 (用于删除) dataset_ids.select( _on_df_select, outputs=[selected_ds_idx] ) # 删除 Selected 行 del_row_btn.click( fn=_delete_selected_row, inputs=[dataset_ids, selected_ds_idx], outputs=[dataset_ids, selected_ds_idx] ) # 添加空行到 Selected add_row_btn.click( fn=_add_empty_row, inputs=[dataset_ids], outputs=[dataset_ids] ) # 运行逻辑 run_btn.click( _run_full_automation, inputs=[task_description, benchmark_name, benchmark_description, dataset_ids], outputs=[process_status, preview_out_html, plan_out_md, plan_out_raw, code_out, eval_output_html], show_progress="hidden" ) load_model_btn.click(_warmup_model, outputs=[]).then(get_status_text, outputs=[status_html]) demo.load(_warmup_model, outputs=[]).then(get_status_text, outputs=[status_html]) return demo demo = create_demo() if __name__ == "__main__": demo.launch( theme=_DEMO_THEME, css=_DEMO_CSS, server_name="0.0.0.0", share=True, )