File size: 3,812 Bytes
77f7fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""Extract aligned sample rows from the FinePhrase dataset for the blog data explorer widget.

Streams through HuggingFaceFW/finephrase and collects aligned samples where the
same source document has outputs for all four prompt configs (faq, math, table, tutorial).
Stops once 1000 aligned samples are found.

Usage:
    python app/scripts/extract_finephrase_samples.py
"""

import json
import logging
from pathlib import Path

from datasets import load_dataset

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

REPO_ID = "HuggingFaceFW/finephrase"
PROMPTS = ["faq", "math", "table", "tutorial"]
TARGET_SAMPLES = 1000
OUTPUT_PATH = Path(__file__).parent.parent / "public" / "data" / "finephrase-samples.jsonl"


def extract_samples() -> list[dict]:
    """Stream all prompts and collect aligned samples until we reach TARGET_SAMPLES."""
    # Accumulate rows by document id for each prompt
    prompt_data: dict[str, dict[str, dict]] = {p: {} for p in PROMPTS}
    # Track which doc ids are complete (present in all 4 prompts)
    complete_ids: list[str] = []
    complete_set: set[str] = set()

    # Stream all prompts in parallel via iterators
    iterators = {}
    for prompt in PROMPTS:
        logger.info(f"Opening stream for {prompt}...")
        ds = load_dataset(REPO_ID, name=prompt, split="train", streaming=True)
        iterators[prompt] = iter(ds)

    batch_size = 500
    while len(complete_ids) < TARGET_SAMPLES:
        # Fetch a batch from each prompt
        any_progress = False
        for prompt in PROMPTS:
            it = iterators[prompt]
            fetched = 0
            for _ in range(batch_size):
                try:
                    row = next(it)
                except StopIteration:
                    break
                fetched += 1
                doc_id = row["id"]
                prompt_data[prompt][doc_id] = row

                # Check if this doc_id is now complete across all prompts
                if doc_id not in complete_set and all(
                    doc_id in prompt_data[p] for p in PROMPTS
                ):
                    complete_set.add(doc_id)
                    complete_ids.append(doc_id)
                    if len(complete_ids) >= TARGET_SAMPLES:
                        break
            if fetched > 0:
                any_progress = True
            logger.info(
                f"  {prompt}: {len(prompt_data[prompt])} rows loaded, "
                f"{len(complete_ids)} aligned so far"
            )
            if len(complete_ids) >= TARGET_SAMPLES:
                break

        if not any_progress:
            logger.warning("All streams exhausted before reaching target")
            break

    logger.info(f"Found {len(complete_ids)} aligned documents")

    samples: list[dict] = []
    for doc_id in complete_ids[:TARGET_SAMPLES]:
        ref_row = prompt_data[PROMPTS[0]][doc_id]

        entry: dict = {
            "id": doc_id,
            "url": ref_row.get("url", ""),
            "file_path": ref_row.get("file_path", ""),
            "source": str(ref_row.get("text", "")),
        }
        for prompt in PROMPTS:
            rollout = prompt_data[prompt][doc_id].get("rollout_results", [])
            entry[prompt] = str(rollout[0]["text"]) if rollout else ""
        samples.append(entry)

    logger.info(f"Built {len(samples)} samples")
    return samples


def main() -> None:
    samples = extract_samples()
    OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
    with open(OUTPUT_PATH, "w") as f:
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    logger.info(f"Saved {len(samples)} samples to {OUTPUT_PATH}")


if __name__ == "__main__":
    main()