Mini-ImageNet / src /collection /count_label_hf.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
Raw
History Blame Contribute Delete
4.09 kB
from datasets import load_dataset
from collections import Counter
from dotenv import load_dotenv
import os
# ============================================================
# [์„ค์ • ๋ถ€๋ถ„]
# ============================================================
load_dotenv()
HF_TOKEN = os.environ.get("HF_TOKEN")
# ํ™•์ธํ•  Hugging Face ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„
DATASET_NAME = "jbarat/plant_species" # ์˜ˆ: "uran66/animals"
# ํ™•์ธํ•  split ์ด๋ฆ„
SPLIT_NAME = "train"
# ๋ผ๋ฒจ ํ•„๋“œ๋ช…
LABEL_FIELD_NAME = "label"
# streaming ์‚ฌ์šฉ ์—ฌ๋ถ€
# True : ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์„ ๋ฏธ๋ฆฌ ๋‹ค์šด๋กœ๋“œํ•˜์ง€ ์•Š๊ณ  ํ•˜๋‚˜์”ฉ ์ฝ์œผ๋ฉด์„œ ํ™•์ธ
# False : ๋กœ์ปฌ ์บ์‹œ์— ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค์šด๋กœ๋“œํ•œ ๋’ค ํ™•์ธ
USE_STREAMING = True
# ๋ฌธ์ž์—ด ๋ผ๋ฒจ ๋ฐ์ดํ„ฐ์…‹์ผ ๊ฒฝ์šฐ ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ํ›‘์–ด์•ผ ์ •ํ™•ํ•œ ๊ฐœ์ˆ˜๋ฅผ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
# None์ด๋ฉด ์ „์ฒด ํ™•์ธ, ์ˆซ์ž๋ฅผ ๋„ฃ์œผ๋ฉด ์ผ๋ถ€ ์ƒ˜ํ”Œ๋งŒ ํ™•์ธํ•œ๋‹ค.
MAX_SCAN_ITEMS = None
# ============================================================
def get_label_name(dataset, label_value):
label_feature = dataset.features[LABEL_FIELD_NAME]
# ClassLabel ํƒ€์ž…์ด๋ฉด ์ˆซ์ž ๋ผ๋ฒจ์„ ๋ฌธ์ž์—ด ๋ผ๋ฒจ๋ช…์œผ๋กœ ๋ณ€ํ™˜ํ•œ๋‹ค.
if hasattr(label_feature, "int2str") and isinstance(label_value, int):
return label_feature.int2str(label_value)
# ์ด๋ฏธ ๋ฌธ์ž์—ด ๋ผ๋ฒจ์ด๋ฉด ๊ทธ๋Œ€๋กœ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ํ•ด์„œ ์‚ฌ์šฉํ•œ๋‹ค.
return str(label_value)
def get_unique_labels_with_counts():
print(f"[{DATASET_NAME}] ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์ค‘...")
dataset = load_dataset(
DATASET_NAME,
split=SPLIT_NAME,
streaming=USE_STREAMING,
token=HF_TOKEN
)
# ๋ฐ์ดํ„ฐ์…‹์˜ feature ์ •๋ณด์—์„œ ๋ผ๋ฒจ ํ•„๋“œ๋ฅผ ๊ฐ€์ ธ์˜จ๋‹ค.
label_feature = dataset.features[LABEL_FIELD_NAME]
# ํด๋ž˜์Šค๋ณ„ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜๋ฅผ ์ €์žฅํ•  Counter
label_counter = Counter()
print("\nํด๋ž˜์Šค๋ณ„ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ ์ง‘๊ณ„ ์ค‘...")
# streaming=True์ธ ๊ฒฝ์šฐ์—๋„ dataset์„ ์ˆœํšŒํ•˜๋ฉด์„œ ๊ฐœ์ˆ˜๋ฅผ ์…€ ์ˆ˜ ์žˆ๋‹ค.
for idx, item in enumerate(dataset):
# MAX_SCAN_ITEMS๊ฐ€ ์„ค์ •๋˜์–ด ์žˆ์œผ๋ฉด ์ง€์ •ํ•œ ๊ฐœ์ˆ˜๊นŒ์ง€๋งŒ ํ™•์ธํ•œ๋‹ค.
if MAX_SCAN_ITEMS is not None and idx >= MAX_SCAN_ITEMS:
break
label_value = item.get(LABEL_FIELD_NAME)
# ๋ผ๋ฒจ ๊ฐ’์ด ์—†๋Š” ๋ฐ์ดํ„ฐ๋Š” ๊ฑด๋„ˆ๋›ด๋‹ค.
if label_value is None:
continue
# ์ˆซ์ž ๋ผ๋ฒจ์ด๋ฉด ์‹ค์ œ ๋ผ๋ฒจ๋ช…์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ ,
# ๋ฌธ์ž์—ด ๋ผ๋ฒจ์ด๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•œ๋‹ค.
label_name = get_label_name(dataset, label_value)
# ํ•ด๋‹น ๋ผ๋ฒจ์˜ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜๋ฅผ 1 ์ฆ๊ฐ€์‹œํ‚จ๋‹ค.
label_counter[label_name] += 1
print("\n๋ผ๋ฒจ ๋ชฉ๋ก ๋ฐ ํด๋ž˜์Šค๋ณ„ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜")
print("-" * 60)
# ------------------------------------------------------------
# 1. Food101์ฒ˜๋Ÿผ label์ด ClassLabel ํƒ€์ž…์ธ ๊ฒฝ์šฐ
# ------------------------------------------------------------
# label_feature.names๊ฐ€ ์žˆ์œผ๋ฉด ์›๋ž˜ ๋ฐ์ดํ„ฐ์…‹์˜ ๋ผ๋ฒจ ์ˆœ์„œ๋Œ€๋กœ ์ถœ๋ ฅํ•œ๋‹ค.
if hasattr(label_feature, "names") and label_feature.names is not None:
label_names = label_feature.names
for idx, label_name in enumerate(label_names):
count = label_counter.get(label_name, 0)
print(f"{idx}: {label_name} - {count} ์žฅ")
# ------------------------------------------------------------
# 2. label์ด ๋ฌธ์ž์—ด๋กœ ์ง์ ‘ ๋“ค์–ด์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹์ธ ๊ฒฝ์šฐ
# ------------------------------------------------------------
# Counter์— ๋ชจ์ธ ๋ผ๋ฒจ๋ช…์„ ์ด๋ฆ„์ˆœ์œผ๋กœ ์ •๋ ฌํ•ด์„œ ์ถœ๋ ฅํ•œ๋‹ค.
else:
label_names = sorted(label_counter.keys())
for idx, label_name in enumerate(label_names):
count = label_counter[label_name]
print(f"{idx}: {label_name} - {count} ์žฅ")
print("-" * 60)
print(f"์ด ๋ผ๋ฒจ ๊ฐœ์ˆ˜: {len(label_counter)}")
print(f"์ด ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜: {sum(label_counter.values())}")
return label_counter
if __name__ == "__main__":
get_unique_labels_with_counts()