File size: 4,091 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from datasets import load_dataset
from collections import Counter
from dotenv import load_dotenv
import os

# ============================================================
# [์„ค์ • ๋ถ€๋ถ„]
# ============================================================
load_dotenv()
HF_TOKEN = os.environ.get("HF_TOKEN")

# ํ™•์ธํ•  Hugging Face ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„
DATASET_NAME = "jbarat/plant_species"  # ์˜ˆ: "uran66/animals"

# ํ™•์ธํ•  split ์ด๋ฆ„
SPLIT_NAME = "train"

# ๋ผ๋ฒจ ํ•„๋“œ๋ช…
LABEL_FIELD_NAME = "label"

# streaming ์‚ฌ์šฉ ์—ฌ๋ถ€
# True  : ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์„ ๋ฏธ๋ฆฌ ๋‹ค์šด๋กœ๋“œํ•˜์ง€ ์•Š๊ณ  ํ•˜๋‚˜์”ฉ ์ฝ์œผ๋ฉด์„œ ํ™•์ธ
# False : ๋กœ์ปฌ ์บ์‹œ์— ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค์šด๋กœ๋“œํ•œ ๋’ค ํ™•์ธ
USE_STREAMING = True

# ๋ฌธ์ž์—ด ๋ผ๋ฒจ ๋ฐ์ดํ„ฐ์…‹์ผ ๊ฒฝ์šฐ ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ํ›‘์–ด์•ผ ์ •ํ™•ํ•œ ๊ฐœ์ˆ˜๋ฅผ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
# None์ด๋ฉด ์ „์ฒด ํ™•์ธ, ์ˆซ์ž๋ฅผ ๋„ฃ์œผ๋ฉด ์ผ๋ถ€ ์ƒ˜ํ”Œ๋งŒ ํ™•์ธํ•œ๋‹ค.
MAX_SCAN_ITEMS = None

# ============================================================


def get_label_name(dataset, label_value):

    label_feature = dataset.features[LABEL_FIELD_NAME]

    # ClassLabel ํƒ€์ž…์ด๋ฉด ์ˆซ์ž ๋ผ๋ฒจ์„ ๋ฌธ์ž์—ด ๋ผ๋ฒจ๋ช…์œผ๋กœ ๋ณ€ํ™˜ํ•œ๋‹ค.
    if hasattr(label_feature, "int2str") and isinstance(label_value, int):
        return label_feature.int2str(label_value)

    # ์ด๋ฏธ ๋ฌธ์ž์—ด ๋ผ๋ฒจ์ด๋ฉด ๊ทธ๋Œ€๋กœ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ํ•ด์„œ ์‚ฌ์šฉํ•œ๋‹ค.
    return str(label_value)


def get_unique_labels_with_counts():

    print(f"[{DATASET_NAME}] ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์ค‘...")

    dataset = load_dataset(
        DATASET_NAME,
        split=SPLIT_NAME,
        streaming=USE_STREAMING,
        token=HF_TOKEN
    )

    # ๋ฐ์ดํ„ฐ์…‹์˜ feature ์ •๋ณด์—์„œ ๋ผ๋ฒจ ํ•„๋“œ๋ฅผ ๊ฐ€์ ธ์˜จ๋‹ค.
    label_feature = dataset.features[LABEL_FIELD_NAME]

    # ํด๋ž˜์Šค๋ณ„ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜๋ฅผ ์ €์žฅํ•  Counter
    label_counter = Counter()

    print("\nํด๋ž˜์Šค๋ณ„ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ ์ง‘๊ณ„ ์ค‘...")

    # streaming=True์ธ ๊ฒฝ์šฐ์—๋„ dataset์„ ์ˆœํšŒํ•˜๋ฉด์„œ ๊ฐœ์ˆ˜๋ฅผ ์…€ ์ˆ˜ ์žˆ๋‹ค.
    for idx, item in enumerate(dataset):
        # MAX_SCAN_ITEMS๊ฐ€ ์„ค์ •๋˜์–ด ์žˆ์œผ๋ฉด ์ง€์ •ํ•œ ๊ฐœ์ˆ˜๊นŒ์ง€๋งŒ ํ™•์ธํ•œ๋‹ค.
        if MAX_SCAN_ITEMS is not None and idx >= MAX_SCAN_ITEMS:
            break

        label_value = item.get(LABEL_FIELD_NAME)

        # ๋ผ๋ฒจ ๊ฐ’์ด ์—†๋Š” ๋ฐ์ดํ„ฐ๋Š” ๊ฑด๋„ˆ๋›ด๋‹ค.
        if label_value is None:
            continue

        # ์ˆซ์ž ๋ผ๋ฒจ์ด๋ฉด ์‹ค์ œ ๋ผ๋ฒจ๋ช…์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ ,
        # ๋ฌธ์ž์—ด ๋ผ๋ฒจ์ด๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•œ๋‹ค.
        label_name = get_label_name(dataset, label_value)

        # ํ•ด๋‹น ๋ผ๋ฒจ์˜ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜๋ฅผ 1 ์ฆ๊ฐ€์‹œํ‚จ๋‹ค.
        label_counter[label_name] += 1

    print("\n๋ผ๋ฒจ ๋ชฉ๋ก ๋ฐ ํด๋ž˜์Šค๋ณ„ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜")
    print("-" * 60)

    # ------------------------------------------------------------
    # 1. Food101์ฒ˜๋Ÿผ label์ด ClassLabel ํƒ€์ž…์ธ ๊ฒฝ์šฐ
    # ------------------------------------------------------------
    # label_feature.names๊ฐ€ ์žˆ์œผ๋ฉด ์›๋ž˜ ๋ฐ์ดํ„ฐ์…‹์˜ ๋ผ๋ฒจ ์ˆœ์„œ๋Œ€๋กœ ์ถœ๋ ฅํ•œ๋‹ค.
    if hasattr(label_feature, "names") and label_feature.names is not None:
        label_names = label_feature.names

        for idx, label_name in enumerate(label_names):
            count = label_counter.get(label_name, 0)
            print(f"{idx}: {label_name} - {count} ์žฅ")

    # ------------------------------------------------------------
    # 2. label์ด ๋ฌธ์ž์—ด๋กœ ์ง์ ‘ ๋“ค์–ด์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹์ธ ๊ฒฝ์šฐ
    # ------------------------------------------------------------
    # Counter์— ๋ชจ์ธ ๋ผ๋ฒจ๋ช…์„ ์ด๋ฆ„์ˆœ์œผ๋กœ ์ •๋ ฌํ•ด์„œ ์ถœ๋ ฅํ•œ๋‹ค.
    else:
        label_names = sorted(label_counter.keys())

        for idx, label_name in enumerate(label_names):
            count = label_counter[label_name]
            print(f"{idx}: {label_name} - {count} ์žฅ")

    print("-" * 60)
    print(f"์ด ๋ผ๋ฒจ ๊ฐœ์ˆ˜: {len(label_counter)}")
    print(f"์ด ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜: {sum(label_counter.values())}")

    return label_counter

if __name__ == "__main__":
    get_unique_labels_with_counts()