egor-bogomolov's picture
Sort dataset dropdown alphabetically by display name
454263d
"""
ML4SE Benchmark Viewer
A web-based interface for browsing and inspecting individual datapoints
from popular ML4SE benchmark datasets (REval, CRUXEval, HumanEval+,
BigOBench, and others).
"""
import ast as _ast
import os
from flask import Flask, jsonify, render_template, request
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import PythonLexer, get_lexer_by_name
app = Flask(__name__)
def _extract_test_classes(test_code: str, cls_name: str) -> list:
"""
Parse a ClassEval unittest module and return one dict per test class
in definition order: {"name": ..., "code": ...}.
Matches top-level classes whose names start with f"{cls_name}Test",
which is the same pattern used by ClassFactory.create_test_classes().
Uses ast.parse only — no code execution, safe to call from the web server.
"""
try:
tree = _ast.parse(test_code)
except SyntaxError as e:
print(f"Warning: SyntaxError parsing test code for {cls_name}: {e}")
return []
lines = test_code.splitlines(keepends=True)
prefix = f"{cls_name}Test"
result = []
for node in tree.body: # top-level definitions, preserves source order
if isinstance(node, _ast.ClassDef) and node.name.startswith(prefix):
start = node.lineno - 1 # ast lineno is 1-indexed
end = node.end_lineno # end_lineno is inclusive; slice is exclusive
result.append(
{
"name": node.name,
"code": "".join(lines[start:end]),
}
)
return result
def _code_offset(code: str) -> int:
"""Number of leading newlines that Pygments will strip."""
offset = 0
for ch in code:
if ch == "\n":
offset += 1
else:
break
return offset
def highlight_code(code, highlight_lines=None, language="python"):
"""
Syntax highlight code with optional line highlighting.
Args:
code: The source code to highlight
highlight_lines: List of line numbers (1-indexed) to highlight
language: Programming language name (default: "python").
Must be a key in LEXER_MAP.
Returns:
HTML string with syntax highlighted code
"""
formatter = HtmlFormatter(
linenos="table", cssclass="source", hl_lines=highlight_lines or [], linenostart=1
)
try:
lexer = get_lexer_by_name(language.lower())
except Exception:
lexer = PythonLexer()
return highlight(code, lexer, formatter)
def get_css():
"""Get CSS for syntax highlighting."""
return HtmlFormatter().get_style_defs(".source")
# ---------------------------------------------------------------------------
# Dataset adapter registration
# ---------------------------------------------------------------------------
from adapters import REGISTRY, _set_helpers, register_hf_datasets # noqa: E402
# Inject helper functions into the adapters module (avoids circular imports)
_set_helpers(highlight_code, _code_offset, _extract_test_classes)
# Register all HuggingFace datasets
register_hf_datasets()
def _get_adapter(dataset_slug: str):
"""Return the adapter for the given slug, or None."""
return REGISTRY.get(dataset_slug)
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.route("/")
def index():
"""Main page showing list of all benchmark problems."""
return render_template("index.html")
@app.route("/api/datasets")
def get_datasets():
"""Return list of available datasets for the UI dataset selector."""
items = sorted(
(
{
"slug": slug,
"display_name": adapter.display_name,
"problem_count": adapter.problem_count(),
"total_count": adapter.total_count,
"has_ground_truth": adapter.has_ground_truth,
}
for slug, adapter in REGISTRY.items()
),
key=lambda d: d["display_name"].lower(),
)
return jsonify(items)
@app.route("/api/<dataset_slug>/problems")
def get_problems(dataset_slug):
"""API endpoint to get list of all problems for a dataset."""
adapter = _get_adapter(dataset_slug)
if adapter is None:
return jsonify({"error": f"Unknown dataset: {dataset_slug}"}), 404
problems = [adapter.get_problem_summary(i) for i in range(adapter.problem_count())]
return jsonify(problems)
@app.route("/api/<dataset_slug>/problem/<int:idx>")
def get_problem(idx, dataset_slug):
"""API endpoint to get detailed information about a specific problem."""
adapter = _get_adapter(dataset_slug)
if adapter is None:
return jsonify({"error": f"Unknown dataset: {dataset_slug}"}), 404
if not (0 <= idx < adapter.problem_count()):
return jsonify({"error": "Invalid problem index"}), 404
try:
return jsonify(adapter.get_problem_detail(idx))
except (KeyError, IndexError, ValueError) as exc:
return jsonify({"error": f"Internal error: {exc}"}), 500
@app.route("/api/highlight_code")
def highlight_code_api():
"""API endpoint to highlight code with specific lines."""
code = request.args.get("code", "")
lines_str = request.args.get("lines", "")
if lines_str:
try:
lines = [int(x) for x in lines_str.split(",") if x.strip()]
except ValueError:
return jsonify({"error": "Invalid line numbers"}), 400
else:
lines = None
highlighted = highlight_code(code, lines)
return jsonify({"highlighted_code": highlighted})
@app.route("/<dataset_slug>/problem/<int:idx>")
def problem_detail(idx, dataset_slug):
"""Page showing detailed view of a specific problem."""
adapter = _get_adapter(dataset_slug)
if adapter is None:
return jsonify({"error": "Unknown dataset"}), 404
if not (0 <= idx < adapter.problem_count()):
return jsonify({"error": "Problem not found"}), 404
return render_template(
"problem.html",
idx=idx,
css=get_css(),
total_problems=adapter.problem_count(),
dataset_slug=dataset_slug,
dataset_name=adapter.display_name,
has_ground_truth=adapter.has_ground_truth,
has_tasks=adapter.has_tasks,
)
@app.route("/api/css")
def get_css_api():
"""API endpoint to get CSS for syntax highlighting."""
return get_css(), 200, {"Content-Type": "text/css"}
@app.route("/api/<dataset_slug>/problem/<int:idx>/ground_truth/<int:input_idx>")
def get_ground_truth(idx, input_idx, dataset_slug):
"""Return ground truth execution data for one (problem, input) pair."""
adapter = _get_adapter(dataset_slug)
if adapter is None:
return jsonify({"error": f"Unknown dataset: {dataset_slug}"}), 404
if not adapter.has_ground_truth:
return jsonify({"status": "unavailable", "message": "Ground truth not available"}), 200
if not (0 <= idx < adapter.problem_count()):
return jsonify({"error": "Invalid problem index"}), 404
result = adapter.get_ground_truth(idx, input_idx)
return jsonify(result)
if __name__ == "__main__":
debug_mode = os.getenv("FLASK_DEBUG", "false").lower() == "true"
port = int(os.getenv("PORT", 7860))
app.run(debug=debug_mode, host="0.0.0.0", port=port)