Spaces:
Running
Running
| import os | |
| import json | |
| import yaml | |
| import gzip | |
| import tarfile | |
| import zipfile | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from huggingface_hub.errors import BadRequestError | |
| TEXT_EXTENSIONS = {".tex", ".text", ".txt", ".bib", ".bbl", ".md"} | |
| def _normalize_message_content(content): | |
| if isinstance(content, list): | |
| return "\n".join(_normalize_message_content(item) for item in content) | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, dict): | |
| if "path" in content: | |
| return f"[uploaded file: {os.path.basename(content['path'])}]" | |
| return str(content.get("text", "")) | |
| if isinstance(content, gr.ChatMessage): | |
| return _normalize_message_content(content.content) | |
| return str(content) | |
| def _sanitize_history(history): | |
| clean = [] | |
| for msg in history or []: | |
| if isinstance(msg, gr.ChatMessage): | |
| role = msg.role | |
| content = msg.content | |
| metadata = msg.metadata | |
| elif isinstance(msg, dict): | |
| role = msg.get("role") | |
| content = msg.get("content", "") | |
| metadata = msg.get("metadata") | |
| else: | |
| continue | |
| if role not in {"user", "assistant", "system"}: | |
| continue | |
| message = {"role": role, "content": _normalize_message_content(content)} | |
| if metadata: | |
| message["metadata"] = metadata | |
| clean.append(message) | |
| return clean | |
| def _extract_streaming_tag_content(text, tag): | |
| source = text or "" | |
| start_tag = f"<{tag}>" | |
| end_tag = f"</{tag}>" | |
| start_idx = source.find(start_tag) | |
| if start_idx == -1: | |
| return "", False, False | |
| content_start = start_idx + len(start_tag) | |
| end_idx = source.find(end_tag, content_start) | |
| if end_idx == -1: | |
| return source[content_start:].lstrip(), True, False | |
| return source[content_start:end_idx].strip(), True, True | |
| def _parse_predictions_jsonl_items(predictions_block, closed): | |
| items = [] | |
| lines = (predictions_block or "").splitlines(keepends=True) | |
| for idx, line in enumerate(lines): | |
| is_last = idx == len(lines) - 1 | |
| line_is_complete = line.endswith("\n") or line.endswith("\r") | |
| if not line_is_complete and not (closed and is_last): | |
| continue | |
| stripped = line.strip() | |
| if not stripped: | |
| continue | |
| try: | |
| items.append(json.loads(stripped)) | |
| except json.JSONDecodeError: | |
| continue | |
| return items | |
| def _format_predictions_markdown(items): | |
| if not items: | |
| return "" | |
| out = [] | |
| for i, item in enumerate(items, start=1): | |
| name = item.get("name", "Untitled Ablation") | |
| ablated_part = item.get("ablated_part", "") | |
| action = item.get("action", "") | |
| replacement = item.get("replacement") | |
| metrics = item.get("metrics", []) | |
| if not isinstance(metrics, list): | |
| metrics = [str(metrics)] | |
| out.append(f"### {i}. **{name}**") | |
| out.append(f" <u>Component</u>: {ablated_part}") | |
| out.append(f" <u>Action</u>: {action}") | |
| if replacement is not None and replacement != "": | |
| out.append(" <u>Replacement</u>:") | |
| if isinstance(replacement, list): | |
| for rep in replacement: | |
| out.append(f" - `{rep}`") | |
| else: | |
| out.append(f" - `{replacement}`") | |
| out.append(" <u>Metrics</u>:") | |
| if metrics: | |
| for metric in metrics: | |
| out.append(f" - `{metric}`") | |
| else: | |
| out.append(" - `(none)`") | |
| out.append("\n---\n") | |
| return "\n".join(out).strip("- \n") | |
| def _read_text_file(path: Path) -> str: | |
| return path.read_text(encoding="utf-8", errors="ignore") | |
| def _is_relevant_text_file(path: Path) -> bool: | |
| return path.suffix.lower() in TEXT_EXTENSIONS | |
| def _safe_extract_zip(zip_path: Path, output_dir: Path) -> None: | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| for member in zf.infolist(): | |
| member_path = output_dir / member.filename | |
| resolved_member = member_path.resolve() | |
| resolved_root = output_dir.resolve() | |
| if not str(resolved_member).startswith(str(resolved_root)): | |
| continue | |
| zf.extract(member, output_dir) | |
| def _safe_extract_tar(tar_path: Path, output_dir: Path) -> None: | |
| with tarfile.open(tar_path, "r:*") as tf: | |
| for member in tf.getmembers(): | |
| member_path = output_dir / member.name | |
| resolved_member = member_path.resolve() | |
| resolved_root = output_dir.resolve() | |
| if not str(resolved_member).startswith(str(resolved_root)): | |
| continue | |
| tf.extract(member, output_dir) | |
| def _archive_to_tagged_source(extracted_root: Path) -> str: | |
| chunks = [] | |
| for file_path in sorted(extracted_root.rglob("*")): | |
| if not file_path.is_file() or not _is_relevant_text_file(file_path): | |
| continue | |
| try: | |
| relative_name = file_path.relative_to(extracted_root).as_posix() | |
| file_text = _read_text_file(file_path) | |
| except Exception: | |
| continue | |
| chunks.append(f'<file name="{relative_name}">\n{file_text}\n</file>\n') | |
| if not chunks: | |
| raise gr.Error( | |
| "No relevant text files found in the archive. Expected .tex/.text/.txt/.bib/.bbl/.md files." | |
| ) | |
| return "\n".join(chunks) | |
| def _convert_pdf_to_markdown(pdf_path: Path) -> str: | |
| try: | |
| from markitdown import MarkItDown | |
| except Exception as e: | |
| raise gr.Error( | |
| "MarkItDown SDK is not available. Make sure `markitdown[pdf]` is installed." | |
| ) from e | |
| try: | |
| converter = MarkItDown(enable_plugins=False) | |
| result = converter.convert(str(pdf_path)) | |
| text = result.text_content | |
| except Exception as e: | |
| raise gr.Error(f"PDF conversion failed with MarkItDown SDK: {e}") from e | |
| text = (text or "").strip() | |
| if not text: | |
| raise gr.Error("MarkItDown SDK produced empty output for this PDF.") | |
| return text | |
| def _build_paper_source_from_upload(uploaded_path: str) -> str: | |
| src_path = Path(uploaded_path) | |
| file_name = src_path.name.lower() | |
| if _is_relevant_text_file(src_path): | |
| return _read_text_file(src_path) | |
| with tempfile.TemporaryDirectory(prefix="paper_extract_") as tmpdir: | |
| tmp_root = Path(tmpdir) | |
| extract_root = tmp_root / "extracted" | |
| extract_root.mkdir(parents=True, exist_ok=True) | |
| if file_name.endswith(".zip"): | |
| _safe_extract_zip(src_path, extract_root) | |
| return _archive_to_tagged_source(extract_root) | |
| if file_name.endswith(".tar.gz") or file_name.endswith(".tgz") or file_name.endswith(".tar"): | |
| _safe_extract_tar(src_path, extract_root) | |
| return _archive_to_tagged_source(extract_root) | |
| if file_name.endswith(".gz") or file_name.endswith(".gzip"): | |
| # Handle compressed tar archives first. | |
| if tarfile.is_tarfile(src_path): | |
| _safe_extract_tar(src_path, extract_root) | |
| return _archive_to_tagged_source(extract_root) | |
| output_name = src_path.name | |
| if output_name.endswith(".gzip"): | |
| output_name = output_name[: -len(".gzip")] | |
| elif output_name.endswith(".gz"): | |
| output_name = output_name[: -len(".gz")] | |
| decompressed_path = extract_root / output_name | |
| with gzip.open(src_path, "rb") as gz_in, open(decompressed_path, "wb") as out_f: | |
| out_f.write(gz_in.read()) | |
| if _is_relevant_text_file(decompressed_path): | |
| return _read_text_file(decompressed_path) | |
| raise gr.Error( | |
| "Unsupported .gz/.gzip payload. It must contain a relevant text file or a tar archive." | |
| ) | |
| if file_name.endswith(".pdf"): | |
| return _convert_pdf_to_markdown(src_path) | |
| raise gr.Error( | |
| "Unsupported file type. Use text files (.tex/.text/.txt/.bib/.bbl/.md), " | |
| "archives (.zip/.tar/.tar.gz/.tgz/.gz/.gzip), or .pdf." | |
| ) | |
| def run_single_interaction( | |
| message_input, | |
| history, | |
| ablation_mode, | |
| num_ablations, | |
| temperature, | |
| top_p, | |
| model_id, | |
| provider_name, | |
| hf_token: gr.OAuthToken | None, | |
| ): | |
| """ | |
| For more information on `huggingface_hub` Inference API support, see: | |
| https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
| """ | |
| try: | |
| config = yaml.safe_load(Path("./prompts.yaml").read_text()) | |
| prompts = ( | |
| config["author_ablation"] | |
| if ablation_mode == "AuthorAblation" | |
| else config["reviewer_ablation"] | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load prompt configuration: {e}") from e | |
| prior_history = _sanitize_history(history) | |
| text = "" | |
| files = [] | |
| if isinstance(message_input, dict): | |
| text = (message_input.get("text") or "").strip() | |
| files = message_input.get("files") or [] | |
| else: | |
| text = (message_input or "").strip() | |
| has_text = bool(text) | |
| has_file = len(files) > 0 | |
| if has_text and has_file: | |
| raise gr.Error("Please submit either paper content text or one file upload, not both.") | |
| if not has_text and not has_file: | |
| raise gr.Error("Please provide paper content as text or upload one file.") | |
| if len(files) > 1: | |
| raise gr.Error("Please upload a single file only.") | |
| if hf_token is None: | |
| raise gr.Error("Please sign in with Hugging Face before submitting.") | |
| file_label = None | |
| file_path = None | |
| if has_file: | |
| file_item = files[0] | |
| file_path = file_item.get("path") if isinstance(file_item, dict) else file_item | |
| file_label = os.path.basename(file_path) if file_path else "uploaded_file" | |
| try: | |
| paper_source = text if has_text else _build_paper_source_from_upload(file_path) | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| raise gr.Error(f"Failed to process uploaded paper: {e}") from e | |
| user_prompt_template = prompts["user_prompt"] | |
| user_content = ( | |
| user_prompt_template.replace("{{paper_source}}", paper_source) | |
| .replace("{{num_ablations}}", str(num_ablations)) | |
| ) | |
| if has_file: | |
| source_hint = f"file: {file_label}" | |
| else: | |
| first_line = (text.splitlines()[0] if text else "").strip() | |
| first_line_words = first_line.split()[:100] | |
| preview = " ".join(first_line_words) | |
| source_hint = f"text preview: {preview}" if preview else "text preview: (empty)" | |
| if ablation_mode == "AuthorAblation": | |
| user_display = f"Planning {num_ablations} ablations for submitted paper ({source_hint})." | |
| else: | |
| user_display = f"Reviewing and suggesting {num_ablations} missing ablations for submitted paper ({source_hint})." | |
| client = InferenceClient( | |
| token=hf_token.token, | |
| model=model_id, | |
| provider=provider_name, | |
| ) | |
| # Keep full chat visible to users, but send only current input to model. | |
| messages = [ | |
| {"role": "system", "content": prompts["system_prompt"]}, | |
| {"role": "user", "content": user_content}, | |
| ] | |
| live_history = [ | |
| gr.ChatMessage( | |
| role=item["role"], | |
| content=item["content"], | |
| metadata=item.get("metadata") or {}, | |
| ) | |
| for item in prior_history | |
| ] | |
| live_history.append(gr.ChatMessage(role="user", content=user_display)) | |
| if has_file and ablation_mode == "AuthorAblation" and "ablat" in paper_source.lower(): | |
| gr.Warning("Uploaded paper appears to already contain ablation content (`ablat*`).") | |
| live_history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="", | |
| metadata={"title": "⏳ Discussion", "status": "pending"}, | |
| ) | |
| ) | |
| emitted = False | |
| raw_output = "" | |
| predictions_message_idx = None | |
| try: | |
| for chunk in client.chat_completion( | |
| messages, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| choices = chunk.choices | |
| token = "" | |
| if len(choices) and choices[0].delta.content: | |
| token = choices[0].delta.content | |
| raw_output += token | |
| discussion_text, discussion_started, _ = _extract_streaming_tag_content( | |
| raw_output, "discussion" | |
| ) | |
| predictions_text, predictions_started, predictions_closed = ( | |
| _extract_streaming_tag_content(raw_output, "predictions") | |
| ) | |
| prediction_items = _parse_predictions_jsonl_items( | |
| predictions_text, predictions_closed | |
| ) | |
| if discussion_started: | |
| discussion_status = "done" if predictions_started else "pending" | |
| live_history[-1] = gr.ChatMessage( | |
| role="assistant", | |
| content=discussion_text or "_(discussion is empty so far)_", | |
| metadata={"title": "⏳ Discussion", "status": discussion_status}, | |
| ) | |
| else: | |
| live_history[-1] = gr.ChatMessage( | |
| role="assistant", | |
| content="", | |
| metadata={"title": "⏳ Discussion", "status": "pending"}, | |
| ) | |
| if predictions_started: | |
| predictions_markdown = _format_predictions_markdown(prediction_items) | |
| if predictions_message_idx is None: | |
| live_history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=predictions_markdown, | |
| ) | |
| ) | |
| predictions_message_idx = len(live_history) - 1 | |
| else: | |
| live_history[predictions_message_idx] = gr.ChatMessage( | |
| role="assistant", | |
| content=predictions_markdown, | |
| ) | |
| emitted = True | |
| yield live_history | |
| except BadRequestError as e: | |
| message = str(e) | |
| if "model_not_supported" in message: | |
| raise gr.Error( | |
| f"Model/provider mismatch for model '{model_id}' with provider '{provider_name}'. " | |
| "Your token is valid, but no enabled provider can serve this model. " | |
| "Pick a different model or provider." | |
| ) from e | |
| raise gr.Error(f"Inference request failed: {message}") from e | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| raise gr.Error(f"Unexpected error during generation: {e}") from e | |
| if not emitted: | |
| live_history[-1] = gr.ChatMessage( | |
| role="assistant", | |
| content="_No discussion block found in model output._", | |
| metadata={"title": "⏳ Discussion", "status": "done"}, | |
| ) | |
| live_history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="_No valid predictions JSONL found._", | |
| ) | |
| ) | |
| yield live_history | |
| def print_like_dislike(x: gr.LikeData): | |
| print(x.index, x.value, x.liked) | |
| def change_ablation_mode( | |
| ablation_mode, | |
| ): | |
| return gr.Slider( | |
| label="Number of ablations to generate", | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| precision=0, | |
| value=5 if ablation_mode == "AuthorAblation" else 3 | |
| ) | |
| def clear_chat(): | |
| return [] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # <span class="ablationbench">AblationBench:</span> Demo | |
| This demo lets you generate ablation plans for your paper using our baseline LM-Planner, introduced in *AblationBench: Evaluating Automated Planning of Ablations in Empirical AI Research*. | |
| Choose one of two tasks: | |
| - <span class="authorablation">AuthorAblation</span> - Generate ablations from a method section (for authors). | |
| - <span class="reviewerablation">ReviewerAblation</span> - Suggest missing ablations from a full paper (for reviewers). | |
| **Guidelines:** | |
| - For <span class="authorablation">AuthorAblation</span>, include the method section and remove any existing ablation experiments. | |
| - For <span class="reviewerablation">ReviewerAblation</span>, include the full paper. | |
| - Text files work best (preferred over PDFs). | |
| Upload your paper as text, a file, or a zipped Overleaf project. | |
| Learn more on our [🌍 project page](https://ablation-bench.github.io/#/), explore the [🤗 benchmark](https://huggingface.co/collections/ai-coscientist/ablationbench) or read our [📎 paper](https://www.arxiv.org/abs/2507.08038). | |
| """, | |
| sanitize_html=False, | |
| ) | |
| # Keep mode choice visible at the top. | |
| ablation_mode = gr.Radio( | |
| choices=["AuthorAblation", "ReviewerAblation"], | |
| value="AuthorAblation", | |
| label="Ablation mode", | |
| elem_id="ablation-mode", | |
| ) | |
| chatbot = gr.Chatbot( | |
| label="Ablation Plan", | |
| buttons=["copy"], | |
| avatar_images=("https://ablation-bench.github.io/_media/user_avatar.png", "https://ablation-bench.github.io/_media/lm_avatar.png"), | |
| ) | |
| message_input = gr.MultimodalTextbox( | |
| label="Paper content", | |
| placeholder="Enter your paper text here, or upload one file: TEX, MD, PDF, ZIP, or GZIP.", | |
| lines=5, | |
| file_count="single", | |
| file_types=[ | |
| "text", | |
| ".tex", | |
| ".text", | |
| ".txt", | |
| ".bib", | |
| ".bbl", | |
| ".md", | |
| ".zip", | |
| ".tar", | |
| ".tar.gz", | |
| ".tgz", | |
| ".gz", | |
| ".gzip", | |
| ".pdf", | |
| ], | |
| max_lines=1000, | |
| ) | |
| with gr.Accordion("Advanced settings", open=False): | |
| num_ablations = gr.Slider( | |
| label="Number of ablations to generate", | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| precision=0, | |
| value=5 if ablation_mode.value == "AuthorAblation" else 3 | |
| ) | |
| model_id = gr.Dropdown( | |
| choices=[ | |
| "openai/gpt-oss-120b", | |
| "MiniMaxAI/MiniMax-M2.5", | |
| "Qwen/Qwen3.5-397B-A17B", | |
| "moonshotai/Kimi-K2.5", | |
| "moonshotai/Kimi-K2-Thinking", | |
| "moonshotai/Kimi-K2-Instruct", | |
| "deepseek-ai/DeepSeek-V3.2", | |
| "zai-org/GLM-5", | |
| "Qwen/Qwen3-235B-A22B-Instruct-2507", | |
| ], | |
| value="openai/gpt-oss-120b", | |
| label="Model ID", | |
| ) | |
| provider_name = gr.Dropdown( | |
| choices=[ | |
| "black-forest-labs", | |
| "cerebras", | |
| "clarifai", | |
| "cohere", | |
| "fal-ai", | |
| "featherless-ai", | |
| "fireworks-ai", | |
| "groq", | |
| "hf-inference", | |
| "hyperbolic", | |
| "nebius", | |
| "novita", | |
| "nscale", | |
| "openai", | |
| "ovhcloud", | |
| "publicai", | |
| "replicate", | |
| "sambanova", | |
| "scaleway", | |
| "together", | |
| "wavespeed", | |
| "zai-org", | |
| ], | |
| value="groq", | |
| label="Provider", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0, | |
| step=0.1, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ) | |
| with gr.Sidebar(): | |
| gr.Markdown("""<center><img src="https://ablation-bench.github.io/_media/icon.png"></center>""", sanitize_html=False) | |
| gr.LoginButton() | |
| message_input.submit( | |
| run_single_interaction, | |
| inputs=[ | |
| message_input, | |
| chatbot, | |
| ablation_mode, | |
| num_ablations, | |
| temperature, | |
| top_p, | |
| model_id, | |
| provider_name, | |
| ], | |
| outputs=[ | |
| chatbot, | |
| ], | |
| ) | |
| chatbot.clear( | |
| clear_chat, | |
| outputs=[ | |
| chatbot, | |
| ] | |
| ) | |
| ablation_mode.input( | |
| change_ablation_mode, | |
| inputs=[ | |
| ablation_mode, | |
| ], | |
| outputs=[ | |
| num_ablations, | |
| ] | |
| ) | |
| chatbot.like(print_like_dislike) | |
| if __name__ == "__main__": | |
| demo.launch(css_paths=Path("style.css")) | |