File size: 8,082 Bytes
fb1cdce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Gradio app for Hugging Face Spaces: view detokenized samples by category.
No processing β€” only serves pre-built sampled.jsonl. Data path via DATA_PATH env (default: sampled.jsonl).
"""

import json
import os
from pathlib import Path

import gradio as gr


def load_sampled_jsonl(path: Path) -> dict[str, list[dict]]:
    """Load JSONL into category -> list of {path, text, num_tokens}."""
    by_cat: dict[str, list[dict]] = {}
    with open(path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rec = json.loads(line)
            cat = rec.get("category", "unknown")
            if cat not in by_cat:
                by_cat[cat] = []
            by_cat[cat].append(rec)
    return by_cat


PAGE_SIZE = 500


def format_page(records: list[dict], page: int) -> str:
    """Format one page of samples (up to PAGE_SIZE) as one string. Page is 1-based."""
    if not records:
        return ""
    start = (page - 1) * PAGE_SIZE
    end = min(start + PAGE_SIZE, len(records))
    page_recs = records[start:end]
    sep = "\n\n" + "β€”" * 50 + "\n\n"
    parts = [r.get("text", "") for r in page_recs]
    return sep.join(parts)


def create_demo():
    data_path = Path(os.environ.get("DATA_PATH", "sampled.jsonl")).resolve()
    if not data_path.is_file():
        raise FileNotFoundError(
            f"Data file not found: {data_path}. Upload sampled.jsonl to this Space (or set DATA_PATH)."
        )
    by_cat = load_sampled_jsonl(data_path)
    categories = sorted(by_cat.keys())

    def _page_label(page: int, max_page: int) -> str:
        return f"Page {page} of {max_page}"

    def on_category_change(cat: str | None):
        if not cat or cat not in by_cat:
            return "", (1, 1), "Page 1 of 1", gr.update(choices=["1"], value="1")
        recs = by_cat[cat]
        max_page = max(1, (len(recs) + PAGE_SIZE - 1) // PAGE_SIZE)
        text = format_page(recs, 1)
        choices = [str(i) for i in range(1, max_page + 1)]
        return text, (1, max_page), _page_label(1, max_page), gr.update(choices=choices, value="1")

    def on_next(cat: str | None, state: tuple):
        page, max_page = state
        if not cat or cat not in by_cat:
            return "", state, _page_label(1, state[1]), gr.update()
        page = min(max_page, page + 1)
        recs = by_cat[cat]
        text = format_page(recs, page)
        return text, (page, max_page), _page_label(page, max_page), gr.update(value=str(page))

    def on_prev(cat: str | None, state: tuple):
        page, max_page = state
        if not cat or cat not in by_cat:
            return "", state, _page_label(1, state[1]), gr.update()
        page = max(1, page - 1)
        recs = by_cat[cat]
        text = format_page(recs, page)
        return text, (page, max_page), _page_label(page, max_page), gr.update(value=str(page))

    def on_page_select(cat: str | None, page_str: str | None, state: tuple):
        if not cat or cat not in by_cat or not page_str:
            return "", state, _page_label(state[0], state[1])
        try:
            page = int(page_str)
        except ValueError:
            return "", state, _page_label(state[0], state[1])
        _, max_page = state
        page = max(1, min(max_page, page))
        recs = by_cat[cat]
        text = format_page(recs, page)
        return text, (page, max_page), _page_label(page, max_page)

    custom_css = """
    .pagination-row { display: flex !important; flex-wrap: nowrap !important; align-items: center !important; gap: 12px !important; }
    .gradio-container textarea { font-size: 18px !important; line-height: 1.6; }
    .gradio-container .label { font-size: 17px !important; }
    .gradio-container input, .gradio-container select { font-size: 17px !important; }
    .home-page .prose, .home-page p, .home-page h1, .home-page .markdown { font-size: 24px !important; color: #212529 !important; }
    .home-page h1 { font-size: 32px !important; color: #212529 !important; }
    .home-page .card-wrap { display: inline-block; margin: 1em 0; border-radius: 16px; box-shadow: 0 8px 24px rgba(0,0,0,0.08); padding: 4px; background: #f8f9fa; border: 1px solid #e9ecef; }
    .home-page .card-wrap .primary, .home-page .card-wrap button { font-size: 26px !important; padding: 40px 64px !important; min-height: 100px !important; border-radius: 12px !important; border: none !important; background: linear-gradient(180deg, #ffffff 0%, #f1f3f5 100%) !important; box-shadow: 0 2px 8px rgba(0,0,0,0.06) !important; cursor: pointer !important; transition: box-shadow 0.2s ease, transform 0.2s ease !important; font-weight: 600 !important; color: #212529 !important; }
    .home-page .card-wrap .primary:hover, .home-page .card-wrap button:hover { box-shadow: 0 6px 20px rgba(0,0,0,0.12) !important; transform: translateY(-2px) !important; }
    .home-page .card-desc, .home-page .card-wrap .markdown, .home-page .card-wrap p, .home-page .card-wrap div[class*="markdown"] { font-size: 20px !important; color: #212529 !important; margin-top: 0.5em !important; }
    """
    with gr.Blocks(title="Token sample viewer", css=custom_css) as demo:
        home_col = gr.Column(visible=True, elem_classes=["home-page"])
        with home_col:
            gr.Markdown("# Home")
            gr.Markdown("Select a dataset to view samples:")
            with gr.Column(elem_classes=["card-wrap"]):
                card_btn = gr.Button("OLMo3", elem_classes=["primary"])
                gr.Markdown("OLMo3 / Dolma 3 pre-training data (Allen AI)", elem_classes=["card-desc"])

        olmo3_col = gr.Column(visible=False)
        with olmo3_col:
            back_btn = gr.Button("← Back to Home")
            gr.Markdown("""
**Dataset:** OLMo3 / Dolma 3 pre-training data (Allen AI).  
Source: Hugging Face (e.g. `allenai/dolma3_pool`, `dolma3_mix_*`). Tokenized with Dolma; each **category** is one data split.  
Text below is detokenized from `.npy` for inspection; document boundaries are not preserved.
            """)
            gr.Markdown("## View detokenized samples by category")
            cat_dd = gr.Dropdown(
                choices=categories,
                label="Category",
                value=categories[0] if categories else None,
            )
            page_state = gr.State((1, 1))
            with gr.Row(elem_classes=["pagination-row"]):
                prev_btn = gr.Button("← Previous")
                next_btn = gr.Button("Next β†’")
                page_label = gr.Markdown("Page 1 of 1", show_label=False)
                page_dd = gr.Dropdown(choices=["1"], value="1", label="Go to page")
            text_out = gr.Textbox(
                label="Samples (max 500 per page)",
                lines=25,
            )

        def show_olmo3():
            return gr.update(visible=False), gr.update(visible=True)

        def show_home():
            return gr.update(visible=True), gr.update(visible=False)

        card_btn.click(fn=show_olmo3, outputs=[home_col, olmo3_col])
        back_btn.click(fn=show_home, outputs=[home_col, olmo3_col])
        cat_dd.change(
            fn=on_category_change,
            inputs=[cat_dd],
            outputs=[text_out, page_state, page_label, page_dd],
        )
        next_btn.click(
            fn=on_next,
            inputs=[cat_dd, page_state],
            outputs=[text_out, page_state, page_label, page_dd],
        )
        prev_btn.click(
            fn=on_prev,
            inputs=[cat_dd, page_state],
            outputs=[text_out, page_state, page_label, page_dd],
        )
        page_dd.change(
            fn=on_page_select,
            inputs=[cat_dd, page_dd, page_state],
            outputs=[text_out, page_state, page_label],
        )
        demo.load(
            fn=on_category_change,
            inputs=[cat_dd],
            outputs=[text_out, page_state, page_label, page_dd],
        )
    return demo


demo = create_demo()
port = int(os.environ.get("PORT", "7860"))
demo.launch(server_name="0.0.0.0", server_port=port)