Mini-ImageNet / src /collection /get_label_list_hf.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
Raw
History Blame Contribute Delete
2.95 kB
from datasets import load_dataset
# ============================================================
# [์„ค์ • ๋ถ€๋ถ„]
# ============================================================
# ํ™•์ธํ•  Hugging Face ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„
DATASET_NAME = "KrushiJethe/fashion_data" #uran66/animals
# ํ™•์ธํ•  split ์ด๋ฆ„
SPLIT_NAME = "train"
# ๋ผ๋ฒจ ํ•„๋“œ๋ช…
LABEL_FIELD_NAME = "articleType"
# streaming ์‚ฌ์šฉ ์—ฌ๋ถ€
# ๋ผ๋ฒจ ๊ตฌ์กฐ๋งŒ ํ™•์ธํ•  ๋•Œ๋Š” streaming=True๋กœ ํ•ด๋„ ๋œ๋‹ค.
USE_STREAMING = True
# ๋ฌธ์ž์—ด ๋ผ๋ฒจ ๋ฐ์ดํ„ฐ์…‹์ผ ๊ฒฝ์šฐ ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ํ›‘์–ด์•ผ ํ•  ์ˆ˜ ์žˆ๋‹ค.
# None์ด๋ฉด ์ „์ฒด ํ™•์ธ, ์ˆซ์ž๋ฅผ ๋„ฃ์œผ๋ฉด ์ผ๋ถ€ ์ƒ˜ํ”Œ๋งŒ ํ™•์ธํ•œ๋‹ค.
MAX_SCAN_ITEMS = None
# ============================================================
def get_unique_labels():
"""
Hugging Face ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ผ๋ฒจ ๋ชฉ๋ก์„ ์ค‘๋ณต ์—†์ด ์ถœ๋ ฅํ•œ๋‹ค.
"""
print(f"[{DATASET_NAME}] ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์ค‘...")
dataset = load_dataset(
DATASET_NAME,
split=SPLIT_NAME,
streaming=USE_STREAMING,
)
# ๋ฐ์ดํ„ฐ์…‹์˜ feature ์ •๋ณด์—์„œ ๋ผ๋ฒจ ํ•„๋“œ๋ฅผ ๊ฐ€์ ธ์˜จ๋‹ค.
label_feature = dataset.features[LABEL_FIELD_NAME]
# ------------------------------------------------------------
# 1. Food101์ฒ˜๋Ÿผ label์ด ClassLabel ํƒ€์ž…์ธ ๊ฒฝ์šฐ
# ------------------------------------------------------------
# ์ด ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ ์ „์ฒด๋ฅผ ์ˆœํšŒํ•˜์ง€ ์•Š์•„๋„
# dataset.features["label"].names ์—์„œ ์ „์ฒด ๋ผ๋ฒจ๋ช…์„ ๋ฐ”๋กœ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
if hasattr(label_feature, "names") and label_feature.names is not None:
label_names = label_feature.names
print("\n๋ผ๋ฒจ ๋ชฉ๋ก")
print("-" * 50)
for idx, label_name in enumerate(label_names):
print(f"{idx}: {label_name}")
print("-" * 50)
print(f"์ด ๋ผ๋ฒจ ๊ฐœ์ˆ˜: {len(label_names)}")
return label_names
# ------------------------------------------------------------
# 2. label์ด ๋ฌธ์ž์—ด๋กœ ์ง์ ‘ ๋“ค์–ด์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹์ธ ๊ฒฝ์šฐ
# ------------------------------------------------------------
# ์ด ๊ฒฝ์šฐ์—๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ์ง์ ‘ ์ˆœํšŒํ•˜๋ฉด์„œ ์ค‘๋ณต์„ ์ œ๊ฑฐํ•ด์•ผ ํ•œ๋‹ค.
unique_labels = set()
print("\n๋ผ๋ฒจ ํ•„๋“œ๊ฐ€ ClassLabel ํƒ€์ž…์ด ์•„๋‹ˆ๋ฏ€๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ์ˆœํšŒํ•ฉ๋‹ˆ๋‹ค...")
for idx, item in enumerate(dataset):
if MAX_SCAN_ITEMS is not None and idx >= MAX_SCAN_ITEMS:
break
label_value = item.get(LABEL_FIELD_NAME)
if label_value is None:
continue
unique_labels.add(str(label_value))
label_names = sorted(unique_labels)
print("\n๋ผ๋ฒจ ๋ชฉ๋ก")
print("-" * 50)
for idx, label_name in enumerate(label_names):
print(f"{idx}: {label_name}")
print("-" * 50)
print(f"์ด ๋ผ๋ฒจ ๊ฐœ์ˆ˜: {len(label_names)}")
return label_names
if __name__ == "__main__":
get_unique_labels()