FlakyTestSleuthOpenEnvRL / dataset /build_dataset.py
vedkdev's picture
Upload folder using huggingface_hub
761f203 verified
"""Offline dataset builder for FlakySleuth.
Examples:
# Validate schema and show category/status summary only
python dataset/build_dataset.py --input py-data.csv --validate-only
# Build full task CSV (requires network access for repo cloning)
export GITHUB_TOKEN=...
python dataset/build_dataset.py --input py-data.csv --output dataset/py_tasks.csv
"""
from __future__ import annotations
import argparse
import csv
import os
import subprocess
import tempfile
from pathlib import Path
from urllib.parse import urlparse
import pandas as pd
import requests
try:
from tqdm import tqdm
except Exception: # pragma: no cover
tqdm = None
TASK12_CATEGORIES = ["NOD", "TD", "TZD", "NIO", "ID", "OD", "OD-Brit", "OD-Vic"]
TASK3_CATEGORIES = ["TD", "TZD", "NOD", "NIO", "ID"]
PROJECT_URL_COL = "Project URL"
SHA_COL = "SHA Detected"
CATEGORY_COL = "Category"
STATUS_COL = "Status"
PR_LINK_COL = "PR Link"
NOTES_COL = "Notes"
TEST_NAME_ALIASES = [
"Pytest Test Name",
"Pytest Test Name (PathToFile::TestClass::TestMethod or PathToFile::TestMethod)",
]
OUTPUT_COLUMNS = [
"repo_url",
"sha",
"test_name",
"test_file",
"category",
"label",
"status",
"pr_link",
"task_types",
"test_code",
"known_fix_diff",
]
def _normalize_header(text: str) -> str:
return " ".join(str(text).strip().split())
def _resolve_test_name_column(columns: list[str]) -> str:
normalized = {_normalize_header(c): c for c in columns}
for alias in TEST_NAME_ALIASES:
key = _normalize_header(alias)
if key in normalized:
return normalized[key]
raise KeyError(
"Could not find pytest test-name column. Expected one of: "
+ ", ".join(TEST_NAME_ALIASES)
)
def _parse_pr_link(pr_link: str) -> tuple[str, str] | None:
"""Return (owner/repo, number) from URL or owner/repo#number."""
value = (pr_link or "").strip()
if not value or value.lower() == "nan":
return None
if value.startswith("http://") or value.startswith("https://"):
parsed = urlparse(value)
parts = [p for p in parsed.path.split("/") if p]
# Expected: /owner/repo/pull/number
if len(parts) >= 4 and parts[2] == "pull" and parts[3].isdigit():
return f"{parts[0]}/{parts[1]}", parts[3]
return None
if "#" in value:
repo, number = value.split("#", 1)
if repo.strip() and number.strip().isdigit():
return repo.strip(), number.strip()
return None
def _is_accepted_status(status: str) -> bool:
value = (status or "").strip().lower()
return value in {"accepted", "merged", "fixed"}
def _non_interactive_git_env() -> dict[str, str]:
env = os.environ.copy()
# Never block on credential prompts while iterating large public datasets.
env["GIT_TERMINAL_PROMPT"] = "0"
env["GCM_INTERACTIVE"] = "Never"
return env
def _has_value(value: str) -> bool:
text = str(value or "").strip().lower()
return text not in {"", "nan", "none"}
def _is_non_unmaintained_status(status: str) -> bool:
value = str(status or "").strip().lower()
return value not in {"", "nan", "none", "unmaintained"}
def _row_preference_rank(row_out: dict[str, str]) -> tuple[int, int, int]:
task_tokens = {t.strip() for t in str(row_out.get("task_types", "")).split(";") if t.strip()}
return (
1 if "fix_proposal" in task_tokens else 0,
1 if _has_value(str(row_out.get("pr_link", ""))) else 0,
1 if _is_non_unmaintained_status(str(row_out.get("status", ""))) else 0,
)
def fetch_test_code(repo_url: str, sha: str, pytest_test_name: str) -> tuple[str, str, str]:
test_file = pytest_test_name.split("::")[0]
git_env = _non_interactive_git_env()
with tempfile.TemporaryDirectory() as tmpdir:
try:
init = subprocess.run(
["git", "init", tmpdir],
capture_output=True,
text=True,
check=False,
timeout=20,
env=git_env,
stdin=subprocess.DEVNULL,
)
if init.returncode != 0:
return "", "git_init_failed", (init.stderr or init.stdout or "").strip()[:200]
remote = subprocess.run(
["git", "-C", tmpdir, "remote", "add", "origin", repo_url],
capture_output=True,
text=True,
check=False,
timeout=10,
env=git_env,
stdin=subprocess.DEVNULL,
)
if remote.returncode != 0:
return "", "git_remote_add_failed", (remote.stderr or remote.stdout or "").strip()[:200]
# Fetch only the requested commit for speed and correctness.
fetch = subprocess.run(
["git", "-C", tmpdir, "fetch", "--depth=1", "origin", sha],
capture_output=True,
text=True,
check=False,
timeout=90,
env=git_env,
stdin=subprocess.DEVNULL,
)
if fetch.returncode != 0:
return "", "git_fetch_sha_failed", (fetch.stderr or fetch.stdout or "").strip()[:200]
checkout = subprocess.run(
["git", "-C", tmpdir, "checkout", "--detach", "FETCH_HEAD"],
capture_output=True,
text=True,
check=False,
timeout=30,
env=git_env,
stdin=subprocess.DEVNULL,
)
if checkout.returncode != 0:
return "", "git_checkout_failed", (checkout.stderr or checkout.stdout or "").strip()[:200]
except subprocess.TimeoutExpired:
return "", "git_timeout", "timeout"
file_path = Path(tmpdir) / test_file
if not file_path.exists():
return "", "test_file_missing_at_sha", test_file
return file_path.read_text(encoding="utf-8", errors="replace")[:10000], "", ""
def fetch_pr_diff(pr_link: str, github_token: str) -> str:
parsed = _parse_pr_link(pr_link)
if not parsed:
return ""
repo, number = parsed
url = f"https://api.github.com/repos/{repo}/pulls/{number}"
headers = {
"Authorization": f"token {github_token}",
"Accept": "application/vnd.github.diff",
}
response = requests.get(url, headers=headers, timeout=15)
if response.status_code == 200:
return response.text[:3000]
return ""
def _validate_schema(input_csv: str) -> tuple[pd.DataFrame, str]:
df = pd.read_csv(input_csv)
df.columns = [_normalize_header(col) for col in df.columns]
missing = []
for required in [PROJECT_URL_COL, SHA_COL, CATEGORY_COL, STATUS_COL, PR_LINK_COL]:
if required not in df.columns:
missing.append(required)
if missing:
raise KeyError(f"Missing required columns: {missing}")
test_name_col = _resolve_test_name_column(list(df.columns))
return df, test_name_col
def _print_input_summary(df: pd.DataFrame, test_name_col: str) -> None:
print("Input schema check: OK")
print(f"Rows: {len(df)}")
print(f"Using test-name column: {test_name_col}")
print("Columns:", list(df.columns))
print("\nCategory distribution (top 20):")
print(df[CATEGORY_COL].fillna("").astype(str).value_counts().head(20))
print("\nStatus distribution:")
print(df[STATUS_COL].fillna("").astype(str).value_counts().head(20))
def build(
input_csv: str,
output_csv: str,
github_token: str,
*,
validate_only: bool = False,
limit: int | None = None,
) -> None:
df, test_name_col = _validate_schema(input_csv)
_print_input_summary(df, test_name_col)
if validate_only:
return
total_rows = min(len(df), limit) if limit is not None else len(df)
print(
f"\nStarting build over {total_rows} rows "
f"(this can take a while: cloning repos + reading files + optional PR diff fetch)"
)
stats: dict[str, int] = {
"kept": 0,
"kept_unique": 0,
"skipped_missing_core_fields": 0,
"skipped_ud": 0,
"skipped_no_task_types": 0,
"skipped_test_code_fetch_failed": 0,
"skipped_test_code_fetch_git_fail": 0,
"skipped_test_code_fetch_file_missing": 0,
"fix_diff_fetched": 0,
"duplicate_key_rows_seen": 0,
"duplicate_key_replaced": 0,
"duplicate_key_kept_existing": 0,
}
fetch_fail_examples: list[dict[str, str]] = []
canonical_rows: dict[tuple[str, str, str], dict[str, str]] = {}
output_path = Path(output_csv)
output_path.parent.mkdir(parents=True, exist_ok=True)
iterator = df.iterrows()
if tqdm is not None:
iterator = tqdm(iterator, total=total_rows, desc="Building tasks", unit="row")
with output_path.open("w", encoding="utf-8", newline="") as out_fp:
writer = csv.DictWriter(out_fp, fieldnames=OUTPUT_COLUMNS, extrasaction="ignore")
writer.writeheader()
out_fp.flush()
processed = 0
for idx, (_, row) in enumerate(iterator, start=1):
if idx > total_rows:
break
processed = idx
repo_url = str(row.get(PROJECT_URL_COL, "")).strip()
sha = str(row.get(SHA_COL, "")).strip()
test_name = str(row.get(test_name_col, "")).strip()
category_raw = str(row.get(CATEGORY_COL, "")).strip()
status = str(row.get(STATUS_COL, "")).strip()
pr_link = str(row.get(PR_LINK_COL, "")).strip()
if not repo_url or not sha or not test_name or not category_raw:
stats["skipped_missing_core_fields"] += 1
_update_progress(iterator, tqdm, stats)
continue
category = category_raw.split(";")[0].strip()
if category == "UD":
stats["skipped_ud"] += 1
_update_progress(iterator, tqdm, stats)
continue
task_types: list[str] = []
if category in TASK12_CATEGORIES:
task_types.extend(["classify", "root_cause"])
if category in TASK3_CATEGORIES and _is_accepted_status(status) and _parse_pr_link(pr_link):
task_types.append("fix_proposal")
if not task_types:
stats["skipped_no_task_types"] += 1
_update_progress(iterator, tqdm, stats)
continue
test_code, fetch_reason, fetch_detail = fetch_test_code(repo_url, sha, test_name)
if not test_code:
stats["skipped_test_code_fetch_failed"] += 1
if fetch_reason in {
"git_init_failed",
"git_remote_add_failed",
"git_fetch_sha_failed",
"git_checkout_failed",
"git_timeout",
}:
stats["skipped_test_code_fetch_git_fail"] += 1
if fetch_reason == "test_file_missing_at_sha":
stats["skipped_test_code_fetch_file_missing"] += 1
if len(fetch_fail_examples) < 10:
fetch_fail_examples.append(
{
"repo_url": repo_url,
"sha": sha,
"test_name": test_name,
"reason": fetch_reason,
"detail": fetch_detail,
}
)
_update_progress(iterator, tqdm, stats)
continue
known_fix_diff = ""
if "fix_proposal" in task_types and github_token:
known_fix_diff = fetch_pr_diff(pr_link, github_token)
if known_fix_diff:
stats["fix_diff_fetched"] += 1
row_out = {
"repo_url": repo_url,
"sha": sha,
"test_name": test_name,
"test_file": test_name.split("::")[0],
"category": category,
"label": "flaky",
"status": status,
"pr_link": pr_link,
"task_types": ";".join(task_types),
"test_code": test_code,
"known_fix_diff": known_fix_diff,
}
writer.writerow(row_out)
out_fp.flush()
stats["kept"] += 1
row_key = (
row_out["repo_url"],
row_out["sha"],
row_out["test_name"],
)
if row_key not in canonical_rows:
canonical_rows[row_key] = row_out
else:
stats["duplicate_key_rows_seen"] += 1
current = canonical_rows[row_key]
if _row_preference_rank(row_out) > _row_preference_rank(current):
canonical_rows[row_key] = row_out
stats["duplicate_key_replaced"] += 1
else:
stats["duplicate_key_kept_existing"] += 1
_update_progress(iterator, tqdm, stats, processed, total_rows)
out = pd.DataFrame(list(canonical_rows.values()), columns=OUTPUT_COLUMNS)
stats["kept_unique"] = len(out)
out.to_csv(output_csv, index=False)
if tqdm is None:
print()
print("\nBuild summary:")
for key, value in stats.items():
print(f" {key}: {value}")
print(f"Built {len(out)} task rows -> {output_csv}")
if fetch_fail_examples:
print("\nSample fetch failures (first 10):")
for i, sample in enumerate(fetch_fail_examples, start=1):
print(
f" {i}. reason={sample['reason']} "
f"repo={sample['repo_url']} sha={sample['sha']} "
f"test={sample['test_name']} detail={sample['detail']}"
)
if len(out):
print(out["category"].value_counts())
print(out["task_types"].value_counts())
def _update_progress(
iterator,
tqdm_mod,
stats: dict[str, int],
processed: int | None = None,
total_rows: int | None = None,
) -> None:
if tqdm_mod is not None and hasattr(iterator, "set_postfix"):
iterator.set_postfix(
kept=stats["kept"],
miss=stats["skipped_missing_core_fields"],
ud=stats["skipped_ud"],
no_task=stats["skipped_no_task_types"],
fetch_fail=stats["skipped_test_code_fetch_failed"],
)
return
if processed is None or total_rows is None:
return
if processed == 1 or processed % 20 == 0 or processed == total_rows:
print(
f"\r[{processed}/{total_rows}] "
f"kept={stats['kept']} "
f"fetch_fail={stats['skipped_test_code_fetch_failed']} "
f"no_task={stats['skipped_no_task_types']}",
end="",
flush=True,
)
def main() -> None:
parser = argparse.ArgumentParser(description="Build FlakySleuth task dataset")
parser.add_argument("--input", default="idoft/py-data.csv", help="Path to IDoFT py-data.csv")
parser.add_argument("--output", default="dataset/py_tasks.csv", help="Output CSV path")
parser.add_argument(
"--validate-only",
action="store_true",
help="Validate input schema and print summary, without cloning/fetching.",
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="Optional max input rows to process (useful for quick sanity checks).",
)
args = parser.parse_args()
github_token = os.environ.get("GITHUB_TOKEN", "")
build(
args.input,
args.output,
github_token,
validate_only=args.validate_only,
limit=args.limit,
)
if __name__ == "__main__":
main()