MuscleCare-FastAPI / augment_dataset.py
Merry99's picture
add augment py
e2b68a6
import os
import random
import json
from datetime import datetime, timezone, timedelta
from typing import Dict, List, Optional
import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi
from dotenv import load_dotenv
TARGET_USERS = 20
RECORDS_PER_USER = 500
def require_env(var_name: str) -> str:
value = os.getenv(var_name)
if not value:
raise RuntimeError(f"ํ™˜๊ฒฝ๋ณ€์ˆ˜ {var_name}๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.")
return value
def jitter(value: float, scale: float = 0.02) -> float:
"""๊ฐ’์— ยฑscale ๋น„์œจ์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€"""
if value is None:
return None
return value * (1 + random.uniform(-scale, scale))
def jitter_abs(value: float, amount: float) -> float:
"""์ ˆ๋Œ€๊ฐ’ ๊ธฐ์ค€ ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€"""
if value is None:
return None
return value + random.uniform(-amount, amount)
def augment_sensor_vector(x: float, y: float, z: float, noise: float = 0.02) -> tuple:
"""
3์ถ• ์„ผ์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฌผ๋ฆฌ์ ์œผ๋กœ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ฆํญ
โ†’ 3์ถ•์€ ๋™์ผํ•œ ๋น„์œจ๋กœ scaling + ๊ฐœ๋ณ„ ์ž‘์€ ๋…ธ์ด์ฆˆ
"""
if x is None or y is None or z is None:
return (x, y, z)
scale = 1 + random.uniform(-noise, noise)
return (
round(x * scale + random.uniform(-0.01, 0.01), 4),
round(y * scale + random.uniform(-0.01, 0.01), 4),
round(z * scale + random.uniform(-0.01, 0.01), 4),
)
def compute_rms(x: float, y: float, z: float, base_noise: float = 0.02) -> float:
"""3์ถ• mean ๊ธฐ๋ฐ˜์œผ๋กœ RMS ์žฌ๊ณ„์‚ฐ"""
if x is None or y is None or z is None:
return None
base = np.sqrt(x**2 + y**2 + z**2)
return round(base * (1 + random.uniform(-base_noise, base_noise)), 4)
def augment_record_strict(row: dict) -> dict:
"""๋ฌผ๋ฆฌ์  ์ œ์•ฝ์„ ์ง€ํ‚ค๋ฉด์„œ ์„ผ์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ฆํญ"""
new = row.copy()
# timestamp jitter
if "timestamp_utc" in row and isinstance(row["timestamp_utc"], str):
try:
t = datetime.fromisoformat(row["timestamp_utc"].replace("Z", "+00:00"))
t = t + timedelta(milliseconds=random.randint(-150, 150))
new["timestamp_utc"] = t.isoformat()
except:
pass
# window jitter
if "window_id" in row and row["window_id"] is not None:
new["window_id"] = int(row["window_id"] + random.randint(-1, 1))
if "window_start_ms" in row and row["window_start_ms"] is not None:
new["window_start_ms"] = row["window_start_ms"] + random.randint(-50, 50)
if "window_end_ms" in row and row["window_end_ms"] is not None:
new["window_end_ms"] = new["window_start_ms"] + 2000 # window_size_ms์™€ ์ผ์น˜
# --- Accelerometer mean ---
if all(f in row and row[f] is not None for f in ["acc_x_mean", "acc_y_mean", "acc_z_mean"]):
new["acc_x_mean"], new["acc_y_mean"], new["acc_z_mean"] = augment_sensor_vector(
row["acc_x_mean"], row["acc_y_mean"], row["acc_z_mean"], noise=0.03
)
# --- Gyro mean ---
if all(f in row and row[f] is not None for f in ["gyro_x_mean", "gyro_y_mean", "gyro_z_mean"]):
new["gyro_x_mean"], new["gyro_y_mean"], new["gyro_z_mean"] = augment_sensor_vector(
row["gyro_x_mean"], row["gyro_y_mean"], row["gyro_z_mean"], noise=0.03
)
# --- Linear accel mean ---
if all(f in row and row[f] is not None for f in ["linacc_x_mean", "linacc_y_mean", "linacc_z_mean"]):
new["linacc_x_mean"], new["linacc_y_mean"], new["linacc_z_mean"] = augment_sensor_vector(
row["linacc_x_mean"], row["linacc_y_mean"], row["linacc_z_mean"], noise=0.03
)
# --- Gravity vector (๋ฌผ๋ฆฌ์  ์ œ์•ฝ: ํฌ๊ธฐ๊ฐ€ ์•ฝ 9.8) ---
if all(f in row and row[f] is not None for f in ["gravity_x_mean", "gravity_y_mean", "gravity_z_mean"]):
gx, gy, gz = augment_sensor_vector(
row["gravity_x_mean"], row["gravity_y_mean"], row["gravity_z_mean"], noise=0.01
)
g_mag = np.sqrt(gx**2 + gy**2 + gz**2)
if g_mag > 0:
scale = 9.8 / g_mag
new["gravity_x_mean"] = round(gx * scale, 4)
new["gravity_y_mean"] = round(gy * scale, 4)
new["gravity_z_mean"] = round(gz * scale, 4)
# --- Recompute RMS from sensor means ---
if all(f in new and new[f] is not None for f in ["acc_x_mean", "acc_y_mean", "acc_z_mean"]):
new["rms_acc"] = compute_rms(
new["acc_x_mean"], new["acc_y_mean"], new["acc_z_mean"], base_noise=0.03
)
elif "rms_acc" in row and row["rms_acc"] is not None:
new["rms_acc"] = jitter(row["rms_acc"], 0.03)
if all(f in new and new[f] is not None for f in ["gyro_x_mean", "gyro_y_mean", "gyro_z_mean"]):
new["rms_gyro"] = compute_rms(
new["gyro_x_mean"], new["gyro_y_mean"], new["gyro_z_mean"], base_noise=0.03
)
elif "rms_gyro" in row and row["rms_gyro"] is not None:
new["rms_gyro"] = jitter(row["rms_gyro"], 0.03)
# --- std values scale with RMS ---
if "rms_acc" in new and new["rms_acc"] is not None and "rms_acc" in row and row["rms_acc"] is not None and row["rms_acc"] > 0:
rms_ratio = new["rms_acc"] / row["rms_acc"]
for col in ["acc_x_std", "acc_y_std", "acc_z_std"]:
if col in row and row[col] is not None:
new[col] = max(0.01, row[col] * rms_ratio * jitter(1, 0.1))
if "rms_gyro" in new and new["rms_gyro"] is not None and "rms_gyro" in row and row["rms_gyro"] is not None and row["rms_gyro"] > 0:
rms_ratio = new["rms_gyro"] / row["rms_gyro"]
for col in ["gyro_x_std", "gyro_y_std", "gyro_z_std"]:
if col in row and row[col] is not None:
new[col] = max(0.001, row[col] * rms_ratio * jitter(1, 0.1))
# --- frequency (weak positive correlation with RMS) ---
if "mean_freq_acc" in row and row["mean_freq_acc"] is not None and "rms_acc" in new and new["rms_acc"] is not None:
new["mean_freq_acc"] = round(jitter_abs(row["mean_freq_acc"], new["rms_acc"] * 0.3), 2)
elif "mean_freq_acc" in row and row["mean_freq_acc"] is not None:
new["mean_freq_acc"] = round(jitter(row["mean_freq_acc"], 0.02), 2)
if "mean_freq_gyro" in row and row["mean_freq_gyro"] is not None and "rms_gyro" in new and new["rms_gyro"] is not None:
new["mean_freq_gyro"] = round(jitter_abs(row["mean_freq_gyro"], new["rms_gyro"] * 0.3), 2)
elif "mean_freq_gyro" in row and row["mean_freq_gyro"] is not None:
new["mean_freq_gyro"] = round(jitter(row["mean_freq_gyro"], 0.02), 2)
# --- entropy: increases when RMS increases ---
if "entropy_acc" in row and row["entropy_acc"] is not None and "rms_acc" in new and new["rms_acc"] is not None and "rms_acc" in row and row["rms_acc"] is not None and row["rms_acc"] > 0:
new["entropy_acc"] = min(1.0, max(0.05, row["entropy_acc"] * (new["rms_acc"] / row["rms_acc"]) * jitter(1, 0.1)))
elif "entropy_acc" in row and row["entropy_acc"] is not None:
new["entropy_acc"] = min(1.0, max(0.05, jitter(row["entropy_acc"], 0.02)))
if "entropy_gyro" in row and row["entropy_gyro"] is not None and "rms_gyro" in new and new["rms_gyro"] is not None and "rms_gyro" in row and row["rms_gyro"] is not None and row["rms_gyro"] > 0:
new["entropy_gyro"] = min(1.0, max(0.05, row["entropy_gyro"] * (new["rms_gyro"] / row["rms_gyro"]) * jitter(1, 0.1)))
elif "entropy_gyro" in row and row["entropy_gyro"] is not None:
new["entropy_gyro"] = min(1.0, max(0.05, jitter(row["entropy_gyro"], 0.02)))
# --- jerk: depends on std and RMS ---
if "jerk_mean" in row and row["jerk_mean"] is not None:
if "acc_x_std" in row and row["acc_x_std"] is not None:
new["jerk_mean"] = round(jitter_abs(row["jerk_mean"], row["acc_x_std"] * 0.3), 4)
else:
new["jerk_mean"] = round(jitter(row["jerk_mean"], 0.02), 4)
if "jerk_std" in row and row["jerk_std"] is not None:
if "acc_x_std" in row and row["acc_x_std"] is not None:
new["jerk_std"] = max(0.001, round(jitter_abs(row["jerk_std"], row["acc_x_std"] * 0.1), 4))
else:
new["jerk_std"] = max(0.001, round(jitter(row["jerk_std"], 0.01), 4))
# --- stability index (inverse to entropy) ---
entropy_avg = 0.5
if "entropy_acc" in new and new["entropy_acc"] is not None and "entropy_gyro" in new and new["entropy_gyro"] is not None:
entropy_avg = (new["entropy_acc"] + new["entropy_gyro"]) / 2
elif "entropy_acc" in new and new["entropy_acc"] is not None:
entropy_avg = new["entropy_acc"]
elif "entropy_gyro" in new and new["entropy_gyro"] is not None:
entropy_avg = new["entropy_gyro"]
new["stability_index"] = round(max(0.4, min(0.99, 1 - entropy_avg * 0.3)), 4)
# --- fatigue model (RMS, ์ฃผํŒŒ์ˆ˜ ๊ธฐ๋ฐ˜) ---
# fatigue๋Š” augment_user_data์—์„œ ์‹œ๊ฐ„์  ์—ฐ์†์„ฑ์„ ๊ณ ๋ คํ•˜์—ฌ ๊ณ„์‚ฐ
# ์—ฌ๊ธฐ์„œ๋Š” ๊ธฐ๋ณธ๊ฐ’๋งŒ ์„ค์ • (๋‚˜์ค‘์— ๋ฎ์–ด์”Œ์›Œ์ง)
if "fatigue" in row and row["fatigue"] is not None:
# ๊ธฐ๋ณธ์ ์œผ๋กœ RMS์™€ ์ฃผํŒŒ์ˆ˜ ๊ธฐ๋ฐ˜์œผ๋กœ ์•ฝ๊ฐ„ ์กฐ์ •
if "rms_acc" in new and new["rms_acc"] is not None and "rms_acc" in row and row["rms_acc"] is not None and row["rms_acc"] > 0.1:
rms_factor = new["rms_acc"] / row["rms_acc"]
else:
rms_factor = 1.0
if "mean_freq_acc" in new and new["mean_freq_acc"] is not None and "mean_freq_acc" in row and row["mean_freq_acc"] is not None and row["mean_freq_acc"] > 1:
freq_factor = row["mean_freq_acc"] / new["mean_freq_acc"]
else:
freq_factor = 1.0
fatigue_delta = rms_factor * 0.05 - freq_factor * 0.03
new["fatigue"] = min(0.95, max(0.05, row["fatigue"] + fatigue_delta + random.uniform(-0.02, 0.02)))
new["fatigue_level"] = 0 if new["fatigue"] < 0.3 else 1 if new["fatigue"] < 0.6 else 2
else:
# fatigue๊ฐ€ ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
new["fatigue"] = 0.1
new["fatigue_level"] = 0
# fatigue_prev๋Š” augment_user_data์—์„œ ์„ค์ •๋จ
if "fatigue_prev" in row and row["fatigue_prev"] is not None:
new["fatigue_prev"] = row["fatigue_prev"]
else:
new["fatigue_prev"] = 0.05
# --- baseline values (preserve) ---
if "rms_base" in row:
new["rms_base"] = row["rms_base"]
if "freq_base" in row:
new["freq_base"] = row["freq_base"]
# --- user_emb: NEVER change ---
if "user_emb" in row:
new["user_emb"] = row["user_emb"]
# --- other fields ---
if "overlap_rate" in row and row["overlap_rate"] is not None:
new["overlap_rate"] = max(0.3, min(0.7, jitter(row["overlap_rate"], 0.02)))
if "window_size_ms" in row:
new["window_size_ms"] = row.get("window_size_ms", 2000)
if "quality_flag" in row:
if random.random() < 0.05: # 5% ํ™•๋ฅ ๋กœ ๋ณ€๊ฒฝ
new["quality_flag"] = 0 if row["quality_flag"] == 1 else 1
else:
new["quality_flag"] = row["quality_flag"]
# session_id ์•ฝ๊ฐ„ ๋ณ€ํ˜•
if "session_id" in row and row["session_id"]:
parts = str(row["session_id"]).split("_")
if len(parts) > 1:
try:
session_num = int(parts[-1])
new["session_id"] = "_".join(parts[:-1]) + "_" + str(session_num + random.randint(-5, 5))
except:
new["session_id"] = row["session_id"]
else:
new["session_id"] = row["session_id"]
return new
def augment_user_data(df: pd.DataFrame, target_count: int, new_user_id: str = None) -> pd.DataFrame:
"""
์‚ฌ์šฉ์ž๋ณ„ ๋ฐ์ดํ„ฐ๋ฅผ ์ฆํญํ•˜์—ฌ ๋ชฉํ‘œ ๊ฐœ์ˆ˜๋งŒํผ ์ƒ์„ฑ
์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž์ธ ๊ฒฝ์šฐ ์‹œ๊ฐ„์  ์—ฐ์†์„ฑ์„ ์œ ์ง€
"""
if len(df) >= target_count:
return df.head(target_count)
need = target_count - len(df)
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž์ธ ๊ฒฝ์šฐ (๊ธฐ์กด ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๊ฑฐ๋‚˜ ์ƒˆ ์‚ฌ์šฉ์ž ID๊ฐ€ ์ œ๊ณต๋œ ๊ฒฝ์šฐ)
is_new_user = new_user_id is not None or len(df) == 0
if is_new_user and len(df) > 0:
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž๋Š” ํ•ญ์ƒ target_count๋งŒํผ ์ƒ์„ฑ (์ฐธ์กฐ ๋ฐ์ดํ„ฐ ๊ธธ์ด์™€ ๋ฌด๊ด€)
base_row = df.iloc[0].to_dict()
new_rows = []
# ์‹œ๊ฐ„ ๊ธฐ๋ฐ˜ ์ดˆ๊ธฐ๊ฐ’ ์„ค์ •
if "timestamp_utc" in base_row and base_row["timestamp_utc"]:
try:
base_time = datetime.fromisoformat(str(base_row["timestamp_utc"]).replace("Z", "+00:00"))
except:
base_time = datetime.now(timezone.utc)
else:
base_time = datetime.now(timezone.utc)
base_window_id = 1 # ์ƒˆ ์‚ฌ์šฉ์ž๋Š” window_id๋ฅผ 1๋ถ€ํ„ฐ ์‹œ์ž‘
base_window_start = 0 # ์ƒˆ ์‚ฌ์šฉ์ž๋Š” window_start_ms๋ฅผ 0๋ถ€ํ„ฐ ์‹œ์ž‘
prev_fatigue = base_row.get("fatigue", 0.1) if base_row.get("fatigue") is not None else 0.1
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž๋Š” ํ•ญ์ƒ target_count๋งŒํผ ์ƒ์„ฑ
for i in range(target_count):
# ์ƒ˜ํ”Œ ๋ ˆ์ฝ”๋“œ ์„ ํƒ
sample_idx = random.randint(0, len(df) - 1)
sample = df.iloc[sample_idx].to_dict()
# ์ƒˆ๋กœ์šด ๋ ˆ์ฝ”๋“œ ์ƒ์„ฑ
new_row = augment_record_strict(sample)
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ID ์„ค์ •
if new_user_id:
new_row["user_id"] = new_user_id
# ์‹œ๊ฐ„์  ์—ฐ์†์„ฑ ์œ ์ง€
window_interval = 2000 # window_size_ms
new_row["window_id"] = base_window_id + i
new_row["window_start_ms"] = base_window_start + i * window_interval
new_row["window_end_ms"] = new_row["window_start_ms"] + window_interval
# timestamp ์—ฐ์†์„ฑ ์œ ์ง€
new_row["timestamp_utc"] = (base_time + timedelta(milliseconds=i * window_interval)).isoformat()
# ํ”ผ๋กœ๋„ ์—ฐ์†์„ฑ ์œ ์ง€ (์ด์ „ ํ”ผ๋กœ๋„๋Š” ์ง์ „ ๋ ˆ์ฝ”๋“œ์˜ ํ”ผ๋กœ๋„)
if i > 0:
new_row["fatigue_prev"] = prev_fatigue
else:
# ์ฒซ ๋ ˆ์ฝ”๋“œ๋Š” ์ฐธ์กฐ ๋ฐ์ดํ„ฐ์˜ ํ”ผ๋กœ๋„์—์„œ ์•ฝ๊ฐ„ ๋‚ฎ๊ฒŒ ์‹œ์ž‘
new_row["fatigue_prev"] = max(0.05, prev_fatigue - random.uniform(0, 0.05))
# ํ˜„์žฌ ํ”ผ๋กœ๋„๋Š” ์ด์ „ ํ”ผ๋กœ๋„ ๊ธฐ๋ฐ˜์œผ๋กœ ์•ฝ๊ฐ„ ์ฆ๊ฐ€ํ•˜๋Š” ๊ฒฝํ–ฅ (์‹ค์ œ ์ธก์ •๊ณผ ์œ ์‚ฌ)
if "fatigue" in new_row and new_row["fatigue"] is not None:
# ํ”ผ๋กœ๋„๋Š” ์‹œ๊ฐ„์— ๋”ฐ๋ผ ์ ์ง„์ ์œผ๋กœ ์ฆ๊ฐ€ํ•˜๋Š” ๊ฒฝํ–ฅ
fatigue_base = new_row["fatigue_prev"] if "fatigue_prev" in new_row else prev_fatigue
# ์•ฝ๊ฐ„์˜ ์ฆ๊ฐ€ + ๋…ธ์ด์ฆˆ
fatigue_increase = random.uniform(0, 0.02) # ์‹œ๊ฐ„์— ๋”ฐ๋ฅธ ์ ์ง„์  ์ฆ๊ฐ€
new_row["fatigue"] = min(0.95, max(0.05, fatigue_base + fatigue_increase + random.uniform(-0.01, 0.01)))
new_row["fatigue_level"] = 0 if new_row["fatigue"] < 0.3 else 1 if new_row["fatigue"] < 0.6 else 2
prev_fatigue = new_row["fatigue"]
# ์„ธ์…˜ ID ์ƒ์„ฑ (์ƒˆ ์‚ฌ์šฉ์ž์ด๋ฏ€๋กœ ์ƒˆ๋กœ์šด ์„ธ์…˜)
if "session_id" in new_row:
new_row["session_id"] = f"session_{i // 10 + 1:03d}" # 10๊ฐœ ๋ ˆ์ฝ”๋“œ๋‹น ์„ธ์…˜
# measure_date๋Š” ๊ธฐ์กด ๋ฐ์ดํ„ฐ์— ์žˆ๋Š” ๊ฒฝ์šฐ์—๋งŒ ์„ค์ •
if "measure_date" in sample:
try:
measure_time = datetime.fromisoformat(new_row["timestamp_utc"].replace("Z", "+00:00"))
new_row["measure_date"] = measure_time.strftime("%Y-%m-%d")
except:
new_row["measure_date"] = base_time.strftime("%Y-%m-%d")
new_rows.append(new_row)
return pd.DataFrame(new_rows)
else:
# ๊ธฐ์กด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ ์ฆํญ (์‹œ๊ฐ„์  ์—ฐ์†์„ฑ ์œ ์ง€)
new_rows = []
last_row = df.iloc[-1].to_dict()
# ๋งˆ์ง€๋ง‰ ๋ ˆ์ฝ”๋“œ์˜ ์‹œ๊ฐ„ ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
if "timestamp_utc" in last_row and last_row["timestamp_utc"]:
try:
last_time = datetime.fromisoformat(str(last_row["timestamp_utc"]).replace("Z", "+00:00"))
except:
last_time = datetime.now(timezone.utc)
else:
last_time = datetime.now(timezone.utc)
last_window_id = last_row.get("window_id", 0) if last_row.get("window_id") is not None else 0
last_window_start = last_row.get("window_end_ms", 0) if last_row.get("window_end_ms") is not None else 0
prev_fatigue = last_row.get("fatigue", 0.1) if last_row.get("fatigue") is not None else 0.1
for i in range(need):
# ์ƒ˜ํ”Œ ๋ ˆ์ฝ”๋“œ ์„ ํƒ
sample_idx = random.randint(0, len(df) - 1)
sample = df.iloc[sample_idx].to_dict()
# ์ƒˆ๋กœ์šด ๋ ˆ์ฝ”๋“œ ์ƒ์„ฑ
new_row = augment_record_strict(sample)
# ์‹œ๊ฐ„์  ์—ฐ์†์„ฑ ์œ ์ง€
window_interval = 2000
new_row["window_id"] = last_window_id + i + 1
new_row["window_start_ms"] = last_window_start + i * window_interval
new_row["window_end_ms"] = new_row["window_start_ms"] + window_interval
# timestamp ์—ฐ์†์„ฑ ์œ ์ง€
new_row["timestamp_utc"] = (last_time + timedelta(milliseconds=(i + 1) * window_interval)).isoformat()
# ํ”ผ๋กœ๋„ ์—ฐ์†์„ฑ ์œ ์ง€
new_row["fatigue_prev"] = prev_fatigue
if "fatigue" in new_row and new_row["fatigue"] is not None:
# ํ”ผ๋กœ๋„๋Š” ์‹œ๊ฐ„์— ๋”ฐ๋ผ ์ ์ง„์ ์œผ๋กœ ์ฆ๊ฐ€ํ•˜๋Š” ๊ฒฝํ–ฅ
fatigue_increase = random.uniform(0, 0.02) # ์‹œ๊ฐ„์— ๋”ฐ๋ฅธ ์ ์ง„์  ์ฆ๊ฐ€
new_row["fatigue"] = min(0.95, max(0.05, prev_fatigue + fatigue_increase + random.uniform(-0.01, 0.01)))
new_row["fatigue_level"] = 0 if new_row["fatigue"] < 0.3 else 1 if new_row["fatigue"] < 0.6 else 2
prev_fatigue = new_row["fatigue"]
# measure_date๋Š” ๊ธฐ์กด ๋ฐ์ดํ„ฐ์— ์žˆ๋Š” ๊ฒฝ์šฐ์—๋งŒ ์„ค์ •
if "measure_date" in sample:
try:
measure_time = datetime.fromisoformat(new_row["timestamp_utc"].replace("Z", "+00:00"))
new_row["measure_date"] = measure_time.strftime("%Y-%m-%d")
except:
new_row["measure_date"] = last_time.strftime("%Y-%m-%d")
new_rows.append(new_row)
return pd.concat([df, pd.DataFrame(new_rows)], ignore_index=True)
def main():
load_dotenv()
repo_id = require_env("HF_DATA_REPO_ID")
token = require_env("HF_DATA_TOKEN")
print(f"๐Ÿ“‚ ๊ธฐ์กด ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์ค‘: {repo_id}")
# ๊ฐœ๋ณ„ parquet ํŒŒ์ผ์„ ๋ชจ๋‘ ๋กœ๋“œ (user๋กœ ์‹œ์ž‘ํ•˜์ง€ ์•Š๋Š” ํŒŒ์ผ๋„ ํฌํ•จ)
api = HfApi()
try:
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
# ๋ชจ๋“  parquet ํŒŒ์ผ ํ•„ํ„ฐ๋ง (user๋กœ ์‹œ์ž‘ํ•˜์ง€ ์•Š๋Š” ๊ฒƒ๋„ ํฌํ•จ)
parquet_files = [f for f in files if f.endswith(".parquet")]
print(f"๐Ÿ“Š Parquet ํŒŒ์ผ ์ˆ˜: {len(parquet_files)}")
existing = DatasetDict()
for file_path in parquet_files:
try:
# ํŒŒ์ผ๋ช…์—์„œ ์‚ฌ์šฉ์ž ID ์ถ”์ถœ
# ํ˜•์‹: data/user_xxx.parquet ๋˜๋Š” data/user_xxx-00000-of-00001.parquet
filename = file_path.split("/")[-1] if "/" in file_path else file_path
# .parquet ํ™•์žฅ์ž ์ œ๊ฑฐ
filename_no_ext = filename.replace(".parquet", "")
# -00000-of-00001 ๋ถ€๋ถ„์ด ์žˆ์œผ๋ฉด ์ œ๊ฑฐ, ์—†์œผ๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
if "-" in filename_no_ext:
user_id = filename_no_ext.split("-")[0]
else:
user_id = filename_no_ext
# local_user๋กœ ์‹œ์ž‘ํ•˜๋Š” ํŒŒ์ผ์€ ์ œ์™ธ
if user_id.startswith("local_user"):
print(f"โญ๏ธ {user_id}: local_user๋กœ ์‹œ์ž‘ํ•˜๋Š” ํŒŒ์ผ์€ ์ œ์™ธ")
continue
# ๊ฐœ๋ณ„ ํŒŒ์ผ์„ pandas๋กœ ์ง์ ‘ ๋กœ๋“œ
from huggingface_hub import hf_hub_download
import tempfile
# ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
local_path = hf_hub_download(
repo_id=repo_id,
filename=file_path,
repo_type="dataset",
token=token
)
# pandas๋กœ ์ง์ ‘ ์ฝ๊ธฐ
df = pd.read_parquet(local_path)
if len(df) > 0:
existing[user_id] = Dataset.from_pandas(df, preserve_index=False)
print(f"โœ… {user_id}: {len(df)} ๋ ˆ์ฝ”๋“œ ๋กœ๋“œ")
else:
print(f"โš ๏ธ {user_id}: ๋นˆ ๋ฐ์ดํ„ฐ์…‹, ๊ฑด๋„ˆ๋œ€")
except Exception as e2:
print(f"โš ๏ธ {file_path}: ๋กœ๋“œ ์‹คํŒจ ({str(e2)[:100]}), ๊ฑด๋„ˆ๋œ€")
continue
except Exception as e3:
print(f"โŒ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์™„์ „ ์‹คํŒจ: {e3}")
return
# ์œ ํšจํ•œ ์‚ฌ์šฉ์ž๋งŒ ํ•„ํ„ฐ๋ง (๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋Š” ์‚ฌ์šฉ์ž๋งŒ, local_user ์ œ์™ธ)
valid_users = {}
for user_id in existing.keys():
# local_user๋กœ ์‹œ์ž‘ํ•˜๋Š” ์‚ฌ์šฉ์ž๋Š” ์ œ์™ธ
if user_id.startswith("local_user"):
print(f"โญ๏ธ {user_id}: local_user๋กœ ์‹œ์ž‘ํ•˜๋Š” ์‚ฌ์šฉ์ž๋Š” ์ œ์™ธ")
continue
try:
user_data = existing[user_id]
if len(user_data) > 0:
valid_users[user_id] = user_data
else:
print(f"โš ๏ธ {user_id}: ๋นˆ ๋ฐ์ดํ„ฐ์…‹, ๊ฑด๋„ˆ๋œ€")
except Exception as e:
print(f"โš ๏ธ {user_id}: ๋ฐ์ดํ„ฐ ์ ‘๊ทผ ์‹คํŒจ ({e}), ๊ฑด๋„ˆ๋œ€")
continue
if len(valid_users) == 0:
print("โŒ ์œ ํšจํ•œ ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return
print(f"โœ… ์œ ํšจํ•œ ์‚ฌ์šฉ์ž ์ˆ˜: {len(valid_users)}๋ช…")
# ํ˜„์žฌ ์ด ๋ ˆ์ฝ”๋“œ ์ˆ˜ ๊ณ„์‚ฐ
current_total = sum(len(valid_users[user_id]) for user_id in valid_users)
print(f"๐Ÿ“Š ํ˜„์žฌ ์ด ๋ ˆ์ฝ”๋“œ ์ˆ˜: {current_total}")
# ๊ธฐ์กด ์‚ฌ์šฉ์ž ๋ชฉ๋ก ๊ฐ€์ ธ์˜ค๊ธฐ (์ƒ˜ํ”Œ๋ง์šฉ)
all_users = list(valid_users.keys())
if len(all_users) == 0:
print("โŒ ์ฆํญํ•  ์ฐธ์กฐ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž 20๋ช… ์ƒ์„ฑ (๊ธฐ์กด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ๋ฅผ ์ฐธ์กฐํ•˜์—ฌ ์ฆํญ)
print(f"๐ŸŽฏ ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž {TARGET_USERS}๋ช… ์ƒ์„ฑ ์ค‘...")
print(f"๐Ÿ“‹ ์ฐธ์กฐ ์‚ฌ์šฉ์ž: {len(all_users)}๋ช…")
print(f"๐ŸŽฏ ์‚ฌ์šฉ์ž๋‹น ๋ชฉํ‘œ ๋ ˆ์ฝ”๋“œ ์ˆ˜: {RECORDS_PER_USER}")
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ
new_user_datasets = {}
for i in range(1, TARGET_USERS + 1):
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ID ์ƒ์„ฑ
new_user_id = f"augmented_user_{i:03d}"
# ๊ธฐ์กด ์‚ฌ์šฉ์ž ์ค‘ ๋žœ๋ค ์„ ํƒ (์ฐธ์กฐ์šฉ)
reference_user_id = random.choice(all_users)
reference_df = valid_users[reference_user_id].to_pandas()
if len(reference_df) == 0:
print(f"โš ๏ธ ์ฐธ์กฐ ์‚ฌ์šฉ์ž {reference_user_id}์˜ ๋ฐ์ดํ„ฐ๊ฐ€ ๋น„์–ด์žˆ์–ด ๊ฑด๋„ˆ๋œ€")
continue
try:
# ์ฐธ์กฐ ๋ฐ์ดํ„ฐ๋ฅผ ์ฆํญํ•˜์—ฌ ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ ์ƒ์„ฑ (์ƒˆ ์‚ฌ์šฉ์ž ID ์ „๋‹ฌ)
new_user_df = augment_user_data(reference_df, RECORDS_PER_USER, new_user_id=new_user_id)
# user_id ์ปฌ๋Ÿผ์ด ์—†์œผ๋ฉด ์ถ”๊ฐ€
if "user_id" not in new_user_df.columns:
new_user_df["user_id"] = new_user_id
else:
new_user_df["user_id"] = new_user_id
new_user_datasets[new_user_id] = Dataset.from_pandas(new_user_df, preserve_index=False)
actual_count = len(new_user_df)
print(f"๐Ÿ“ˆ {new_user_id}: {actual_count} ๋ ˆ์ฝ”๋“œ ์ƒ์„ฑ (์ฐธ์กฐ: {reference_user_id}, ๋ชฉํ‘œ: {RECORDS_PER_USER})")
if actual_count != RECORDS_PER_USER:
print(f" โš ๏ธ ๊ฒฝ๊ณ : ์ƒ์„ฑ๋œ ๋ ˆ์ฝ”๋“œ ์ˆ˜({actual_count})๊ฐ€ ๋ชฉํ‘œ({RECORDS_PER_USER})์™€ ๋‹ค๋ฆ…๋‹ˆ๋‹ค!")
except Exception as e:
print(f"โŒ {new_user_id}: ์ƒ์„ฑ ์‹คํŒจ ({e}), ๊ฑด๋„ˆ๋œ€")
continue
if len(new_user_datasets) == 0:
print("โŒ ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ๊ฐ€ ์ƒ์„ฑ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
return
# ๊ธฐ์กด ๋ฐ์ดํ„ฐ์˜ ์Šคํ‚ค๋งˆ ํ™•์ธ (์ฒซ ๋ฒˆ์งธ ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ ๊ธฐ์ค€)
print("๐Ÿ”ง ๊ธฐ์กด ๋ฐ์ดํ„ฐ ์Šคํ‚ค๋งˆ ํ™•์ธ ์ค‘...")
reference_user_id = list(valid_users.keys())[0]
reference_df = valid_users[reference_user_id].to_pandas()
existing_columns = set(reference_df.columns)
print(f" ๐Ÿ“‹ ๊ธฐ์กด ๋ฐ์ดํ„ฐ ์ปฌ๋Ÿผ ์ˆ˜: {len(existing_columns)}")
print(f" ๐Ÿ“‹ ๊ธฐ์กด ๋ฐ์ดํ„ฐ ์ปฌ๋Ÿผ: {sorted(existing_columns)}")
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ๋ฅผ ๊ธฐ์กด ์Šคํ‚ค๋งˆ์— ๋งž์ถค
print("๐Ÿ”ง ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ๋ฅผ ๊ธฐ์กด ์Šคํ‚ค๋งˆ์— ๋งž์ถ”๋Š” ์ค‘...")
for user_id in new_user_datasets.keys():
df = new_user_datasets[user_id].to_pandas()
# ๊ธฐ์กด์— ์—†๋Š” ์ปฌ๋Ÿผ ์ œ๊ฑฐ
columns_to_remove = set(df.columns) - existing_columns
if columns_to_remove:
df = df.drop(columns=list(columns_to_remove))
print(f" โš ๏ธ {user_id}: ๋ถˆํ•„์š”ํ•œ ์ปฌ๋Ÿผ ์ œ๊ฑฐ: {columns_to_remove}")
# ๊ธฐ์กด์— ์žˆ๋Š”๋ฐ ์—†๋Š” ์ปฌ๋Ÿผ ์ถ”๊ฐ€ (None์œผ๋กœ)
columns_to_add = existing_columns - set(df.columns)
if columns_to_add:
for col in columns_to_add:
df[col] = None
print(f" โž• {user_id}: ๋ˆ„๋ฝ๋œ ์ปฌ๋Ÿผ ์ถ”๊ฐ€: {columns_to_add}")
# ์ปฌ๋Ÿผ ์ˆœ์„œ๋ฅผ ๊ธฐ์กด ๋ฐ์ดํ„ฐ์™€ ๋™์ผํ•˜๊ฒŒ ๋งž์ถค
df = df[list(reference_df.columns)]
new_user_datasets[user_id] = Dataset.from_pandas(df, preserve_index=False)
print(f" โœ… {user_id}: ์Šคํ‚ค๋งˆ ์ •๊ทœํ™” ์™„๋ฃŒ")
# ๊ธฐ์กด ๋ฐ์ดํ„ฐ์…‹์— ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ ์ถ”๊ฐ€
final_datasets = {}
# ๊ธฐ์กด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ ์œ ์ง€
for user_id in valid_users.keys():
final_datasets[user_id] = valid_users[user_id]
# ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž ๋ฐ์ดํ„ฐ ์ถ”๊ฐ€
for user_id in new_user_datasets.keys():
final_datasets[user_id] = new_user_datasets[user_id]
final_dict = DatasetDict(final_datasets)
new_users_total = sum(len(new_user_datasets[user_id]) for user_id in new_user_datasets)
total_records = sum(len(final_dict[user_id]) for user_id in final_dict)
print(f"๐Ÿ“Š ์ƒˆ๋กœ์šด ์‚ฌ์šฉ์ž๋“ค์˜ ์ด ๋ ˆ์ฝ”๋“œ ์ˆ˜: {new_users_total}")
print(f"๐Ÿ“Š ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹ ์ด ๋ ˆ์ฝ”๋“œ ์ˆ˜: {total_records}")
print(f"๐Ÿ“Š ์ƒˆ๋กœ์šด parquet ํŒŒ์ผ ์ˆ˜: {len(new_user_datasets)}๊ฐœ")
# local_user๋กœ ์‹œ์ž‘ํ•˜๋Š” ํŒŒ์ผ ์‚ญ์ œ
print("๐Ÿ—‘๏ธ local_user๋กœ ์‹œ์ž‘ํ•˜๋Š” ํŒŒ์ผ ์‚ญ์ œ ์ค‘...")
try:
files_to_delete = []
for file_path in parquet_files:
filename = file_path.split("/")[-1] if "/" in file_path else file_path
filename_no_ext = filename.replace(".parquet", "")
# -00000-of-00001 ๋ถ€๋ถ„์ด ์žˆ์œผ๋ฉด ์ œ๊ฑฐ
if "-" in filename_no_ext:
user_id = filename_no_ext.split("-")[0]
else:
user_id = filename_no_ext
if user_id.startswith("local_user"):
files_to_delete.append(file_path)
for file_path in files_to_delete:
try:
api.delete_file(path_in_repo=file_path, repo_id=repo_id, repo_type="dataset", token=token)
print(f" โœ… ์‚ญ์ œ: {file_path}")
except Exception as e:
print(f" โš ๏ธ ์‚ญ์ œ ์‹คํŒจ ({file_path}): {str(e)[:100]}")
if files_to_delete:
print(f"๐Ÿ—‘๏ธ {len(files_to_delete)}๊ฐœ ํŒŒ์ผ ์‚ญ์ œ ์™„๋ฃŒ")
else:
print("โ„น๏ธ ์‚ญ์ œํ•  local_user ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค")
except Exception as e:
print(f"โš ๏ธ ํŒŒ์ผ ์‚ญ์ œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)[:100]}")
print(f"๐Ÿ“ค Hugging Face Hub์— ์—…๋กœ๋“œ ์ค‘: {repo_id}")
final_dict.push_to_hub(repo_id, token=token, private=True)
print("โœ… ์—…๋กœ๋“œ ์™„๋ฃŒ")
if __name__ == "__main__":
main()