ablation-bench / app.py
Talor Abramovich
removed debug print
ce8e77a
import os
import json
import yaml
import gzip
import tarfile
import zipfile
import tempfile
from pathlib import Path
import gradio as gr
from huggingface_hub import InferenceClient
from huggingface_hub.errors import BadRequestError
TEXT_EXTENSIONS = {".tex", ".text", ".txt", ".bib", ".bbl", ".md"}
def _normalize_message_content(content):
if isinstance(content, list):
return "\n".join(_normalize_message_content(item) for item in content)
if isinstance(content, str):
return content
if isinstance(content, dict):
if "path" in content:
return f"[uploaded file: {os.path.basename(content['path'])}]"
return str(content.get("text", ""))
if isinstance(content, gr.ChatMessage):
return _normalize_message_content(content.content)
return str(content)
def _sanitize_history(history):
clean = []
for msg in history or []:
if isinstance(msg, gr.ChatMessage):
role = msg.role
content = msg.content
metadata = msg.metadata
elif isinstance(msg, dict):
role = msg.get("role")
content = msg.get("content", "")
metadata = msg.get("metadata")
else:
continue
if role not in {"user", "assistant", "system"}:
continue
message = {"role": role, "content": _normalize_message_content(content)}
if metadata:
message["metadata"] = metadata
clean.append(message)
return clean
def _extract_streaming_tag_content(text, tag):
source = text or ""
start_tag = f"<{tag}>"
end_tag = f"</{tag}>"
start_idx = source.find(start_tag)
if start_idx == -1:
return "", False, False
content_start = start_idx + len(start_tag)
end_idx = source.find(end_tag, content_start)
if end_idx == -1:
return source[content_start:].lstrip(), True, False
return source[content_start:end_idx].strip(), True, True
def _parse_predictions_jsonl_items(predictions_block, closed):
items = []
lines = (predictions_block or "").splitlines(keepends=True)
for idx, line in enumerate(lines):
is_last = idx == len(lines) - 1
line_is_complete = line.endswith("\n") or line.endswith("\r")
if not line_is_complete and not (closed and is_last):
continue
stripped = line.strip()
if not stripped:
continue
try:
items.append(json.loads(stripped))
except json.JSONDecodeError:
continue
return items
def _format_predictions_markdown(items):
if not items:
return ""
out = []
for i, item in enumerate(items, start=1):
name = item.get("name", "Untitled Ablation")
ablated_part = item.get("ablated_part", "")
action = item.get("action", "")
replacement = item.get("replacement")
metrics = item.get("metrics", [])
if not isinstance(metrics, list):
metrics = [str(metrics)]
out.append(f"### {i}. **{name}**")
out.append(f"&nbsp;&nbsp;&nbsp;&nbsp;<u>Component</u>: {ablated_part}")
out.append(f"&nbsp;&nbsp;&nbsp;&nbsp;<u>Action</u>: {action}")
if replacement is not None and replacement != "":
out.append("&nbsp;&nbsp;&nbsp;&nbsp;<u>Replacement</u>:")
if isinstance(replacement, list):
for rep in replacement:
out.append(f"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- `{rep}`")
else:
out.append(f"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- `{replacement}`")
out.append("&nbsp;&nbsp;&nbsp;&nbsp;<u>Metrics</u>:")
if metrics:
for metric in metrics:
out.append(f"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- `{metric}`")
else:
out.append("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- `(none)`")
out.append("\n---\n")
return "\n".join(out).strip("- \n")
def _read_text_file(path: Path) -> str:
return path.read_text(encoding="utf-8", errors="ignore")
def _is_relevant_text_file(path: Path) -> bool:
return path.suffix.lower() in TEXT_EXTENSIONS
def _safe_extract_zip(zip_path: Path, output_dir: Path) -> None:
with zipfile.ZipFile(zip_path, "r") as zf:
for member in zf.infolist():
member_path = output_dir / member.filename
resolved_member = member_path.resolve()
resolved_root = output_dir.resolve()
if not str(resolved_member).startswith(str(resolved_root)):
continue
zf.extract(member, output_dir)
def _safe_extract_tar(tar_path: Path, output_dir: Path) -> None:
with tarfile.open(tar_path, "r:*") as tf:
for member in tf.getmembers():
member_path = output_dir / member.name
resolved_member = member_path.resolve()
resolved_root = output_dir.resolve()
if not str(resolved_member).startswith(str(resolved_root)):
continue
tf.extract(member, output_dir)
def _archive_to_tagged_source(extracted_root: Path) -> str:
chunks = []
for file_path in sorted(extracted_root.rglob("*")):
if not file_path.is_file() or not _is_relevant_text_file(file_path):
continue
try:
relative_name = file_path.relative_to(extracted_root).as_posix()
file_text = _read_text_file(file_path)
except Exception:
continue
chunks.append(f'<file name="{relative_name}">\n{file_text}\n</file>\n')
if not chunks:
raise gr.Error(
"No relevant text files found in the archive. Expected .tex/.text/.txt/.bib/.bbl/.md files."
)
return "\n".join(chunks)
def _convert_pdf_to_markdown(pdf_path: Path) -> str:
try:
from markitdown import MarkItDown
except Exception as e:
raise gr.Error(
"MarkItDown SDK is not available. Make sure `markitdown[pdf]` is installed."
) from e
try:
converter = MarkItDown(enable_plugins=False)
result = converter.convert(str(pdf_path))
text = result.text_content
except Exception as e:
raise gr.Error(f"PDF conversion failed with MarkItDown SDK: {e}") from e
text = (text or "").strip()
if not text:
raise gr.Error("MarkItDown SDK produced empty output for this PDF.")
return text
def _build_paper_source_from_upload(uploaded_path: str) -> str:
src_path = Path(uploaded_path)
file_name = src_path.name.lower()
if _is_relevant_text_file(src_path):
return _read_text_file(src_path)
with tempfile.TemporaryDirectory(prefix="paper_extract_") as tmpdir:
tmp_root = Path(tmpdir)
extract_root = tmp_root / "extracted"
extract_root.mkdir(parents=True, exist_ok=True)
if file_name.endswith(".zip"):
_safe_extract_zip(src_path, extract_root)
return _archive_to_tagged_source(extract_root)
if file_name.endswith(".tar.gz") or file_name.endswith(".tgz") or file_name.endswith(".tar"):
_safe_extract_tar(src_path, extract_root)
return _archive_to_tagged_source(extract_root)
if file_name.endswith(".gz") or file_name.endswith(".gzip"):
# Handle compressed tar archives first.
if tarfile.is_tarfile(src_path):
_safe_extract_tar(src_path, extract_root)
return _archive_to_tagged_source(extract_root)
output_name = src_path.name
if output_name.endswith(".gzip"):
output_name = output_name[: -len(".gzip")]
elif output_name.endswith(".gz"):
output_name = output_name[: -len(".gz")]
decompressed_path = extract_root / output_name
with gzip.open(src_path, "rb") as gz_in, open(decompressed_path, "wb") as out_f:
out_f.write(gz_in.read())
if _is_relevant_text_file(decompressed_path):
return _read_text_file(decompressed_path)
raise gr.Error(
"Unsupported .gz/.gzip payload. It must contain a relevant text file or a tar archive."
)
if file_name.endswith(".pdf"):
return _convert_pdf_to_markdown(src_path)
raise gr.Error(
"Unsupported file type. Use text files (.tex/.text/.txt/.bib/.bbl/.md), "
"archives (.zip/.tar/.tar.gz/.tgz/.gz/.gzip), or .pdf."
)
def run_single_interaction(
message_input,
history,
ablation_mode,
num_ablations,
temperature,
top_p,
model_id,
provider_name,
hf_token: gr.OAuthToken | None,
):
"""
For more information on `huggingface_hub` Inference API support, see:
https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
try:
config = yaml.safe_load(Path("./prompts.yaml").read_text())
prompts = (
config["author_ablation"]
if ablation_mode == "AuthorAblation"
else config["reviewer_ablation"]
)
except Exception as e:
raise gr.Error(f"Failed to load prompt configuration: {e}") from e
prior_history = _sanitize_history(history)
text = ""
files = []
if isinstance(message_input, dict):
text = (message_input.get("text") or "").strip()
files = message_input.get("files") or []
else:
text = (message_input or "").strip()
has_text = bool(text)
has_file = len(files) > 0
if has_text and has_file:
raise gr.Error("Please submit either paper content text or one file upload, not both.")
if not has_text and not has_file:
raise gr.Error("Please provide paper content as text or upload one file.")
if len(files) > 1:
raise gr.Error("Please upload a single file only.")
if hf_token is None:
raise gr.Error("Please sign in with Hugging Face before submitting.")
file_label = None
file_path = None
if has_file:
file_item = files[0]
file_path = file_item.get("path") if isinstance(file_item, dict) else file_item
file_label = os.path.basename(file_path) if file_path else "uploaded_file"
try:
paper_source = text if has_text else _build_paper_source_from_upload(file_path)
except gr.Error:
raise
except Exception as e:
raise gr.Error(f"Failed to process uploaded paper: {e}") from e
user_prompt_template = prompts["user_prompt"]
user_content = (
user_prompt_template.replace("{{paper_source}}", paper_source)
.replace("{{num_ablations}}", str(num_ablations))
)
if has_file:
source_hint = f"file: {file_label}"
else:
first_line = (text.splitlines()[0] if text else "").strip()
first_line_words = first_line.split()[:100]
preview = " ".join(first_line_words)
source_hint = f"text preview: {preview}" if preview else "text preview: (empty)"
if ablation_mode == "AuthorAblation":
user_display = f"Planning {num_ablations} ablations for submitted paper ({source_hint})."
else:
user_display = f"Reviewing and suggesting {num_ablations} missing ablations for submitted paper ({source_hint})."
client = InferenceClient(
token=hf_token.token,
model=model_id,
provider=provider_name,
)
# Keep full chat visible to users, but send only current input to model.
messages = [
{"role": "system", "content": prompts["system_prompt"]},
{"role": "user", "content": user_content},
]
live_history = [
gr.ChatMessage(
role=item["role"],
content=item["content"],
metadata=item.get("metadata") or {},
)
for item in prior_history
]
live_history.append(gr.ChatMessage(role="user", content=user_display))
if has_file and ablation_mode == "AuthorAblation" and "ablat" in paper_source.lower():
gr.Warning("Uploaded paper appears to already contain ablation content (`ablat*`).")
live_history.append(
gr.ChatMessage(
role="assistant",
content="",
metadata={"title": "⏳ Discussion", "status": "pending"},
)
)
emitted = False
raw_output = ""
predictions_message_idx = None
try:
for chunk in client.chat_completion(
messages,
stream=True,
temperature=temperature,
top_p=top_p,
):
choices = chunk.choices
token = ""
if len(choices) and choices[0].delta.content:
token = choices[0].delta.content
raw_output += token
discussion_text, discussion_started, _ = _extract_streaming_tag_content(
raw_output, "discussion"
)
predictions_text, predictions_started, predictions_closed = (
_extract_streaming_tag_content(raw_output, "predictions")
)
prediction_items = _parse_predictions_jsonl_items(
predictions_text, predictions_closed
)
if discussion_started:
discussion_status = "done" if predictions_started else "pending"
live_history[-1] = gr.ChatMessage(
role="assistant",
content=discussion_text or "_(discussion is empty so far)_",
metadata={"title": "⏳ Discussion", "status": discussion_status},
)
else:
live_history[-1] = gr.ChatMessage(
role="assistant",
content="",
metadata={"title": "⏳ Discussion", "status": "pending"},
)
if predictions_started:
predictions_markdown = _format_predictions_markdown(prediction_items)
if predictions_message_idx is None:
live_history.append(
gr.ChatMessage(
role="assistant",
content=predictions_markdown,
)
)
predictions_message_idx = len(live_history) - 1
else:
live_history[predictions_message_idx] = gr.ChatMessage(
role="assistant",
content=predictions_markdown,
)
emitted = True
yield live_history
except BadRequestError as e:
message = str(e)
if "model_not_supported" in message:
raise gr.Error(
f"Model/provider mismatch for model '{model_id}' with provider '{provider_name}'. "
"Your token is valid, but no enabled provider can serve this model. "
"Pick a different model or provider."
) from e
raise gr.Error(f"Inference request failed: {message}") from e
except gr.Error:
raise
except Exception as e:
raise gr.Error(f"Unexpected error during generation: {e}") from e
if not emitted:
live_history[-1] = gr.ChatMessage(
role="assistant",
content="_No discussion block found in model output._",
metadata={"title": "⏳ Discussion", "status": "done"},
)
live_history.append(
gr.ChatMessage(
role="assistant",
content="_No valid predictions JSONL found._",
)
)
yield live_history
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def change_ablation_mode(
ablation_mode,
):
return gr.Slider(
label="Number of ablations to generate",
minimum=1,
maximum=10,
step=1,
precision=0,
value=5 if ablation_mode == "AuthorAblation" else 3
)
def clear_chat():
return []
with gr.Blocks() as demo:
gr.Markdown(
"""
# <span class="ablationbench">AblationBench:</span> Demo
This demo lets you generate ablation plans for your paper using our baseline LM-Planner, introduced in *AblationBench: Evaluating Automated Planning of Ablations in Empirical AI Research*.
Choose one of two tasks:
- <span class="authorablation">AuthorAblation</span> - Generate ablations from a method section (for authors).
- <span class="reviewerablation">ReviewerAblation</span> - Suggest missing ablations from a full paper (for reviewers).
**Guidelines:**
- For <span class="authorablation">AuthorAblation</span>, include the method section and remove any existing ablation experiments.
- For <span class="reviewerablation">ReviewerAblation</span>, include the full paper.
- Text files work best (preferred over PDFs).
Upload your paper as text, a file, or a zipped Overleaf project.
Learn more on our [🌍 project page](https://ablation-bench.github.io/#/), explore the [🤗 benchmark](https://huggingface.co/collections/ai-coscientist/ablationbench) or read our [📎 paper](https://www.arxiv.org/abs/2507.08038).
""",
sanitize_html=False,
)
# Keep mode choice visible at the top.
ablation_mode = gr.Radio(
choices=["AuthorAblation", "ReviewerAblation"],
value="AuthorAblation",
label="Ablation mode",
elem_id="ablation-mode",
)
chatbot = gr.Chatbot(
label="Ablation Plan",
buttons=["copy"],
avatar_images=("https://ablation-bench.github.io/_media/user_avatar.png", "https://ablation-bench.github.io/_media/lm_avatar.png"),
)
message_input = gr.MultimodalTextbox(
label="Paper content",
placeholder="Enter your paper text here, or upload one file: TEX, MD, PDF, ZIP, or GZIP.",
lines=5,
file_count="single",
file_types=[
"text",
".tex",
".text",
".txt",
".bib",
".bbl",
".md",
".zip",
".tar",
".tar.gz",
".tgz",
".gz",
".gzip",
".pdf",
],
max_lines=1000,
)
with gr.Accordion("Advanced settings", open=False):
num_ablations = gr.Slider(
label="Number of ablations to generate",
minimum=1,
maximum=10,
step=1,
precision=0,
value=5 if ablation_mode.value == "AuthorAblation" else 3
)
model_id = gr.Dropdown(
choices=[
"openai/gpt-oss-120b",
"MiniMaxAI/MiniMax-M2.5",
"Qwen/Qwen3.5-397B-A17B",
"moonshotai/Kimi-K2.5",
"moonshotai/Kimi-K2-Thinking",
"moonshotai/Kimi-K2-Instruct",
"deepseek-ai/DeepSeek-V3.2",
"zai-org/GLM-5",
"Qwen/Qwen3-235B-A22B-Instruct-2507",
],
value="openai/gpt-oss-120b",
label="Model ID",
)
provider_name = gr.Dropdown(
choices=[
"black-forest-labs",
"cerebras",
"clarifai",
"cohere",
"fal-ai",
"featherless-ai",
"fireworks-ai",
"groq",
"hf-inference",
"hyperbolic",
"nebius",
"novita",
"nscale",
"openai",
"ovhcloud",
"publicai",
"replicate",
"sambanova",
"scaleway",
"together",
"wavespeed",
"zai-org",
],
value="groq",
label="Provider",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0,
step=0.1,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
)
with gr.Sidebar():
gr.Markdown("""<center><img src="https://ablation-bench.github.io/_media/icon.png"></center>""", sanitize_html=False)
gr.LoginButton()
message_input.submit(
run_single_interaction,
inputs=[
message_input,
chatbot,
ablation_mode,
num_ablations,
temperature,
top_p,
model_id,
provider_name,
],
outputs=[
chatbot,
],
)
chatbot.clear(
clear_chat,
outputs=[
chatbot,
]
)
ablation_mode.input(
change_ablation_mode,
inputs=[
ablation_mode,
],
outputs=[
num_ablations,
]
)
chatbot.like(print_like_dislike)
if __name__ == "__main__":
demo.launch(css_paths=Path("style.css"))