File size: 4,638 Bytes
0cc2540
 
 
 
485e71e
0cc2540
 
 
 
 
 
 
e60a692
485e71e
 
 
 
 
 
 
 
0cc2540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485e71e
 
 
 
e60a692
 
 
485e71e
e60a692
 
0cc2540
 
 
 
1082f9f
 
0cc2540
 
 
 
1082f9f
0cc2540
 
 
 
 
 
1082f9f
0cc2540
 
 
 
 
 
 
 
 
 
1082f9f
 
 
 
0cc2540
 
1082f9f
 
0cc2540
 
1082f9f
 
 
0cc2540
 
1082f9f
 
 
0cc2540
 
 
1082f9f
 
 
 
0cc2540
 
 
 
 
485e71e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import math
import os

import gradio as gr
from datasets import Image as HFImage
from datasets import load_dataset

REPO_ID = "teddyk251/self-imagine-sample"  # change me
TOKEN = os.environ.get("HF_TOKEN")

ds = load_dataset(REPO_ID, split="train", token=TOKEN)

ds = ds.cast_column("image", HFImage(decode=True))

def img_src(img):
    # HF Datasets gives either a dict with 'path' (local cache) or 'url'
    if isinstance(img, dict):
        return img.get("path") or img.get("url")
    return img  # already a string


DATASETS = ["All"] + sorted(set(ds["dataset"]))
MODELS = ["All"] + sorted(set(ds["model"]))

def filter_rows(dataset_sel, model_sel, query):
    rows = ds
    if dataset_sel != "All":
        rows = rows.filter(lambda ex: ex["dataset"] == dataset_sel)
    if model_sel != "All":
        rows = rows.filter(lambda ex: ex["model"] == model_sel)
    if query:
        q = query.lower()
        rows = rows.filter(lambda ex: q in (ex["filename"] or "").lower())
    return rows

def paginate(rows, page, page_size):
    total = len(rows)
    total_pages = max(1, math.ceil(total / max(1, page_size)))
    page = max(1, min(page, total_pages))
    start = (page - 1) * page_size
    end = min(start + page_size, total)
    sub = rows.select(range(start, end))
    # gallery_items = [(sub[i]["image"], f'{sub[i]["dataset"]} / {sub[i]["model"]} — {sub[i]["filename"]}')
    #                  for i in range(len(sub))]\
    gallery_items = []
    for i in range(len(sub)):
    #    src = img_src(sub[i]["image"])
        img = sub[i]["image"]
        if img is None:
           continue  # skip records without a resolvable path/url
        label = f'{sub[i]["dataset"]} / {sub[i]["model"]}{sub[i].get("filename","")}'
        gallery_items.append((img, label))
    return gallery_items, total, total_pages, page

def refresh(dataset_sel, model_sel, query, page, page_size):
    rows = filter_rows(dataset_sel, model_sel, query)
    if len(rows) == 0:
        return [], "Page 1/1 — 0 images", 1
    items, total, total_pages, page = paginate(rows, page, page_size)
    info = f"Page {page}/{total_pages}{total} images"
    return items, info, page


with gr.Blocks(title="Self‑Imagine — Sample Gallery") as demo:
    gr.Markdown("## 🖼️ Self‑Imagine — Sample Gallery")

    with gr.Row():
        dataset_dd = gr.Dropdown(DATASETS, value="All", label="Dataset")
        model_dd = gr.Dropdown(MODELS, value="All", label="Model")
        query_tb = gr.Textbox(label="Search filename", placeholder="e.g., 17.jpg or gsm8k")
        page_size = gr.Slider(6, 60, value=24, step=6, label="Thumbnails per page")

    with gr.Row():
        prev_btn = gr.Button("◀ Prev")
        next_btn = gr.Button("Next ▶")
        page_num = gr.Number(value=1, precision=0, interactive=False, label="Page")
        page_info = gr.Markdown()

    gallery = gr.Gallery(columns=[6], height=700, allow_preview=False, show_label=False)

    def _initial():
        items, info, page = refresh("All", "All", "", 1, 24)
        return items, info, page
    demo.load(_initial, inputs=None, outputs=[gallery, page_info, page_num])

    def do_search(dataset_sel, model_sel, query, page_size):
        items, info, page = refresh(dataset_sel, model_sel, query, 1, page_size)
        return items, info, 1

    dataset_dd.change(do_search, [dataset_dd, model_dd, query_tb, page_size], [gallery, page_info, page_num])
    model_dd.change(do_search,   [dataset_dd, model_dd, query_tb, page_size], [gallery, page_info, page_num])
    query_tb.change(do_search,   [dataset_dd, model_dd, query_tb, page_size], [gallery, page_info, page_num])
    page_size.change(do_search,  [dataset_dd, model_dd, query_tb, page_size], [gallery, page_info, page_num])

    def on_prev(dataset_sel, model_sel, query, page, page_size):
        newp = max(1, int(page) - 1)
        items, info, page = refresh(dataset_sel, model_sel, query, newp, page_size)
        return items, info, newp

    def on_next(dataset_sel, model_sel, query, page, page_size):
        rows = filter_rows(dataset_sel, model_sel, query)
        total_pages = max(1, math.ceil(len(rows) / max(1, int(page_size))))
        newp = min(total_pages, int(page) + 1)
        items, info, page = refresh(dataset_sel, model_sel, query, newp, page_size)
        return items, info, newp

    prev_btn.click(on_prev, [dataset_dd, model_dd, query_tb, page_num, page_size], [gallery, page_info, page_num])
    next_btn.click(on_next, [dataset_dd, model_dd, query_tb, page_num, page_size], [gallery, page_info, page_num])

if __name__ == "__main__":
    demo.launch(ssr_mode=False)