" in cleaned:
cleaned = cleaned.split("", 1)[0]
cleaned = cleaned.replace("", "")
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
return cleaned.strip()
def _escape_html(text: Any) -> str:
return html.escape(str(text), quote=True)
def _render_collapsible_json(title: str, data: Any, preview: Optional[str] = None) -> str:
"""Render a collapsible JSON block with optional preview text."""
pretty = json.dumps(data, ensure_ascii=False, indent=2)
if preview is None:
preview = str(data)
if len(preview) > 140:
preview = preview[:140] + "..."
return (
""
f"{_escape_html(preview or title)}
"
f"{_escape_html(pretty)}"
" "
)
def _render_processed_scores_table(data_scores: List[Dict]) -> str:
"""Render processed training data as an HTML table with collapsible samples."""
if not data_scores:
return "No processed data returned.
"
judgement_map = {
"A": "A: Invalid",
"B": "B: Format Error",
"C": "C: Incorrect",
"D": "D: Task Mismatch",
"E": "E: Pass",
}
rows_html = []
for idx, s in enumerate(data_scores, start=1):
sample = {
"user": s.get("question") or s.get("question_preview") or s.get("input") or "",
"assistant": s.get("answer") or s.get("answer_preview") or s.get("output") or "",
}
judgement = s.get("judgement", s.get("feedback", ""))
judgement_disp = judgement_map.get(str(judgement).strip().upper(), judgement or "")
score_val = s.get("score")
score_disp = f"{score_val:.2f}" if isinstance(score_val, (int, float)) else _escape_html(score_val if score_val is not None else "")
sample_preview = sample["user"] if isinstance(sample["user"], str) else ""
sample_cell = _render_collapsible_json("sample", sample, preview=sample_preview or "View sample")
rows_html.append(
""
f"| {_escape_html(s.get('index', idx))} | "
f"{score_disp} | "
f"{_escape_html(judgement_disp)} | "
f"{sample_cell} | "
"
"
)
rows_joined = "\n".join(rows_html)
table_html = (
''
'
'
""
"| # | score | judgement | sample |
"
""
f"{rows_joined}"
"
"
"
"
)
return table_html
def _render_preview_table_html(dataset_infos: List[Dict], errors: Optional[Dict[str, str]] = None) -> str:
"""Render raw preview data as an HTML table with collapsible samples."""
rows_html = []
idx = 1
for info in dataset_infos or []:
ds_id = info.get("dataset_id", "Unknown")
for ex_group in info.get("examples", []):
for sample in ex_group.get("preview_examples", []):
sample_cell = _render_collapsible_json("sample", sample)
rows_html.append(
""
f"| {idx} | "
f"{_escape_html(ds_id)} | "
f"{sample_cell} | "
"
"
)
idx += 1
if rows_html:
base = (
""
"
"
"| # | Source Dataset | sample |
"
f"{''.join(rows_html)}"
"
"
)
else:
base = "No preview samples available.
"
if errors:
err_lines = "".join(f"{_escape_html(k)}: {_escape_html(v)}" for k, v in errors.items())
base += f""
return base
def _truncate_dataset_examples_for_prompt(datasets: List[Dict], limit: int = TRUNCATE_LIMIT) -> List[Dict]:
"""Return a truncated copy of datasets for LLM prompts to avoid context overflow."""
truncated: List[Dict] = []
for ds in datasets or []:
ds_copy = {
"dataset_id": ds.get("dataset_id"),
"revision": ds.get("revision"),
"examples": [],
}
for ex in ds.get("examples", []):
ex_copy = {
"name": ex.get("name"),
"split": ex.get("split"),
"schema": ex.get("schema"),
"preview_examples": [],
}
for sample in ex.get("preview_examples", []):
sample_copy = {}
for k, v in sample.items():
val = json.dumps(v, ensure_ascii=False) if not isinstance(v, str) else v
sample_copy[k] = _truncate_text(val, limit=limit)
ex_copy["preview_examples"].append(sample_copy)
ds_copy["examples"].append(ex_copy)
truncated.append(ds_copy)
return truncated
def _render_plan_prompt(task_description: str, benchmark: Dict[str, str], datasets: List[Dict]) -> str:
datasets = _truncate_dataset_examples_for_prompt(datasets)
parts = [
"# 任务描述",
task_description,
"",
"# 测试数据 (Benchmark)",
f"## {benchmark.get('name', '')}",
benchmark.get("description", ""),
"",
"# 可用的 Huggingface 训练数据集",
]
for item in datasets:
parts.append(f"## {item['dataset_id']}")
parts.append(str(item["examples"]))
parts.append("")
parts.append("---")
parts.append(PLAN_INSTRUCTIONS.strip())
return "\n".join(parts).strip()
def _render_code_prompt(datasets: List[Dict], plan: str, tool_info: str) -> str:
datasets = _truncate_dataset_examples_for_prompt(datasets)
parts = [
"# 可用的 Huggingface 训练数据集",
]
for item in datasets:
parts.append(f"## {item['dataset_id']}")
parts.append(str(item["examples"]))
parts.append("")
parts.extend([
"# 数据处理计划",
plan,
"",
"# 工具信息",
tool_info,
"",
"---",
CODE_INSTRUCTIONS.strip(),
])
return "\n".join(parts).strip()
def _compose_task_context(task_description: str, benchmark_description: str) -> str:
"""Combine task description and benchmark description for downstream execution/logging."""
benchmark_description = benchmark_description or ""
if benchmark_description.strip():
return f"{task_description}\n\nBenchmark: {benchmark_description}"
return task_description
def _build_tool_info() -> str:
return TOOL_INFO.strip()
def _resolve_model_id() -> str:
return os.getenv("DATACHEF_MODEL_PATH") or MODEL_ID
def _use_remote_llm() -> bool:
return bool(_REMOTE_API_BASE)
def _get_remote_client() -> OpenAI:
if not _REMOTE_API_BASE:
raise RuntimeError("Remote base URL is empty.")
global _REMOTE_CLIENT
if _REMOTE_CLIENT is not None:
return _REMOTE_CLIENT
with _REMOTE_CLIENT_LOCK:
if _REMOTE_CLIENT is not None:
return _REMOTE_CLIENT
# Some self-hosted vLLM endpoints do not require authentication; OpenAI client
# still expects a non-empty api_key, so provide a dummy token when missing.
api_key = _REMOTE_API_KEY or os.getenv("OPENAI_API_KEY") or "EMPTY"
_REMOTE_CLIENT = OpenAI(
base_url=_REMOTE_API_BASE,
api_key=api_key,
default_headers=_REMOTE_DEFAULT_HEADERS,
)
return _REMOTE_CLIENT
def _generate_text_remote(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float) -> str:
client = _get_remote_client()
_log_full("llm_request", json.dumps({"messages": messages, "max_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, ensure_ascii=False))
try:
completion = client.chat.completions.create(
model=_resolve_model_id(),
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
except Exception as e: # noqa: BLE001
raise RuntimeError(f"Remote LLM error: {e}") from e
content = completion.choices[0].message.content if completion.choices else ""
_log_full("llm_response", content or "")
return _strip_think(content or "")
def _generate_text(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float) -> str:
if not _use_remote_llm():
raise RuntimeError("DATACHEF_VLLM_URL is not set; please configure remote vLLM endpoint.")
return _generate_text_remote(messages, max_new_tokens, temperature, top_p)
def _request_keywords_via_proxy(task_description: str, benchmark_description: str, n: int = 3) -> List[str]:
"""
Ask proxy /generate_keywords to produce search keywords.
Falls back to simple heuristics if proxy fails.
"""
base = RUN_CODE_URL.rsplit("/", 1)[0].rstrip("/")
url = f"{base}/generate_keywords"
headers = {}
if RUN_CODE_KEY:
headers["x-forward-key"] = RUN_CODE_KEY
payload = {
"task_description": task_description,
"benchmark_description": benchmark_description,
"n": n,
}
try:
resp = requests.post(url, json=payload, headers=headers, timeout=20)
resp.raise_for_status()
data = resp.json() if resp.content else {}
keywords = data.get("keywords") or []
keywords = [k.strip() for k in keywords if isinstance(k, str) and k.strip()]
if keywords:
return keywords[:n]
except Exception as e:
logger.warning("generate_keywords proxy failed: %s", e)
# fallback: basic split of task description
fallback = []
for token in re.split(r"[;,.,。\\n]+", task_description):
token = token.strip()
if len(token.split()) <= 5 and token:
fallback.append(token)
if len(fallback) >= n:
break
return fallback
def _search_hf_datasets(keywords: List[str], limit_per_kw: int = 5) -> List[str]:
"""Search HuggingFace datasets for each keyword and return dataset ids."""
api = HfApi()
found: List[str] = []
seen = set()
for kw in keywords:
try:
try:
# huggingface_hub>=0.36 uses `search` instead of the deprecated `full_text_search`
results = api.list_datasets(
search=kw,
limit=limit_per_kw,
sort="likes",
direction=-1, # descending
)
except TypeError:
# fallback for older hub versions
results = api.list_datasets(
full_text_search=kw,
limit=limit_per_kw,
sort="likes",
)
except Exception as e:
logger.warning("HF search failed for '%s': %s", kw, e)
continue
for ds in results:
ds_id = getattr(ds, "id", None) or getattr(ds, "dataset", None)
if ds_id and ds_id not in seen:
seen.add(ds_id)
found.append(ds_id)
return found
def _auto_suggest_datasets(task_description: str, benchmark_description: str) -> List[List[str]]:
if not task_description:
return []
benchmark_description = benchmark_description or ""
keywords = _request_keywords_via_proxy(task_description, benchmark_description, n=3)
ds_ids = _search_hf_datasets(keywords, limit_per_kw=3)
return [[d] for d in ds_ids]
def _rows(table) -> List[List[str]]:
"""Normalize Gradio dataframe output (which may be pandas DataFrame) to list of rows."""
if table is None:
return []
if hasattr(table, "to_numpy"):
try:
return table.to_numpy().tolist()
except Exception:
return []
return table or []
def _add_selected_suggestion(
current,
suggestions,
idx: Optional[int],
) -> List[List[str]]:
rows_current = _rows(current)
rows_suggestions = _rows(suggestions)
if idx is None or idx < 0 or idx >= len(rows_suggestions):
return rows_current
val = str(rows_suggestions[idx][0]).strip() if rows_suggestions[idx] else ""
if not val:
return rows_current
existing = [str(r[0]).strip() for r in rows_current if r and len(r) > 0]
if val in existing:
return rows_current
return rows_current + [[val]]
def _add_all_suggestions(
current,
suggestions,
) -> List[List[str]]:
rows_current = _rows(current)
rows_suggestions = _rows(suggestions)
existing = [str(r[0]).strip() for r in rows_current if r and len(r) > 0]
merged = list(rows_current)
for row in rows_suggestions:
val = str(row[0]).strip() if row else ""
if val and val not in existing:
existing.append(val)
merged.append([val])
return merged
def _add_empty_row(current) -> List[List[str]]:
"""在当前表格末尾添加空行"""
rows = _rows(current)
# 确保追加的是一个 list 包含空字符串,对应单列
return rows + [[""]]
def _delete_selected_row(current, idx: Optional[int]) -> Tuple[List[List[str]], None]:
"""删除选中的行"""
rows = _rows(current)
if idx is not None and 0 <= idx < len(rows):
rows.pop(idx)
# 返回处理后的行,并重置选中状态为 None
return rows, None
def _auto_suggest_datasets_ui(task_description: str, benchmark_description: str) -> Tuple[List[List[str]], Optional[int]]:
suggestions = _auto_suggest_datasets(task_description, benchmark_description)
return suggestions, None
def _feature_repr(feature) -> str:
if isinstance(feature, Value):
if feature.dtype in ("string", "large_string"):
return "str"
return feature.dtype
if isinstance(feature, ClassLabel):
return "ClassLabel"
if isinstance(feature, Sequence):
return f"List[{_feature_repr(feature.feature)}]"
return str(feature)
def _feature_repr_from_hf(feature: Dict) -> str:
"""Lightweight repr using datasets-server feature schema."""
if not isinstance(feature, dict):
return "unknown"
type_info = feature.get("type", {})
if not isinstance(type_info, dict):
return str(type_info) if type_info is not None else "unknown"
type_name = type_info.get("_type")
if type_name == "Value":
dtype = type_info.get("dtype")
if dtype in ("string", "large_string"):
return "str"
return dtype or "Value"
if type_name == "ClassLabel":
return "ClassLabel"
if type_name == "Sequence":
inner = type_info.get("feature", {})
inner_repr = _feature_repr_from_hf({"type": inner}) if isinstance(inner, dict) else "Any"
return f"List[{inner_repr}]"
return type_name or str(type_info) or "unknown"
def _run_with_timeout(func, timeout: int, *args, **kwargs):
"""Run a callable with a timeout; cancel the future if it takes too long."""
executor = ThreadPoolExecutor(max_workers=1)
future = executor.submit(func, *args, **kwargs)
try:
return future.result(timeout=timeout)
finally:
executor.shutdown(wait=False, cancel_futures=True)
import time # 确保开头导入了 time
def _render_progress_html(current_step: int, error: bool = False, detail_msg: str = "", est_time: int = 0) -> str:
steps = ["Data Preview", "Processing Plan", "Code Generation", "Code Execution"]
icons_html = ''
active_progress_html = ""
# 生成一个唯一的 ID 和 动画名,确保每次切换步骤时浏览器都会重置动画
unique_id = f"step_bar_{current_step}_{int(time.time())}"
anim_name = f"fill_{unique_id}"
for idx, name in enumerate(steps):
is_active = (idx == current_step)
if idx < current_step:
status_class, icon = "done", "✓"
elif is_active:
if error:
status_class, icon = "error", "✕"
elif current_step == 4:
status_class, icon = "done", "✓"
else:
status_class, icon = "active", "⋯"
if est_time > 0:
# 通过内联 style 定义一个新的 keyframes 动画
active_progress_html = f"""
"""
else:
status_class, icon = "", str(idx + 1)
icons_html += f"""
"""
if idx < len(steps) - 1:
icons_html += '
➜
'
icons_html += '
'
text_html = f'{detail_msg}
'
return f'{icons_html}{active_progress_html}{text_html}
'
def _get_dataset_examples(dataset_id: str, timeout: Optional[int] = None) -> Dict:
"""Load dataset examples with optional timeout and cache results to avoid reloading."""
if dataset_id in _DATASET_PREVIEW_CACHE:
return _DATASET_PREVIEW_CACHE[dataset_id]
errors: List[str] = []
try:
if timeout:
result = _run_with_timeout(_build_dataset_examples, timeout, dataset_id)
else:
result = _build_dataset_examples(dataset_id)
_DATASET_PREVIEW_CACHE[dataset_id] = result
return result
except FuturesTimeoutError:
errors.append(f"Preview for {dataset_id} timed out after {timeout}s")
except Exception as e:
errors.append(str(e))
# Fallback to legacy full load if lightweight preview fails
try:
result = _build_dataset_examples_legacy(dataset_id)
_DATASET_PREVIEW_CACHE[dataset_id] = result
return result
except Exception as e:
errors.append(str(e))
raise RuntimeError("; ".join(errors))
def _run_code_probe(code_text: str, timeout: int) -> Tuple[bool, str]:
"""
Run lightweight code on the proxy runner to test dataset loading.
Returns (success, error_message).
"""
if not RUN_CODE_URL:
return False, "RUN_CODE_URL not configured"
headers = {}
if RUN_CODE_KEY:
headers["x-forward-key"] = RUN_CODE_KEY
try:
resp = requests.post(
RUN_CODE_URL,
json={"code": code_text, "timeout": timeout},
headers=headers,
timeout=min(timeout, 30) + 10,
)
resp.raise_for_status()
data = resp.json() if resp.content else {}
except Exception as e: # noqa: BLE001
return False, f"probe request failed: {e}"
job_id = data.get("job_id")
if not job_id:
return False, "probe runner did not return job_id"
status_base = RUN_CODE_URL.rsplit("/", 1)[0].rstrip("/")
status_url = f"{status_base}/status/{job_id}"
deadline = time.time() + timeout
poll_interval = 1.5
result_data = None
while time.time() < deadline:
try:
r = requests.get(status_url, headers=headers, timeout=min(timeout, 30))
r.raise_for_status()
job = r.json() if r.content else {}
except Exception:
time.sleep(poll_interval)
continue
status = job.get("status")
if status in {"success", "failed", "timeout", "error"}:
result_data = job
break
time.sleep(poll_interval)
if result_data is None:
return False, f"timed out after {timeout}s"
status = result_data.get("status")
retcode = result_data.get("returncode")
stderr = result_data.get("stderr") or ""
summary = result_data.get("summary") or ""
if status == "success" and retcode == 0:
return True, ""
return False, summary or stderr or f"probe failed (status={status}, returncode={retcode})"
def _probe_dataset_load(dataset_id: str, name: Optional[str], split: str) -> Tuple[bool, str]:
"""
Trigger a small load via proxy runner to weed out slow/invalid datasets.
Loads up to 20 samples; returns (success, error_message).
"""
name_literal = "None" if not name or name == "default" else repr(name)
split_literal = repr(split or "train")
code = (
"from aidp import load_remote_dataset\n"
"from datasets import disable_caching\n"
"disable_caching()\n"
f"ds = load_remote_dataset({repr(dataset_id)}, name={name_literal}, split={split_literal}, num_samples=20, shuffle=False)\n"
"print('rows', len(ds))\n"
"print('columns', list(ds.features.keys()))\n"
)
return _run_code_probe(code, timeout=DATASET_LOAD_PROBE_TIMEOUT)
def _filter_dataset_infos_by_load(dataset_infos: List[Dict]) -> Tuple[List[Dict], Dict[str, str]]:
"""
Run load probes for each dataset (concurrent, limited). Datasets that fail or timeout are dropped.
Returns (kept_dataset_infos, load_errors).
"""
kept: List[Dict] = []
errors: Dict[str, str] = {}
with ThreadPoolExecutor(max_workers=max(1, DATASET_LOAD_PROBE_CONCURRENCY)) as pool:
future_map = {}
for info in dataset_infos:
dataset_id = info.get("dataset_id", "")
examples = info.get("examples") or []
if not examples:
errors[dataset_id] = "no examples to probe"
continue
first = examples[0]
name = first.get("name") or None
split = first.get("split") or "train"
future = pool.submit(_probe_dataset_load, dataset_id, name, split)
future_map[future] = (dataset_id, info)
for future in future_map:
dataset_id, info = future_map[future]
try:
ok, err = future.result()
except Exception as e: # pragma: no cover - robustness
ok, err = False, str(e)
if ok:
kept.append(info)
else:
errors[dataset_id] = err
return kept, errors
def _fetch_json(url: str, timeout: int) -> Dict:
try:
with urllib.request.urlopen(url, timeout=timeout) as resp:
if resp.status != 200:
raise RuntimeError(f"HTTP {resp.status}: {resp.read().decode('utf-8', 'ignore')}")
data = resp.read()
return json.loads(data.decode("utf-8"))
except urllib.error.URLError as e:
raise RuntimeError(f"Request failed: {e}") from e
def _select_split(splits: List[Dict]) -> Tuple[str, str]:
if not splits:
raise RuntimeError("No splits available.")
# Prefer train split
for item in splits:
if item.get("split") == "train":
return item.get("config") or "default", "train"
first = splits[0]
return first.get("config") or "default", first.get("split") or "train"
def _build_dataset_examples(dataset_id: str) -> Dict:
"""Lightweight preview using datasets-server first-rows API (no full download)."""
quoted_id = urllib.parse.quote(dataset_id, safe="")
splits_url = f"https://datasets-server.huggingface.co/splits?dataset={quoted_id}"
splits_resp = _fetch_json(splits_url, DATASET_PREVIEW_TIMEOUT)
splits = splits_resp.get("splits", [])
config, split = _select_split(splits)
first_rows_url = (
"https://datasets-server.huggingface.co/first-rows"
f"?dataset={quoted_id}&config={urllib.parse.quote(config, safe='')}"
f"&split={urllib.parse.quote(split, safe='')}&offset=0&length={PREVIEW_LIMIT}"
)
first_rows = _fetch_json(first_rows_url, DATASET_PREVIEW_TIMEOUT)
features = first_rows.get("features", [])
schema = {
feat["name"]: _feature_repr_from_hf(feat)
for feat in features
if isinstance(feat, dict) and "name" in feat
}
preview_examples: List[Dict] = []
for row_item in first_rows.get("rows", []):
if not isinstance(row_item, dict):
continue
row = row_item.get("row", {})
if not isinstance(row, dict):
preview_examples.append({"row": _truncate_text(str(row))})
continue
trimmed = {}
for k, v in row.items():
value_text = json.dumps(v, ensure_ascii=False) if not isinstance(v, str) else v
trimmed[k] = value_text
preview_examples.append(trimmed)
if len(preview_examples) >= PREVIEW_LIMIT:
break
return {
"dataset_id": dataset_id,
"revision": "main",
"examples": [{
"name": config or "default",
"split": split,
"schema": schema,
"preview_examples": preview_examples,
}],
}
def _build_dataset_examples_legacy(dataset_id: str) -> Dict:
"""Fallback preview by loading a small slice via datasets (may download data)."""
fallback_revisions = [None, "convert/parquet", "refs/convert/parquet", "parquet"]
last_err: Optional[Exception] = None
script_err: Optional[RuntimeError] = None
def load_with_revision(revision: Optional[str]) -> Dict:
config_names = get_dataset_config_names(dataset_id, revision=revision)
config = config_names[0] if config_names else None
split_names = (
get_dataset_split_names(dataset_id, config, revision=revision)
if config
else get_dataset_split_names(dataset_id, revision=revision)
)
if "train" in split_names:
split = "train"
else:
split = split_names[0] if split_names else "train"
load_kwargs = dict(split=f"{split}[:{PREVIEW_LIMIT}]")
if revision:
load_kwargs["revision"] = revision
if config:
ds = load_dataset(dataset_id, name=config, **load_kwargs)
else:
ds = load_dataset(dataset_id, **load_kwargs)
schema = {k: _feature_repr(v) for k, v in ds.features.items()}
preview_examples: List[Dict] = []
for item in ds:
trimmed = {}
for k, v in item.items():
value_text = json.dumps(v, ensure_ascii=False) if not isinstance(v, str) else v
trimmed[k] = value_text
preview_examples.append(trimmed)
return {
"dataset_id": dataset_id,
"revision": revision or "main",
"examples": [{
"name": config or "default",
"split": split,
"schema": schema,
"preview_examples": preview_examples,
}],
}
for rev in fallback_revisions:
try:
return load_with_revision(rev)
except RuntimeError as e:
last_err = e
msg = str(e)
if "Dataset scripts are no longer supported" in msg:
script_err = e
continue
if "doesn't exist for dataset" in msg or "Not Found" in msg:
continue
raise
except Exception as e:
last_err = e
raise
if script_err:
raise RuntimeError(
f"{dataset_id} requires running its dataset script, which is disabled in datasets>=4. "
"No parquet/convert revision was found."
) from script_err
raise last_err or RuntimeError("Unknown error while loading dataset.")
def _normalize_dataset_ids(dataset_ids: Any) -> List[str]:
if dataset_ids is None:
return []
if hasattr(dataset_ids, "to_numpy"):
rows = dataset_ids.to_numpy().tolist()
else:
rows = dataset_ids
if not rows:
return []
flat: List[str] = []
for row in rows:
if not row:
continue
value = str(row[0]).strip()
if value:
flat.append(value)
seen = set()
result = []
for item in flat:
if item in seen:
continue
seen.add(item)
result.append(item)
return result
def _build_input_sample(
task_description: str,
benchmark_name: str,
benchmark_description: str,
dataset_ids: List[str],
) -> Dict:
if not dataset_ids:
raise ValueError("Dataset list is empty. Please select at least one dataset.")
# If preview just ran with the same datasets, reuse its filtered result to avoid re-probing.
global _LAST_PREVIEW_IDS, _LAST_PREVIEW_DATASET_INFOS, _LAST_PREVIEW_ERRORS
if _LAST_PREVIEW_IDS == dataset_ids and _LAST_PREVIEW_DATASET_INFOS:
dataset_infos = list(_LAST_PREVIEW_DATASET_INFOS)
errors = dict(_LAST_PREVIEW_ERRORS)
else:
dataset_infos = []
errors: Dict[str, str] = {}
for dataset_id in dataset_ids:
try:
dataset_infos.append(_get_dataset_examples(dataset_id, DATASET_PREVIEW_TIMEOUT))
except Exception as e:
errors[dataset_id] = str(e)
if dataset_infos:
dataset_infos, load_errors = _filter_dataset_infos_by_load(dataset_infos)
errors.update(load_errors)
_LAST_PREVIEW_IDS = dataset_ids
_LAST_PREVIEW_DATASET_INFOS = list(dataset_infos)
_LAST_PREVIEW_ERRORS = dict(errors)
# === 修改开始: 检查是否所有数据集都挂了 ===
if not dataset_infos:
error_details = "; ".join([f"{k}: {v}" for k, v in errors.items()])
raise ValueError(f"All selected datasets failed to load (timeout or access error). Please try smaller or different datasets. Details: {error_details}")
# === 修改结束 ===
if errors and not dataset_infos:
# Fallback (logic covered above, but kept for safety)
raise ValueError(f"Dataset build error: {errors}")
return {
"id": 0,
"task": {
"description": task_description,
"benchmark": {
"name": benchmark_name,
"description": benchmark_description,
},
},
"datasets": dataset_infos,
"errors": errors,
}
def _preview_dataset(dataset_ids_table: List[List[str]]) -> str:
dataset_ids = _normalize_dataset_ids(dataset_ids_table)
if not dataset_ids:
return "Please provide at least one dataset_id."
try:
global _LAST_PREVIEW_IDS, _LAST_PREVIEW_DATASET_INFOS, _LAST_PREVIEW_ERRORS
dataset_infos: List[Dict] = []
errors: Dict[str, str] = {}
for dataset_id in dataset_ids:
try:
dataset_infos.append(_get_dataset_examples(dataset_id, DATASET_PREVIEW_TIMEOUT))
except Exception as e:
errors[dataset_id] = str(e)
# Probe loadability and drop slow/failed datasets before rendering preview
if dataset_infos:
dataset_infos, load_errors = _filter_dataset_infos_by_load(dataset_infos)
errors.update(load_errors)
_LAST_PREVIEW_IDS = dataset_ids
_LAST_PREVIEW_DATASET_INFOS = list(dataset_infos)
_LAST_PREVIEW_ERRORS = dict(errors)
return _render_preview_table_html(dataset_infos, errors if errors else None)
except Exception as e:
return f"Preview error: {e}"
def _on_df_select(evt: gr.SelectData) -> Optional[int]:
if not evt.index:
return None
return evt.index[0]
def _parse_code_block(response: str) -> List[str]:
blocks = []
start_pattern = re.compile(r"(?m)^[ \\t]*```python[ \\t]*\\r?\\n")
for match in start_pattern.finditer(response):
start_pos = match.end()
next_match = start_pattern.search(response, start_pos)
search_end = next_match.start() if next_match else len(response)
last_end = -1
for end_match in re.finditer(r"```", response[start_pos:search_end]):
end_pos = start_pos + end_match.start()
after = end_pos + len("```")
if after >= len(response) or response[after] in ["\n", "\r"]:
last_end = end_pos
if last_end != -1:
content = response[start_pos:last_end].rstrip("\n\r ")
blocks.append(content)
return blocks
def _strip_code_fence(code_text: str) -> str:
"""Remove surrounding ``` or ```python fences if present."""
code = code_text.strip()
if code.startswith("```"):
# remove first line fence
lines = code.splitlines()
if lines and lines[0].lstrip("`").lower().startswith("python"):
lines = lines[1:]
elif lines and lines[0].startswith("```"):
lines = lines[1:]
# drop trailing fence
if lines and lines[-1].strip().startswith("```"):
lines = lines[:-1]
code = "\n".join(lines).strip()
return code
def _syntax_check(code_text: str) -> Tuple[bool, str]:
try:
ast.parse(code_text)
except SyntaxError as e:
location = f"line {e.lineno}, col {e.offset}" if e.lineno and e.offset else "unknown location"
return False, f"SyntaxError: {e.msg} ({location})"
return True, ""
def _build_sample_from_inputs(
task_description: str,
benchmark_name: str,
benchmark_description: str,
dataset_ids_table: List[List[str]],
) -> Tuple[Optional[Dict], Optional[str]]:
dataset_ids = _normalize_dataset_ids(dataset_ids_table)
if not task_description or not benchmark_name or not dataset_ids:
return None, "Missing required input fields."
# Allow empty benchmark description by normalizing to empty string
benchmark_description = benchmark_description or ""
try:
sample = _build_input_sample(
task_description=task_description,
benchmark_name=benchmark_name,
benchmark_description=benchmark_description,
dataset_ids=dataset_ids,
)
except Exception as e:
return None, f"Dataset build error: {e}"
return sample, None
def _generate_plan_from_sample(sample: Dict) -> str:
planner_messages = [{
"role": "system",
"content": PLANNER_SYSTEM_PROMPT,
}, {
"role": "user",
"content": _render_plan_prompt(
task_description=sample["task"]["description"],
benchmark=sample["task"]["benchmark"],
datasets=sample["datasets"],
),
}]
try:
plan_text = _generate_text(
planner_messages,
max_new_tokens=PLAN_MAX_TOKENS,
temperature=0.6,
top_p=0.95,
)
_log_preview("planner_output", plan_text)
return plan_text
except Exception as e:
return f"Planner error: {e}"
def _generate_code_from_plan(datasets: List[Dict], plan_text: str, progress=None) -> str:
_log_preview("coder_plan_input", plan_text)
tool_info = _build_tool_info()
first_code = ""
for attempt in range(MAX_ATTEMPTS):
# 进度条提示:显示当前尝试次数
if progress:
progress(0.5 + (attempt * 0.1), desc=f"Coding (Attempt {attempt + 1}/{MAX_ATTEMPTS})...")
logger.info("[attempt] coder attempt %s/%s", attempt + 1, MAX_ATTEMPTS)
coder_messages = [{
"role": "system",
"content": CODER_SYSTEM_PROMPT,
}, {
"role": "user",
"content": _render_code_prompt(
datasets=datasets,
plan=plan_text,
tool_info=tool_info,
),
}]
_log_full("coder_request_messages", json.dumps(coder_messages, ensure_ascii=False))
try:
coder_response = _generate_text(
coder_messages,
max_new_tokens=CODE_MAX_TOKENS,
temperature=0.6,
top_p=0.95,
)
except Exception as e:
return f"Coder error: {e}"
_log_full("coder_raw_response", coder_response)
blocks = _parse_code_block(coder_response)
syn_code = ""
for block in blocks:
if "data-processing code block" in block:
syn_code = block
break
if not syn_code and blocks:
for block in blocks:
if "test code block" in block:
continue
syn_code = block
break
if not syn_code and blocks:
syn_code = blocks[0]
if not syn_code:
syn_code = coder_response
if attempt == 0:
first_code = syn_code
cleaned_code = _strip_code_fence(syn_code)
_log_full("coder_cleaned_code", cleaned_code)
ok, _ = _syntax_check(AIDP_HEADER + cleaned_code.strip() + "\n")
if ok:
if progress: progress(1.0, desc="Code Generation Complete")
return cleaned_code
else:
if progress:
progress(0.5 + (attempt * 0.1), desc=f"Retrying ({attempt + 1})...")
return first_code
def _execute_generated_code_remote(code_text: str, timeout: int = RUN_CODE_TIMEOUT, task_description: str = "") -> Tuple[str, str, str]:
"""
Returns:
eval_md: HTML table of evaluation scores and samples.
preview_md: HTML table of processed data previews (kept for compatibility).
logs_txt: Plain text of stdout/stderr/status.
"""
if not code_text or not code_text.strip():
return "", "", "No code to run. Please generate code first."
_log_preview("run_code_request", code_text, limit=800)
if task_description:
_log_preview("run_code_task_description", task_description, limit=400)
logger.info("[run_code] POST %s timeout=%ss", RUN_CODE_URL, timeout)
headers = {}
if RUN_CODE_KEY:
headers["x-forward-key"] = RUN_CODE_KEY
try:
resp = requests.post(
RUN_CODE_URL,
json={"code": code_text, "timeout": timeout, "task_description": task_description},
headers=headers,
timeout=min(timeout, 30) + 10,
)
resp.raise_for_status()
data = resp.json() if resp.content else {}
except Exception as e:
logger.exception("[run_code] request failed")
return "", "", f"Run error: {e}"
job_id = data.get("job_id")
if not job_id:
return "", "", "Run error: runner did not return job_id."
status_base = RUN_CODE_URL.rsplit("/", 1)[0].rstrip("/")
status_url = f"{status_base}/status/{job_id}"
deadline = time.time() + timeout
poll_interval = 2
result_data = None
while time.time() < deadline:
try:
r = requests.get(status_url, headers=headers, timeout=min(timeout, 30))
r.raise_for_status()
job = r.json() if r.content else {}
except Exception as e: # noqa: BLE001
logger.warning("[run_code] status poll error: %s", e)
time.sleep(poll_interval)
continue
status = job.get("status")
if status in {"success", "failed", "timeout", "error"}:
result_data = job
break
time.sleep(poll_interval)
if result_data is None:
return "", "", f"Run error: job {job_id} did not complete within {timeout}s"
data = result_data
# 1. Parse Logs
status = data.get("status", "unknown")
retcode = data.get("returncode", "n/a")
summary = data.get("summary", "")
stdout = data.get("stdout", "")
stderr = data.get("stderr", "")
duration = data.get("duration_sec")
logger.info(
"[run_code] status=%s returncode=%s duration=%s summary=%s",
status,
retcode,
f"{duration:.2f}s" if duration is not None else "n/a",
_truncate_text(summary or "", limit=200),
)
logs = [f"Status: {status} (returncode={retcode})"]
if duration is not None:
logs.append(f"Duration: {duration:.2f}s")
if summary:
logs.append(f"Summary: {summary}")
if stdout:
_log_preview("run_code_stdout", stdout, limit=800)
logs.append("\n--- stdout ---\n" + stdout)
if stderr:
_log_preview("run_code_stderr", stderr, limit=800)
logs.append("\n--- stderr ---\n" + stderr)
logs_txt = "\n".join(logs)
# 2. Parse Evaluation Table -> Renamed to Processed Training Data
data_scores = data.get("data_scores") or []
data_score_error = data.get("data_score_error")
eval_md = ""
if data_scores:
eval_md = _render_processed_scores_table(data_scores)
elif data_score_error:
eval_md = f"**Data Processing Error:** {data_score_error}"
else:
eval_md = "No data returned from execution."
# 3. Processed Preview (Unused in new UI but kept for compatibility)
preview_md = ""
return eval_md, preview_md, logs_txt
def _run_full_automation(
task_description: str,
benchmark_name: str,
benchmark_description: str,
dataset_ids_table: List[List[str]]
):
"""
Generator that runs the entire pipeline: Preview -> Plan -> Code -> Execute.
Yields: (status_html, preview_html, plan_md, plan_raw, code_src, eval_html)
"""
empty = ""
loading_html = ""
dataset_ids = _normalize_dataset_ids(dataset_ids_table)
# 数量限制检查
if len(dataset_ids) > 10:
err_msg = f"Resource Limit: Please select no more than 10 datasets (currently selected {len(dataset_ids)}). Limited resources available."
yield (
gr.update(value=_render_progress_html(0, error=True, detail_msg=err_msg), visible=True),
f"{err_msg}
", empty, empty, empty, empty
)
return
# --- Step 1: Data Preview ---
# 计算预览时间:<5个 60s,>=5个 120s
preview_time = 60 if len(dataset_ids) < 5 else 120
yield (
gr.update(value=_render_progress_html(0, detail_msg="Fetching dataset info...", est_time=preview_time), visible=True),
loading_html, empty, empty, empty, empty
)
# 执行 Preview
preview_md = _preview_dataset(dataset_ids_table)
# --- Step 2: Planning ---
# Plan 时间:固定 60s
yield (
gr.update(value=_render_progress_html(1, detail_msg="Designing data recipe...", est_time=60)),
preview_md, loading_html, empty, empty, empty
)
# 构建 Sample
sample, error = _build_sample_from_inputs(task_description, benchmark_name, benchmark_description, dataset_ids_table)
if error:
err_html = _render_progress_html(1, error=True, detail_msg="Dataset loading failed")
yield (gr.update(value=err_html), preview_md, f"❌ Error: {error}", f"❌ {error}", empty, empty)
return
# 生成 Plan
plan_text = _generate_plan_from_sample(sample)
if plan_text.startswith("Planner error"):
err_html = _render_progress_html(1, error=True, detail_msg="Planner failed")
yield (gr.update(value=err_html), preview_md, plan_text, plan_text, empty, empty)
return
# --- Step 3: Coding ---
# Coding 时间:固定 60s
yield (
gr.update(value=_render_progress_html(2, detail_msg="Generating Python script...", est_time=60)),
preview_md, plan_text, plan_text, loading_html, empty
)
# 生成 Code
code_text = _generate_code_from_plan(sample["datasets"], plan_text)
if code_text.startswith("Coder error"):
err_html = _render_progress_html(2, error=True, detail_msg="Coder failed")
yield (gr.update(value=err_html), preview_md, plan_text, plan_text, code_text, empty)
return
# --- Step 4: Execution ---
# Execution 时间:固定 180s
yield (
gr.update(value=_render_progress_html(3, detail_msg="Executing code on remote runner...", est_time=180)),
preview_md, plan_text, plan_text, code_text, "Processing data... this may take a moment"
)
# 执行 Code
exec_task_desc = _compose_task_context(task_description, benchmark_description)
eval_md, _, logs_txt = _execute_generated_code_remote(code_text, task_description=exec_task_desc)
# 检查运行结果
is_run_error = False
run_msg = "Done"
if "Run error" in logs_txt or "timeout" in logs_txt.lower() or "traceback" in logs_txt.lower():
if "Status: success" not in logs_txt:
is_run_error = True
run_msg = "Execution failed or timed out"
if is_run_error:
final_progress = _render_progress_html(3, error=True, detail_msg=run_msg)
if "timeout" in logs_txt.lower():
eval_md = f"Timeout: Code execution took too long. Please simplify the logic or try fewer samples.
" + eval_md
else:
eval_md = f"Runtime Error: Check the code or logs.
" + eval_md
yield (
gr.update(value=final_progress),
preview_md, plan_text, plan_text, code_text, eval_md
)
else:
# 正常完成,隐藏进度条
yield (
gr.update(visible=False),
preview_md, plan_text, plan_text, code_text, eval_md
)
def _format_model_status() -> str:
model_id = _resolve_model_id()
if _use_remote_llm():
auth = "with api key" if (_REMOTE_API_KEY or os.getenv("OPENAI_API_KEY")) else "dummy key"
return f"Remote vLLM @ {_REMOTE_API_BASE or 'unset'} ({auth}), model={model_id}"
return "Remote vLLM not configured (set DATACHEF_VLLM_URL)."
def _warmup_model() -> str:
return _format_model_status()
# --- CSS 样式定义 (全局) ---
_CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500&family=Inter:wght@400;500;600;700;800&display=swap');
:root {
--primary-col: #10b981;
--panel-bg: #ffffff;
--border-col: #e2e8f0;
}
.gradio-container {
font-family: 'Inter', system-ui, sans-serif !important;
background-color: #f8fafc !important;
}
/* 标题区域 */
.header-container {
padding: 1.5rem 0 1rem;
margin-bottom: 1rem;
border-bottom: 1px solid var(--border-col);
}
/* 卡片容器 */
.control-panel {
background: var(--panel-bg);
border: 1px solid var(--border-col);
border-radius: 12px;
padding: 1.25rem;
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.05);
}
/* 右侧输出面板 */
.output-panel {
background: var(--panel-bg);
border: 1px solid var(--border-col);
border-radius: 12px;
padding: 1.25rem;
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.05);
display: flex;
flex-direction: column;
justify-content: flex-start;
height: 100%;
}
/* 绿色胶囊标题 */
.green-header {
background-color: #ecfdf5;
color: #047857;
border: 1px solid #a7f3d0;
padding: 4px 12px;
border-radius: 6px;
font-size: 0.85rem;
font-weight: 700;
display: inline-flex;
align-items: center;
margin-bottom: 12px;
margin-top: 5px;
letter-spacing: 0.02em;
font-family: 'Inter', sans-serif;
}
/* 状态栏 */
.status-bar-compact {
background: #f1f5f9;
border-radius: 8px;
padding: 8px 12px;
margin-bottom: 1.25rem;
display: flex;
align-items: center;
justify-content: space-between;
font-size: 0.85rem;
}
/* 进度条样式 */
.step-container {
background: #fff; padding: 12px 20px; border-radius: 8px;
border: 1px solid #e2e8f0; margin-bottom: 15px;
box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05);
}
.step-row { display: flex; align-items: center; justify-content: center; width: 100%; }
.step-item { display: flex; align-items: center; gap: 8px; color: #94a3b8; font-weight: 600; font-size: 0.9rem; }
.step-icon {
width: 24px; height: 24px; border-radius: 50%; display: flex; align-items: center; justify-content: center;
font-size: 11px; border: 2px solid #cbd5e1; background: white; color: #64748b; font-weight: 700;
}
.step-item.done { color: #059669; }
.step-item.done .step-icon { background: #10b981; border-color: #10b981; color: white; }
.step-item.active { color: #0ea5e9; }
.step-item.active .step-icon { border-color: #0ea5e9; color: #0ea5e9; background: #e0f2fe; animation: pulse 2s infinite; }
.step-arrow { color: #cbd5e1; margin: 0 12px; font-size: 12px; }
.step-status-text { margin-top: 8px; font-size: 0.8rem; color: #64748b; text-align: center; height: 16px; }
@keyframes pulse {
0% { box-shadow: 0 0 0 0 rgba(14, 165, 233, 0.4); }
70% { box-shadow: 0 0 0 6px rgba(14, 165, 233, 0); }
100% { box-shadow: 0 0 0 0 rgba(14, 165, 233, 0); }
}
/* === 左侧表格与按钮美化 (Gradio Dataframe) === */
.unified-dataframe table { font-size: 12px !important; }
.unified-dataframe td, .unified-dataframe th { padding: 6px 8px !important; }
.unified-dataframe { margin-top: 0 !important; margin-bottom: 0 !important; }
.sub-label {
font-size: 0.85rem; font-weight: 700; color: #059669;
margin-bottom: 6px; display: flex; align-items: center; gap: 6px;
white-space: nowrap;
}
/* 小操作按钮 */
.action-btn {
font-size: 0.75rem !important;
padding: 0px 10px !important;
height: 28px !important;
min-height: 28px !important;
border: 1px solid #cbd5e1 !important;
background: linear-gradient(to bottom, #ffffff, #f8fafc);
color: #475569 !important;
border-radius: 6px !important;
box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05);
}
.action-btn:hover { background: #f1f5f9 !important; border-color: #94a3b8 !important; color: #0f172a !important; }
.delete-btn { color: #b91c1c !important; border-color: #fca5a5 !important; background: #fef2f2 !important; }
.delete-btn:hover { background: #fee2e2 !important; border-color: #f87171 !important; }
/* 搜索按钮特化 */
.search-btn-primary {
font-weight: 800 !important;
font-size: 0.9rem !important;
text-transform: uppercase;
letter-spacing: 0.05em;
box-shadow: 0 2px 4px rgba(16, 185, 129, 0.2);
}
/* === 右侧结果表格美化 (HTML Table) === */
.data-table-host { width: 100%; }
.table-wrapper { width: 100%; overflow-x: auto; }
.data-table {
width: 100%;
border-collapse: separate;
border-spacing: 0;
border: 1px solid #e2e8f0;
border-radius: 8px;
font-size: 0.85rem;
background-color: #fff;
table-layout: auto; /* 允许根据内容自动调整列宽 */
}
.data-table thead th {
background: #ecfdf5;
color: #166534;
font-weight: 700;
padding: 10px 12px;
border-bottom: 1px solid #bbf7d0;
text-align: left;
white-space: nowrap;
}
.data-table td {
padding: 8px 12px;
border-bottom: 1px solid #f1f5f9;
color: #334155;
vertical-align: top;
/* 核心修复:防止 text 溢出,同时允许长单词换行 */
word-wrap: break-word;
max-width: 600px;
}
.data-table tr:last-child td { border-bottom: none; }
.data-table tr:hover td { background-color: #f8fafc; }
/* 样本列特化:保证宽度,防止被挤压 */
.sample-cell {
min-width: 300px;
width: 50%;
}
.sample-details summary {
cursor: pointer;
font-weight: 600;
color: #047857;
outline: none;
margin-bottom: 4px;
}
.sample-pre {
margin-top: 6px;
padding: 10px;
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 6px;
font-family: 'JetBrains Mono', monospace;
font-size: 0.75rem; /* 12px */
color: #0f172a;
/* 核心修复:保证 JSON 代码块自动换行,不撑破表格 */
white-space: pre-wrap;
word-break: break-all;
max-height: 400px; /* 过长则滚动 */
overflow-y: auto;
}
.error-block { margin-top: 12px; padding: 10px 12px; background: #fef2f2; border: 1px solid #fecdd3; border-radius: 10px; color: #991b1b; }
/* 通用工具类 */
.no-label-input > label > span { display: none; }
.mt-0 { margin-top: 0 !important; }
.mt-2 { margin-top: 0.5rem !important; }
.mt-4 { margin-top: 1rem !important; }
.mb-2 { margin-bottom: 0.5rem !important; }
.gap-2 { gap: 0.5rem; }
.code-box { font-size: 12px !important; }
.gradio-container .tabs { border-bottom: 1px solid #e2e8f0; }
.control-panel .gradio-dataframe { margin-top: 0 !important; margin-bottom: 0 !important; }
.btn-row { display: flex; flex-direction: row; align-items: center; }
.progress-bar-container {
width: 100%;
background-color: #f1f5f9;
border-radius: 10px; /* 圆角稍微大一点更好看 */
height: 8px;
margin-top: 8px;
overflow: hidden;
border: 1px solid #e2e8f0;
}
.progress-bar-fill {
height: 100%;
background: linear-gradient(90deg, #0ea5e9, #22c55e); /* 增加渐变效果 */
width: 0%;
}
.est-time-text {
font-size: 0.7rem;
color: #94a3b8;
margin-top: 4px;
text-align: center;
font-family: 'JetBrains Mono', monospace;
}
/* === 新增: 用户指南样式 (保持不变) === */
.info-box {
background: #f0f9ff;
border: 1px solid #bae6fd;
color: #0369a1;
padding: 12px 16px;
border-radius: 8px;
font-size: 0.85rem;
line-height: 1.5;
margin-bottom: 16px;
}
@keyframes progressFill {
0% { width: 0%; }
90% { width: 90%; } /* 稍微留一点余地,防止还没跑完就满了 */
100% { width: 95%; }
}
"""
def create_demo() -> gr.Blocks:
global _DEMO_THEME, _DEMO_CSS
# 1. 配置主题和 CSS
theme = gr.themes.Soft(
primary_hue="emerald",
neutral_hue="slate",
radius_size="md",
text_size="sm",
font=[gr.themes.GoogleFont("Inter"), "sans-serif"],
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "monospace"],
)
_DEMO_THEME = theme
_DEMO_CSS = _CUSTOM_CSS
# 2. 准备默认数据
default_sample = DEFAULT_SAMPLE
default_task = default_sample.get("task", {})
default_benchmark = default_task.get("benchmark", {})
default_dataset_rows: List[List[str]] = []
if default_sample.get("datasets"):
default_dataset_rows = [[d.get("dataset_id", "")] for d in default_sample["datasets"] if d.get("dataset_id")]
# 3. 辅助 HTML 函数
def get_header_html():
return f"""
DataChef
Automatic Data Recipe Generation for LLM Adaptation
"""
def get_status_text():
model_id = _resolve_model_id()
if _use_remote_llm():
status_dot = ""
status_text = f"Ready"
model_info = f"{model_id}"
else:
status_dot = ""
status_text = f"Not Configured"
model_info = "Remote vLLM unset"
return f"{status_dot} {status_text}
{model_info}
"
# 4. 构建 UI
with gr.Blocks(title="DataChef") as demo:
# --- 顶部 Header ---
with gr.Row(elem_classes=["header-container"]):
gr.HTML(get_header_html())
# --- 主体布局 ---
with gr.Row():
# =============== 左侧控制面板 (Left Column) ===============
with gr.Column(scale=32, elem_classes=["control-panel"]):
# 顶部状态栏
with gr.Row(elem_classes=["status-bar-compact"]):
status_html = gr.HTML(get_status_text())
load_model_btn = gr.Button("↻", size="sm", variant="secondary", min_width=30)
# user guide
gr.HTML("""
👋 Welcome to DataChef Demo
- Mode 1 (Easy): Choose a preset in Configuration Presets and click Generate Recipe.
- Mode 2 (Custom): Provide a task description, benchmark details, and provide datasets manually or via Auto Search.
""")
# 预设选择
gr.HTML("")
# 修复:定义 preset_dropdown
preset_dropdown = gr.Dropdown(choices=list(PRESETS.keys()), value="Physics", show_label=False, container=False, interactive=True)
# 任务描述
gr.HTML("")
task_description = gr.Textbox(show_label=False, lines=3, value=default_task.get("description", ""), placeholder="Describe task...", elem_classes=["no-label-input"])
# Benchmark 设置
gr.HTML("")
benchmark_name = gr.Textbox(label="Name", placeholder="Benchmark Name (e.g., AIME2025)", value=default_benchmark.get("name", ""), lines=1)
benchmark_description = gr.Textbox(
label="Description (Optional)",
placeholder="Recommand: Provide description and test examples to guide the generation (e.g., AIME is xxx. Example: {'question': 'Calculate the velocity of...', xxx})",
lines=3,
value=default_benchmark.get("description", "")
)
# 数据集选择区域
gr.HTML("")
# 1. Suggested Dataset Area
with gr.Row(elem_classes=["mb-2", "btn-row"]):
gr.HTML("🔍 Suggested
", scale=0)
auto_ds_btn = gr.Button("🔎 Auto Search", variant="primary", scale=1, elem_classes=["search-btn-primary"])
suggested_ds = gr.Dataframe(
headers=["Dataset ID"],
datatype=["str"],
row_count=(3, "fixed"),
value=[],
interactive=False,
wrap=True,
elem_classes=["unified-dataframe"]
)
# 中间操作按钮组
with gr.Row(elem_classes=["btn-row", "mt-2", "gap-2"]):
add_suggested_btn = gr.Button("⬇️ Add Selected", size="sm", elem_classes=["action-btn"])
add_all_suggested_btn = gr.Button("⏬ Add All", size="sm", elem_classes=["action-btn"])
# 2. Selected Dataset Area
gr.HTML("✅ Selected
")
# 修复:col_count -> column_count
dataset_ids = gr.Dataframe(
headers=["Dataset ID"],
datatype=["str"],
row_count=(3, "dynamic"),
value=default_dataset_rows,
interactive=True,
column_count=(1, "fixed"),
wrap=True,
elem_classes=["unified-dataframe"]
)
# Selected 表格操作按钮
selected_ds_idx = gr.State(value=None)
with gr.Row(elem_classes=["btn-row", "mt-2", "gap-2"]):
del_row_btn = gr.Button("🗑️ Delete", size="sm", elem_classes=["action-btn", "delete-btn"])
add_row_btn = gr.Button("➕ Add Row", size="sm", elem_classes=["action-btn"])
# 状态存储
suggested_selected_idx = gr.State(value=None)
# 运行按钮
with gr.Row(elem_classes=["mt-4"]):
run_btn = gr.Button("🚀 Generate Recipe", variant="primary", size="lg")
# =============== 右侧输出面板 (Right Column) ===============
with gr.Column(scale=68):
# 进度状态
process_status = gr.HTML(visible=False, label=None)
# Plan & Code 分栏视图
with gr.Row(equal_height=False, elem_classes=["mt-0"]):
# Plan
with gr.Column(variant="panel", elem_classes=["output-panel"]):
gr.HTML("")
with gr.Tabs():
with gr.Tab("Rendered"):
plan_out_md = gr.Markdown(value="_Waiting..._", line_breaks=True)
with gr.Tab("Source"):
plan_out_raw = gr.Code(language="markdown", interactive=False, lines=15, elem_classes=["code-box"])
# Code
with gr.Column(variant="panel", elem_classes=["output-panel"]):
gr.HTML("")
code_out = gr.Code(language="python", interactive=False, lines=18, elem_classes=["code-box"])
# 结果面板
with gr.Column(variant="panel", elem_classes=["output-panel", "mt-4"]):
gr.HTML("")
with gr.Tabs():
with gr.Tab("🏆 Processed Training Data"):
eval_output_html = gr.HTML("Recipe generation not started yet.", elem_classes=["data-table-host"])
with gr.Tab("📥 Raw Data"):
preview_out_html = gr.HTML("Waiting for input...", elem_classes=["data-table-host"])
# --- 事件绑定 ---
# 此时 preset_dropdown 已经定义
preset_dropdown.change(
fn=_load_preset_config,
inputs=[preset_dropdown],
outputs=[task_description, benchmark_name, benchmark_description, dataset_ids]
)
# 自动搜索
auto_ds_btn.click(
fn=_auto_suggest_datasets_ui,
inputs=[task_description, benchmark_description],
outputs=[suggested_ds, suggested_selected_idx],
)
# 选中 Suggested 行
suggested_ds.select(
_on_df_select,
outputs=[suggested_selected_idx],
)
# Add Suggested (Single)
add_suggested_btn.click(
fn=_add_selected_suggestion,
inputs=[dataset_ids, suggested_ds, suggested_selected_idx],
outputs=[dataset_ids],
)
# Add All Suggested
add_all_suggested_btn.click(
fn=_add_all_suggestions,
inputs=[dataset_ids, suggested_ds],
outputs=[dataset_ids],
)
# 选中 Selected 行 (用于删除)
dataset_ids.select(
_on_df_select,
outputs=[selected_ds_idx]
)
# 删除 Selected 行
del_row_btn.click(
fn=_delete_selected_row,
inputs=[dataset_ids, selected_ds_idx],
outputs=[dataset_ids, selected_ds_idx]
)
# 添加空行到 Selected
add_row_btn.click(
fn=_add_empty_row,
inputs=[dataset_ids],
outputs=[dataset_ids]
)
# 运行逻辑
run_btn.click(
_run_full_automation,
inputs=[task_description, benchmark_name, benchmark_description, dataset_ids],
outputs=[process_status, preview_out_html, plan_out_md, plan_out_raw, code_out, eval_output_html],
show_progress="hidden"
)
load_model_btn.click(_warmup_model, outputs=[]).then(get_status_text, outputs=[status_html])
demo.load(_warmup_model, outputs=[]).then(get_status_text, outputs=[status_html])
return demo
demo = create_demo()
if __name__ == "__main__":
demo.launch(
theme=_DEMO_THEME,
css=_DEMO_CSS,
server_name="0.0.0.0",
share=True,
)