MuscleCare-Train-Hybrid / load_dataset.py
Merry99's picture
Restore: Space paused fixes - memory optimization and error handling
9a12dde
import json
import os
from pathlib import Path
from typing import Iterable, List, Optional, Tuple
import pandas as pd
from datasets import get_dataset_config_names, get_dataset_split_names, load_dataset
from huggingface_hub import hf_hub_download
DEFAULT_DATASET_ID = "Merry99/MuscleCare-DataSet"
DEFAULT_DATASET_SPLITS = [
"local_user",
"ios_D7ED673185E248BD9DC1102E881E9111",
"android_SP1A.210812.016",
] + [f"user_{i:03d}" for i in range(1, 51)]
def download_parquet_from_hub(
repo_id: str,
filenames: Iterable[str],
local_dir: str = "./data",
repo_type: str = "dataset",
token: Optional[str] = None,
) -> List[Path]:
"""
(์˜ต์…˜) Hugging Face Hub์—์„œ parquet ํŒŒ์ผ์„ ๋‚ด๋ ค๋ฐ›์•„ ๋กœ์ปฌ์— ์ €์žฅ.
Space์™€ ๋™์ผํ•œ ํ™˜๊ฒฝ์„ ์œ„ํ•ด ํ•„์š” ์‹œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
"""
target_dir = Path(local_dir)
target_dir.mkdir(parents=True, exist_ok=True)
downloaded: List[Path] = []
for name in filenames:
local_path = Path(
hf_hub_download(
repo_id=repo_id,
filename=name,
repo_type=repo_type,
token=token,
local_dir=target_dir,
local_dir_use_symlinks=False,
)
)
downloaded.append(local_path)
return downloaded
def resolve_parquet_files(data_dir: str = "./data", pattern: str = "user*.parquet") -> List[Path]:
"""
๋ฐ์ดํ„ฐ ๋””๋ ‰ํ† ๋ฆฌ์—์„œ parquet ํŒŒ์ผ ๋ชฉ๋ก์„ ์ •๋ ฌ๋œ ์ƒํƒœ๋กœ ๋ฐ˜ํ™˜.
"""
data_path = Path(data_dir)
if not data_path.exists():
raise FileNotFoundError(f"๋ฐ์ดํ„ฐ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {data_dir}")
parquet_files = sorted(data_path.glob(pattern))
if not parquet_files:
raise FileNotFoundError(f"ํŒจํ„ด({pattern})์— ํ•ด๋‹นํ•˜๋Š” parquet ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค.")
return parquet_files
def parse_user_embedding(raw_emb, fallback_dim: int = 12) -> List[float]:
"""
๋ฌธ์ž์—ด/๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ์˜ user_emb๋ฅผ ๊ณ ์ • ๊ธธ์ด ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜.
"""
if isinstance(raw_emb, str):
try:
raw_emb = json.loads(raw_emb)
except json.JSONDecodeError:
raw_emb = []
if isinstance(raw_emb, (list, tuple)):
values = list(raw_emb)
else:
values = []
if not values:
values = [0.0] * fallback_dim
if len(values) < fallback_dim:
values = values + [0.0] * (fallback_dim - len(values))
else:
values = values[:fallback_dim]
return [float(v) for v in values]
def normalize_user_embeddings(df: pd.DataFrame, emb_dim: int) -> pd.DataFrame:
if "user_emb" not in df.columns:
raise KeyError("๋ฐ์ดํ„ฐ์…‹์— 'user_emb' ์ปฌ๋Ÿผ์ด ์—†์Šต๋‹ˆ๋‹ค.")
df = df.copy()
df["user_emb"] = df["user_emb"].apply(lambda v: parse_user_embedding(v, emb_dim))
return df
def _resolve_config_name(repo_id: str) -> Optional[str]:
try:
configs = get_dataset_config_names(repo_id)
if configs:
return configs[0]
except Exception:
pass
return None
def _load_split_dataframe(
repo_id: str,
split_name: str,
cache_dir: str,
config_name: Optional[str],
) -> Optional[pd.DataFrame]:
load_kwargs = {
"path": repo_id,
"split": split_name,
"cache_dir": cache_dir,
}
if config_name:
load_kwargs["name"] = config_name
try:
ds = load_dataset(**load_kwargs)
except ValueError as exc:
print(f"โš ๏ธ split '{split_name}' ๊ฑด๋„ˆ๋œ€: {exc}")
return None
return ds.to_pandas() if hasattr(ds, "to_pandas") else ds.to_pandas()
def load_dataset_from_hub(
repo_id: Optional[str] = None,
split: Optional[str] = None,
cache_dir: Optional[str] = None,
emb_dim: int = 12,
exclude_sessions: Optional[Iterable[str]] = None,
) -> Tuple[pd.DataFrame, List[str]]:
"""
Hugging Face Dataset์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•ด DataFrame์œผ๋กœ ๋ณ€ํ™˜.
exclude_sessions์— ํฌํ•จ๋œ session_id๋Š” ์ œ์™ธํ•ฉ๋‹ˆ๋‹ค.
"""
repo_id = repo_id or DEFAULT_DATASET_ID
cache_dir = cache_dir or os.getenv("HF_DATASET_CACHE_DIR", "./data/hf_cache")
config_name = _resolve_config_name(repo_id)
if split:
split_names = [split]
else:
try:
split_names = get_dataset_split_names(repo_id, config_name)
except Exception:
split_names = DEFAULT_DATASET_SPLITS
frames: List[pd.DataFrame] = []
for split_name in split_names:
df_part = _load_split_dataframe(
repo_id=repo_id,
split_name=split_name,
cache_dir=cache_dir,
config_name=config_name,
)
if df_part is not None and not df_part.empty:
frames.append(df_part)
if not frames:
raise ValueError("NO_DATA_AVAILABLE")
df = pd.concat(frames, ignore_index=True)
if "session_id" not in df.columns:
raise KeyError("๋ฐ์ดํ„ฐ์…‹์— 'session_id' ์ปฌ๋Ÿผ์ด ์—†์Šต๋‹ˆ๋‹ค.")
exclude_set = set(str(s) for s in (exclude_sessions or []))
if exclude_set:
df = df[~df["session_id"].astype(str).isin(exclude_set)]
session_ids = sorted(df["session_id"].dropna().astype(str).unique().tolist())
df = normalize_user_embeddings(df, emb_dim)
return df, session_ids
def load_parquet_dataset(
data_dir: str = "./data",
pattern: str = "user*.parquet",
emb_dim: int = 12,
) -> pd.DataFrame:
"""
๋ฐ์ดํ„ฐ๊ฐ€ ๋กœ์ปฌ์— ์—†์œผ๋ฉด ์ž๋™์œผ๋กœ Hugging Face Dataset์—์„œ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
"""
try:
parquet_files = resolve_parquet_files(data_dir, pattern)
frames = [pd.read_parquet(path) for path in parquet_files]
data = pd.concat(frames, ignore_index=True)
return normalize_user_embeddings(data, emb_dim)
except FileNotFoundError:
# ๋กœ์ปฌ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋‹ค๋ฉด HF Dataset์—์„œ ์ง์ ‘ ๋กœ๋“œ
print("โš ๏ธ ๋กœ์ปฌ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์–ด Hugging Face Dataset์—์„œ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.")
df, _ = load_dataset_from_hub(emb_dim=emb_dim)
return df