Spaces:
Sleeping
Sleeping
| """Offline dataset builder for FlakySleuth. | |
| Examples: | |
| # Validate schema and show category/status summary only | |
| python dataset/build_dataset.py --input py-data.csv --validate-only | |
| # Build full task CSV (requires network access for repo cloning) | |
| export GITHUB_TOKEN=... | |
| python dataset/build_dataset.py --input py-data.csv --output dataset/py_tasks.csv | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import os | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| from urllib.parse import urlparse | |
| import pandas as pd | |
| import requests | |
| try: | |
| from tqdm import tqdm | |
| except Exception: # pragma: no cover | |
| tqdm = None | |
| TASK12_CATEGORIES = ["NOD", "TD", "TZD", "NIO", "ID", "OD", "OD-Brit", "OD-Vic"] | |
| TASK3_CATEGORIES = ["TD", "TZD", "NOD", "NIO", "ID"] | |
| PROJECT_URL_COL = "Project URL" | |
| SHA_COL = "SHA Detected" | |
| CATEGORY_COL = "Category" | |
| STATUS_COL = "Status" | |
| PR_LINK_COL = "PR Link" | |
| NOTES_COL = "Notes" | |
| TEST_NAME_ALIASES = [ | |
| "Pytest Test Name", | |
| "Pytest Test Name (PathToFile::TestClass::TestMethod or PathToFile::TestMethod)", | |
| ] | |
| OUTPUT_COLUMNS = [ | |
| "repo_url", | |
| "sha", | |
| "test_name", | |
| "test_file", | |
| "category", | |
| "label", | |
| "status", | |
| "pr_link", | |
| "task_types", | |
| "test_code", | |
| "known_fix_diff", | |
| ] | |
| def _normalize_header(text: str) -> str: | |
| return " ".join(str(text).strip().split()) | |
| def _resolve_test_name_column(columns: list[str]) -> str: | |
| normalized = {_normalize_header(c): c for c in columns} | |
| for alias in TEST_NAME_ALIASES: | |
| key = _normalize_header(alias) | |
| if key in normalized: | |
| return normalized[key] | |
| raise KeyError( | |
| "Could not find pytest test-name column. Expected one of: " | |
| + ", ".join(TEST_NAME_ALIASES) | |
| ) | |
| def _parse_pr_link(pr_link: str) -> tuple[str, str] | None: | |
| """Return (owner/repo, number) from URL or owner/repo#number.""" | |
| value = (pr_link or "").strip() | |
| if not value or value.lower() == "nan": | |
| return None | |
| if value.startswith("http://") or value.startswith("https://"): | |
| parsed = urlparse(value) | |
| parts = [p for p in parsed.path.split("/") if p] | |
| # Expected: /owner/repo/pull/number | |
| if len(parts) >= 4 and parts[2] == "pull" and parts[3].isdigit(): | |
| return f"{parts[0]}/{parts[1]}", parts[3] | |
| return None | |
| if "#" in value: | |
| repo, number = value.split("#", 1) | |
| if repo.strip() and number.strip().isdigit(): | |
| return repo.strip(), number.strip() | |
| return None | |
| def _is_accepted_status(status: str) -> bool: | |
| value = (status or "").strip().lower() | |
| return value in {"accepted", "merged", "fixed"} | |
| def _non_interactive_git_env() -> dict[str, str]: | |
| env = os.environ.copy() | |
| # Never block on credential prompts while iterating large public datasets. | |
| env["GIT_TERMINAL_PROMPT"] = "0" | |
| env["GCM_INTERACTIVE"] = "Never" | |
| return env | |
| def _has_value(value: str) -> bool: | |
| text = str(value or "").strip().lower() | |
| return text not in {"", "nan", "none"} | |
| def _is_non_unmaintained_status(status: str) -> bool: | |
| value = str(status or "").strip().lower() | |
| return value not in {"", "nan", "none", "unmaintained"} | |
| def _row_preference_rank(row_out: dict[str, str]) -> tuple[int, int, int]: | |
| task_tokens = {t.strip() for t in str(row_out.get("task_types", "")).split(";") if t.strip()} | |
| return ( | |
| 1 if "fix_proposal" in task_tokens else 0, | |
| 1 if _has_value(str(row_out.get("pr_link", ""))) else 0, | |
| 1 if _is_non_unmaintained_status(str(row_out.get("status", ""))) else 0, | |
| ) | |
| def fetch_test_code(repo_url: str, sha: str, pytest_test_name: str) -> tuple[str, str, str]: | |
| test_file = pytest_test_name.split("::")[0] | |
| git_env = _non_interactive_git_env() | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| try: | |
| init = subprocess.run( | |
| ["git", "init", tmpdir], | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| timeout=20, | |
| env=git_env, | |
| stdin=subprocess.DEVNULL, | |
| ) | |
| if init.returncode != 0: | |
| return "", "git_init_failed", (init.stderr or init.stdout or "").strip()[:200] | |
| remote = subprocess.run( | |
| ["git", "-C", tmpdir, "remote", "add", "origin", repo_url], | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| timeout=10, | |
| env=git_env, | |
| stdin=subprocess.DEVNULL, | |
| ) | |
| if remote.returncode != 0: | |
| return "", "git_remote_add_failed", (remote.stderr or remote.stdout or "").strip()[:200] | |
| # Fetch only the requested commit for speed and correctness. | |
| fetch = subprocess.run( | |
| ["git", "-C", tmpdir, "fetch", "--depth=1", "origin", sha], | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| timeout=90, | |
| env=git_env, | |
| stdin=subprocess.DEVNULL, | |
| ) | |
| if fetch.returncode != 0: | |
| return "", "git_fetch_sha_failed", (fetch.stderr or fetch.stdout or "").strip()[:200] | |
| checkout = subprocess.run( | |
| ["git", "-C", tmpdir, "checkout", "--detach", "FETCH_HEAD"], | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| timeout=30, | |
| env=git_env, | |
| stdin=subprocess.DEVNULL, | |
| ) | |
| if checkout.returncode != 0: | |
| return "", "git_checkout_failed", (checkout.stderr or checkout.stdout or "").strip()[:200] | |
| except subprocess.TimeoutExpired: | |
| return "", "git_timeout", "timeout" | |
| file_path = Path(tmpdir) / test_file | |
| if not file_path.exists(): | |
| return "", "test_file_missing_at_sha", test_file | |
| return file_path.read_text(encoding="utf-8", errors="replace")[:10000], "", "" | |
| def fetch_pr_diff(pr_link: str, github_token: str) -> str: | |
| parsed = _parse_pr_link(pr_link) | |
| if not parsed: | |
| return "" | |
| repo, number = parsed | |
| url = f"https://api.github.com/repos/{repo}/pulls/{number}" | |
| headers = { | |
| "Authorization": f"token {github_token}", | |
| "Accept": "application/vnd.github.diff", | |
| } | |
| response = requests.get(url, headers=headers, timeout=15) | |
| if response.status_code == 200: | |
| return response.text[:3000] | |
| return "" | |
| def _validate_schema(input_csv: str) -> tuple[pd.DataFrame, str]: | |
| df = pd.read_csv(input_csv) | |
| df.columns = [_normalize_header(col) for col in df.columns] | |
| missing = [] | |
| for required in [PROJECT_URL_COL, SHA_COL, CATEGORY_COL, STATUS_COL, PR_LINK_COL]: | |
| if required not in df.columns: | |
| missing.append(required) | |
| if missing: | |
| raise KeyError(f"Missing required columns: {missing}") | |
| test_name_col = _resolve_test_name_column(list(df.columns)) | |
| return df, test_name_col | |
| def _print_input_summary(df: pd.DataFrame, test_name_col: str) -> None: | |
| print("Input schema check: OK") | |
| print(f"Rows: {len(df)}") | |
| print(f"Using test-name column: {test_name_col}") | |
| print("Columns:", list(df.columns)) | |
| print("\nCategory distribution (top 20):") | |
| print(df[CATEGORY_COL].fillna("").astype(str).value_counts().head(20)) | |
| print("\nStatus distribution:") | |
| print(df[STATUS_COL].fillna("").astype(str).value_counts().head(20)) | |
| def build( | |
| input_csv: str, | |
| output_csv: str, | |
| github_token: str, | |
| *, | |
| validate_only: bool = False, | |
| limit: int | None = None, | |
| ) -> None: | |
| df, test_name_col = _validate_schema(input_csv) | |
| _print_input_summary(df, test_name_col) | |
| if validate_only: | |
| return | |
| total_rows = min(len(df), limit) if limit is not None else len(df) | |
| print( | |
| f"\nStarting build over {total_rows} rows " | |
| f"(this can take a while: cloning repos + reading files + optional PR diff fetch)" | |
| ) | |
| stats: dict[str, int] = { | |
| "kept": 0, | |
| "kept_unique": 0, | |
| "skipped_missing_core_fields": 0, | |
| "skipped_ud": 0, | |
| "skipped_no_task_types": 0, | |
| "skipped_test_code_fetch_failed": 0, | |
| "skipped_test_code_fetch_git_fail": 0, | |
| "skipped_test_code_fetch_file_missing": 0, | |
| "fix_diff_fetched": 0, | |
| "duplicate_key_rows_seen": 0, | |
| "duplicate_key_replaced": 0, | |
| "duplicate_key_kept_existing": 0, | |
| } | |
| fetch_fail_examples: list[dict[str, str]] = [] | |
| canonical_rows: dict[tuple[str, str, str], dict[str, str]] = {} | |
| output_path = Path(output_csv) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| iterator = df.iterrows() | |
| if tqdm is not None: | |
| iterator = tqdm(iterator, total=total_rows, desc="Building tasks", unit="row") | |
| with output_path.open("w", encoding="utf-8", newline="") as out_fp: | |
| writer = csv.DictWriter(out_fp, fieldnames=OUTPUT_COLUMNS, extrasaction="ignore") | |
| writer.writeheader() | |
| out_fp.flush() | |
| processed = 0 | |
| for idx, (_, row) in enumerate(iterator, start=1): | |
| if idx > total_rows: | |
| break | |
| processed = idx | |
| repo_url = str(row.get(PROJECT_URL_COL, "")).strip() | |
| sha = str(row.get(SHA_COL, "")).strip() | |
| test_name = str(row.get(test_name_col, "")).strip() | |
| category_raw = str(row.get(CATEGORY_COL, "")).strip() | |
| status = str(row.get(STATUS_COL, "")).strip() | |
| pr_link = str(row.get(PR_LINK_COL, "")).strip() | |
| if not repo_url or not sha or not test_name or not category_raw: | |
| stats["skipped_missing_core_fields"] += 1 | |
| _update_progress(iterator, tqdm, stats) | |
| continue | |
| category = category_raw.split(";")[0].strip() | |
| if category == "UD": | |
| stats["skipped_ud"] += 1 | |
| _update_progress(iterator, tqdm, stats) | |
| continue | |
| task_types: list[str] = [] | |
| if category in TASK12_CATEGORIES: | |
| task_types.extend(["classify", "root_cause"]) | |
| if category in TASK3_CATEGORIES and _is_accepted_status(status) and _parse_pr_link(pr_link): | |
| task_types.append("fix_proposal") | |
| if not task_types: | |
| stats["skipped_no_task_types"] += 1 | |
| _update_progress(iterator, tqdm, stats) | |
| continue | |
| test_code, fetch_reason, fetch_detail = fetch_test_code(repo_url, sha, test_name) | |
| if not test_code: | |
| stats["skipped_test_code_fetch_failed"] += 1 | |
| if fetch_reason in { | |
| "git_init_failed", | |
| "git_remote_add_failed", | |
| "git_fetch_sha_failed", | |
| "git_checkout_failed", | |
| "git_timeout", | |
| }: | |
| stats["skipped_test_code_fetch_git_fail"] += 1 | |
| if fetch_reason == "test_file_missing_at_sha": | |
| stats["skipped_test_code_fetch_file_missing"] += 1 | |
| if len(fetch_fail_examples) < 10: | |
| fetch_fail_examples.append( | |
| { | |
| "repo_url": repo_url, | |
| "sha": sha, | |
| "test_name": test_name, | |
| "reason": fetch_reason, | |
| "detail": fetch_detail, | |
| } | |
| ) | |
| _update_progress(iterator, tqdm, stats) | |
| continue | |
| known_fix_diff = "" | |
| if "fix_proposal" in task_types and github_token: | |
| known_fix_diff = fetch_pr_diff(pr_link, github_token) | |
| if known_fix_diff: | |
| stats["fix_diff_fetched"] += 1 | |
| row_out = { | |
| "repo_url": repo_url, | |
| "sha": sha, | |
| "test_name": test_name, | |
| "test_file": test_name.split("::")[0], | |
| "category": category, | |
| "label": "flaky", | |
| "status": status, | |
| "pr_link": pr_link, | |
| "task_types": ";".join(task_types), | |
| "test_code": test_code, | |
| "known_fix_diff": known_fix_diff, | |
| } | |
| writer.writerow(row_out) | |
| out_fp.flush() | |
| stats["kept"] += 1 | |
| row_key = ( | |
| row_out["repo_url"], | |
| row_out["sha"], | |
| row_out["test_name"], | |
| ) | |
| if row_key not in canonical_rows: | |
| canonical_rows[row_key] = row_out | |
| else: | |
| stats["duplicate_key_rows_seen"] += 1 | |
| current = canonical_rows[row_key] | |
| if _row_preference_rank(row_out) > _row_preference_rank(current): | |
| canonical_rows[row_key] = row_out | |
| stats["duplicate_key_replaced"] += 1 | |
| else: | |
| stats["duplicate_key_kept_existing"] += 1 | |
| _update_progress(iterator, tqdm, stats, processed, total_rows) | |
| out = pd.DataFrame(list(canonical_rows.values()), columns=OUTPUT_COLUMNS) | |
| stats["kept_unique"] = len(out) | |
| out.to_csv(output_csv, index=False) | |
| if tqdm is None: | |
| print() | |
| print("\nBuild summary:") | |
| for key, value in stats.items(): | |
| print(f" {key}: {value}") | |
| print(f"Built {len(out)} task rows -> {output_csv}") | |
| if fetch_fail_examples: | |
| print("\nSample fetch failures (first 10):") | |
| for i, sample in enumerate(fetch_fail_examples, start=1): | |
| print( | |
| f" {i}. reason={sample['reason']} " | |
| f"repo={sample['repo_url']} sha={sample['sha']} " | |
| f"test={sample['test_name']} detail={sample['detail']}" | |
| ) | |
| if len(out): | |
| print(out["category"].value_counts()) | |
| print(out["task_types"].value_counts()) | |
| def _update_progress( | |
| iterator, | |
| tqdm_mod, | |
| stats: dict[str, int], | |
| processed: int | None = None, | |
| total_rows: int | None = None, | |
| ) -> None: | |
| if tqdm_mod is not None and hasattr(iterator, "set_postfix"): | |
| iterator.set_postfix( | |
| kept=stats["kept"], | |
| miss=stats["skipped_missing_core_fields"], | |
| ud=stats["skipped_ud"], | |
| no_task=stats["skipped_no_task_types"], | |
| fetch_fail=stats["skipped_test_code_fetch_failed"], | |
| ) | |
| return | |
| if processed is None or total_rows is None: | |
| return | |
| if processed == 1 or processed % 20 == 0 or processed == total_rows: | |
| print( | |
| f"\r[{processed}/{total_rows}] " | |
| f"kept={stats['kept']} " | |
| f"fetch_fail={stats['skipped_test_code_fetch_failed']} " | |
| f"no_task={stats['skipped_no_task_types']}", | |
| end="", | |
| flush=True, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Build FlakySleuth task dataset") | |
| parser.add_argument("--input", default="idoft/py-data.csv", help="Path to IDoFT py-data.csv") | |
| parser.add_argument("--output", default="dataset/py_tasks.csv", help="Output CSV path") | |
| parser.add_argument( | |
| "--validate-only", | |
| action="store_true", | |
| help="Validate input schema and print summary, without cloning/fetching.", | |
| ) | |
| parser.add_argument( | |
| "--limit", | |
| type=int, | |
| default=None, | |
| help="Optional max input rows to process (useful for quick sanity checks).", | |
| ) | |
| args = parser.parse_args() | |
| github_token = os.environ.get("GITHUB_TOKEN", "") | |
| build( | |
| args.input, | |
| args.output, | |
| github_token, | |
| validate_only=args.validate_only, | |
| limit=args.limit, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |