Spaces:
Sleeping
Sleeping
Update augment_dataset.py: Generate 20 new users with 500 records each, compatible with dataset commit fa41e8b
Browse files- augment_dataset.py +384 -0
augment_dataset.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import json
|
| 4 |
+
from datetime import datetime, timezone, timedelta
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
| 10 |
+
from huggingface_hub import HfApi
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
TARGET_USERS = 20
|
| 15 |
+
RECORDS_PER_USER = 500
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def require_env(var_name: str) -> str:
|
| 19 |
+
value = os.getenv(var_name)
|
| 20 |
+
if not value:
|
| 21 |
+
raise RuntimeError(f"환경변수 {var_name}가 필요합니다.")
|
| 22 |
+
return value
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def add_noise(value: float, noise_scale: float) -> float:
|
| 26 |
+
"""값에 노이즈 추가"""
|
| 27 |
+
if value is None:
|
| 28 |
+
return None
|
| 29 |
+
return round(value + random.uniform(-noise_scale, noise_scale), 4)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def bounded(value: float, low: float, high: float) -> float:
|
| 33 |
+
"""값을 범위 내로 제한"""
|
| 34 |
+
if value is None:
|
| 35 |
+
return None
|
| 36 |
+
return max(low, min(high, value))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def augment_record(original: dict, noise_scale: float = 0.1) -> dict:
|
| 40 |
+
"""단일 레코드를 증폭 (물리적 관계와 상관관계를 고려한 의미있는 증폭)"""
|
| 41 |
+
augmented = original.copy()
|
| 42 |
+
|
| 43 |
+
# 시간 정보 변형 (연속성 유지)
|
| 44 |
+
if "timestamp_utc" in augmented and augmented["timestamp_utc"]:
|
| 45 |
+
try:
|
| 46 |
+
base_time = datetime.fromisoformat(augmented["timestamp_utc"].replace("Z", "+00:00"))
|
| 47 |
+
time_delta = timedelta(milliseconds=random.randint(-200, 200))
|
| 48 |
+
augmented["timestamp_utc"] = (base_time + time_delta).isoformat()
|
| 49 |
+
except:
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
# window_id와 시간 범위 약간 조정 (연속성 유지)
|
| 53 |
+
if "window_id" in augmented:
|
| 54 |
+
augmented["window_id"] = augmented["window_id"] + random.randint(-1, 1)
|
| 55 |
+
if "window_start_ms" in augmented:
|
| 56 |
+
augmented["window_start_ms"] = augmented["window_start_ms"] + random.randint(-50, 50)
|
| 57 |
+
if "window_end_ms" in augmented:
|
| 58 |
+
augmented["window_end_ms"] = augmented["window_start_ms"] + 2000 # window_size_ms와 일치
|
| 59 |
+
|
| 60 |
+
# 가속도계 데이터 증폭 (x, y, z 간 상관관계 유지)
|
| 61 |
+
acc_noise = random.uniform(-noise_scale * 0.1, noise_scale * 0.1)
|
| 62 |
+
if "acc_x_mean" in augmented and augmented["acc_x_mean"] is not None:
|
| 63 |
+
augmented["acc_x_mean"] = add_noise(augmented["acc_x_mean"], abs(augmented["acc_x_mean"]) * 0.1 + 0.01)
|
| 64 |
+
if "acc_y_mean" in augmented and augmented["acc_y_mean"] is not None:
|
| 65 |
+
augmented["acc_y_mean"] = add_noise(augmented["acc_y_mean"], abs(augmented["acc_y_mean"]) * 0.1 + 0.01)
|
| 66 |
+
if "acc_z_mean" in augmented and augmented["acc_z_mean"] is not None:
|
| 67 |
+
augmented["acc_z_mean"] = add_noise(augmented["acc_z_mean"], abs(augmented["acc_z_mean"]) * 0.1 + 0.01)
|
| 68 |
+
|
| 69 |
+
# 자이로스코프 데이터 증폭
|
| 70 |
+
gyro_noise = random.uniform(-noise_scale * 0.02, noise_scale * 0.02)
|
| 71 |
+
if "gyro_x_mean" in augmented and augmented["gyro_x_mean"] is not None:
|
| 72 |
+
augmented["gyro_x_mean"] = add_noise(augmented["gyro_x_mean"], 0.005)
|
| 73 |
+
if "gyro_y_mean" in augmented and augmented["gyro_y_mean"] is not None:
|
| 74 |
+
augmented["gyro_y_mean"] = add_noise(augmented["gyro_y_mean"], 0.005)
|
| 75 |
+
if "gyro_z_mean" in augmented and augmented["gyro_z_mean"] is not None:
|
| 76 |
+
augmented["gyro_z_mean"] = add_noise(augmented["gyro_z_mean"], 0.005)
|
| 77 |
+
|
| 78 |
+
# 선형 가속도 증폭
|
| 79 |
+
if "linacc_x_mean" in augmented and augmented["linacc_x_mean"] is not None:
|
| 80 |
+
augmented["linacc_x_mean"] = add_noise(augmented["linacc_x_mean"], abs(augmented["linacc_x_mean"]) * 0.1 + 0.01)
|
| 81 |
+
if "linacc_y_mean" in augmented and augmented["linacc_y_mean"] is not None:
|
| 82 |
+
augmented["linacc_y_mean"] = add_noise(augmented["linacc_y_mean"], abs(augmented["linacc_y_mean"]) * 0.1 + 0.01)
|
| 83 |
+
if "linacc_z_mean" in augmented and augmented["linacc_z_mean"] is not None:
|
| 84 |
+
augmented["linacc_z_mean"] = add_noise(augmented["linacc_z_mean"], abs(augmented["linacc_z_mean"]) * 0.1 + 0.01)
|
| 85 |
+
|
| 86 |
+
# 중력 벡터 증폭 (물리적 제약: 크기가 약 9.8에 가까워야 함)
|
| 87 |
+
if all(f in augmented and augmented[f] is not None for f in ["gravity_x_mean", "gravity_y_mean", "gravity_z_mean"]):
|
| 88 |
+
gx = augmented["gravity_x_mean"] + random.uniform(-0.01, 0.01)
|
| 89 |
+
gy = augmented["gravity_y_mean"] + random.uniform(-0.01, 0.01)
|
| 90 |
+
gz = augmented["gravity_z_mean"] + random.uniform(-0.02, 0.02)
|
| 91 |
+
# 중력 벡터 크기 정규화 (약 9.8 유지)
|
| 92 |
+
g_mag = np.sqrt(gx**2 + gy**2 + gz**2)
|
| 93 |
+
if g_mag > 0:
|
| 94 |
+
scale = 9.8 / g_mag
|
| 95 |
+
augmented["gravity_x_mean"] = round(gx * scale, 4)
|
| 96 |
+
augmented["gravity_y_mean"] = round(gy * scale, 4)
|
| 97 |
+
augmented["gravity_z_mean"] = round(gz * scale, 4)
|
| 98 |
+
|
| 99 |
+
# 센서 표준편차 증폭 (RMS와 일관성 유지)
|
| 100 |
+
sensor_std_fields = [
|
| 101 |
+
"acc_x_std", "acc_y_std", "acc_z_std",
|
| 102 |
+
"gyro_x_std", "gyro_y_std", "gyro_z_std",
|
| 103 |
+
]
|
| 104 |
+
for field in sensor_std_fields:
|
| 105 |
+
if field in augmented and augmented[field] is not None:
|
| 106 |
+
augmented[field] = bounded(add_noise(augmented[field], augmented[field] * 0.1), 0.01, 1.0)
|
| 107 |
+
|
| 108 |
+
# RMS 값 증폭 (센서 평균값과 일관성 유지)
|
| 109 |
+
if "rms_acc" in augmented and augmented["rms_acc"] is not None:
|
| 110 |
+
# RMS는 가속도 평균값의 크기와 관련
|
| 111 |
+
acc_mag = np.sqrt(
|
| 112 |
+
(augmented.get("acc_x_mean", 0) or 0)**2 +
|
| 113 |
+
(augmented.get("acc_y_mean", 0) or 0)**2 +
|
| 114 |
+
(augmented.get("acc_z_mean", 0) or 0)**2
|
| 115 |
+
)
|
| 116 |
+
rms_base = augmented["rms_acc"]
|
| 117 |
+
# RMS는 원본과 비슷한 범위 유지
|
| 118 |
+
augmented["rms_acc"] = bounded(add_noise(rms_base, rms_base * 0.1), 0.1, 2.0)
|
| 119 |
+
|
| 120 |
+
if "rms_gyro" in augmented and augmented["rms_gyro"] is not None:
|
| 121 |
+
gyro_mag = np.sqrt(
|
| 122 |
+
(augmented.get("gyro_x_mean", 0) or 0)**2 +
|
| 123 |
+
(augmented.get("gyro_y_mean", 0) or 0)**2 +
|
| 124 |
+
(augmented.get("gyro_z_mean", 0) or 0)**2
|
| 125 |
+
)
|
| 126 |
+
rms_gyro_base = augmented["rms_gyro"]
|
| 127 |
+
augmented["rms_gyro"] = bounded(add_noise(rms_gyro_base, rms_gyro_base * 0.1), 0.01, 0.5)
|
| 128 |
+
|
| 129 |
+
# 주파수 증폭 (RMS와 상관관계 유지)
|
| 130 |
+
if "mean_freq_acc" in augmented and augmented["mean_freq_acc"] is not None:
|
| 131 |
+
# RMS가 높으면 주파수도 약간 높아지는 경향
|
| 132 |
+
freq_factor = 1.0 + (augmented.get("rms_acc", 0) or 0) * 0.1
|
| 133 |
+
augmented["mean_freq_acc"] = round(add_noise(augmented["mean_freq_acc"] * freq_factor, 1.0) / freq_factor, 2)
|
| 134 |
+
|
| 135 |
+
if "mean_freq_gyro" in augmented and augmented["mean_freq_gyro"] is not None:
|
| 136 |
+
freq_factor = 1.0 + (augmented.get("rms_gyro", 0) or 0) * 0.2
|
| 137 |
+
augmented["mean_freq_gyro"] = round(add_noise(augmented["mean_freq_gyro"] * freq_factor, 0.5) / freq_factor, 2)
|
| 138 |
+
|
| 139 |
+
# 엔트로피 증폭 (안정성과 관련)
|
| 140 |
+
if "entropy_acc" in augmented and augmented["entropy_acc"] is not None:
|
| 141 |
+
augmented["entropy_acc"] = bounded(add_noise(augmented["entropy_acc"], 0.02), 0.1, 1.0)
|
| 142 |
+
if "entropy_gyro" in augmented and augmented["entropy_gyro"] is not None:
|
| 143 |
+
augmented["entropy_gyro"] = bounded(add_noise(augmented["entropy_gyro"], 0.02), 0.1, 1.0)
|
| 144 |
+
|
| 145 |
+
# Jerk 증폭 (가속도 변화율)
|
| 146 |
+
if "jerk_mean" in augmented and augmented["jerk_mean"] is not None:
|
| 147 |
+
augmented["jerk_mean"] = add_noise(augmented["jerk_mean"], 0.01)
|
| 148 |
+
if "jerk_std" in augmented and augmented["jerk_std"] is not None:
|
| 149 |
+
augmented["jerk_std"] = bounded(add_noise(augmented["jerk_std"], 0.005), 0.01, 0.2)
|
| 150 |
+
|
| 151 |
+
# 안정성 지수 증폭 (엔트로피와 반비례 관계)
|
| 152 |
+
if "stability_index" in augmented and augmented["stability_index"] is not None:
|
| 153 |
+
# 엔트로피가 높으면 안정성이 낮아짐
|
| 154 |
+
entropy_avg = ((augmented.get("entropy_acc", 0.5) or 0.5) + (augmented.get("entropy_gyro", 0.5) or 0.5)) / 2
|
| 155 |
+
stability_base = 1.0 - entropy_avg * 0.3 # 엔트로피 기반 추정
|
| 156 |
+
augmented["stability_index"] = bounded(add_noise(stability_base, 0.02), 0.4, 0.99)
|
| 157 |
+
|
| 158 |
+
# 피로도 증폭 (RMS, 주파수와 상관관계)
|
| 159 |
+
if "fatigue" in augmented and augmented["fatigue"] is not None:
|
| 160 |
+
# RMS가 높고 주파수가 낮으면 피로도 증가
|
| 161 |
+
rms_factor = (augmented.get("rms_acc", 0) or 0) / (augmented.get("rms_base", 1.0) or 1.0)
|
| 162 |
+
freq_factor = (augmented.get("mean_freq_acc", 40) or 40) / (augmented.get("freq_base", 40) or 40)
|
| 163 |
+
fatigue_delta = (rms_factor - 1.0) * 0.05 - (freq_factor - 1.0) * 0.03 + random.uniform(-0.03, 0.03)
|
| 164 |
+
augmented["fatigue"] = bounded(augmented["fatigue"] + fatigue_delta, 0.05, 0.95)
|
| 165 |
+
augmented["fatigue_level"] = 0 if augmented["fatigue"] < 0.3 else 1 if augmented["fatigue"] < 0.6 else 2
|
| 166 |
+
|
| 167 |
+
# 이전 피로도는 현재 피로도와 연속성 유지
|
| 168 |
+
if "fatigue_prev" in augmented and augmented["fatigue_prev"] is not None:
|
| 169 |
+
if "fatigue" in augmented and augmented["fatigue"] is not None:
|
| 170 |
+
# 이전 피로도는 현재 피로도보다 약간 낮거나 비슷
|
| 171 |
+
augmented["fatigue_prev"] = bounded(augmented["fatigue"] - random.uniform(0, 0.1), 0.05, 0.95)
|
| 172 |
+
else:
|
| 173 |
+
augmented["fatigue_prev"] = bounded(add_noise(augmented["fatigue_prev"], 0.02), 0.05, 0.95)
|
| 174 |
+
|
| 175 |
+
# user_emb 벡터에 작은 노이즈 추가
|
| 176 |
+
if "user_emb" in augmented and augmented["user_emb"] is not None:
|
| 177 |
+
if isinstance(augmented["user_emb"], str):
|
| 178 |
+
try:
|
| 179 |
+
emb_list = json.loads(augmented["user_emb"])
|
| 180 |
+
except:
|
| 181 |
+
emb_list = augmented["user_emb"]
|
| 182 |
+
else:
|
| 183 |
+
emb_list = augmented["user_emb"]
|
| 184 |
+
|
| 185 |
+
if isinstance(emb_list, list) and len(emb_list) > 0:
|
| 186 |
+
augmented["user_emb"] = [round(v + random.uniform(-0.01, 0.01), 4) for v in emb_list]
|
| 187 |
+
|
| 188 |
+
# overlap_rate 약간 변형
|
| 189 |
+
if "overlap_rate" in augmented and augmented["overlap_rate"] is not None:
|
| 190 |
+
augmented["overlap_rate"] = bounded(add_noise(augmented["overlap_rate"], 0.02), 0.3, 0.7)
|
| 191 |
+
|
| 192 |
+
# quality_flag는 가끔 변경
|
| 193 |
+
if "quality_flag" in augmented:
|
| 194 |
+
if random.random() < 0.05: # 5% 확률로 변경
|
| 195 |
+
augmented["quality_flag"] = 0 if augmented["quality_flag"] == 1 else 1
|
| 196 |
+
|
| 197 |
+
# session_id 약간 변형
|
| 198 |
+
if "session_id" in augmented and augmented["session_id"]:
|
| 199 |
+
parts = augmented["session_id"].split("_")
|
| 200 |
+
if len(parts) > 1:
|
| 201 |
+
try:
|
| 202 |
+
session_num = int(parts[-1])
|
| 203 |
+
augmented["session_id"] = "_".join(parts[:-1]) + "_" + str(session_num + random.randint(-5, 5))
|
| 204 |
+
except:
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
return augmented
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def augment_user_data(df: pd.DataFrame, target_count: int) -> pd.DataFrame:
|
| 211 |
+
"""사용자별 데이터를 증폭하여 목표 개수만큼 생성"""
|
| 212 |
+
current_count = len(df)
|
| 213 |
+
if current_count == 0:
|
| 214 |
+
return df
|
| 215 |
+
|
| 216 |
+
if current_count >= target_count:
|
| 217 |
+
# 이미 충분하면 그대로 반환
|
| 218 |
+
return df.head(target_count)
|
| 219 |
+
|
| 220 |
+
# 증폭이 필요한 개수
|
| 221 |
+
needed = target_count - current_count
|
| 222 |
+
|
| 223 |
+
# 기존 데이터를 복제하고 증폭
|
| 224 |
+
augmented_records = []
|
| 225 |
+
for _ in range(needed):
|
| 226 |
+
# 랜덤하게 원본 레코드 선택
|
| 227 |
+
original_idx = random.randint(0, current_count - 1)
|
| 228 |
+
original = df.iloc[original_idx].to_dict()
|
| 229 |
+
|
| 230 |
+
# 증폭 (노이즈 스케일은 필드에 따라 다르게)
|
| 231 |
+
noise_scale = random.uniform(0.05, 0.15)
|
| 232 |
+
augmented = augment_record(original, noise_scale)
|
| 233 |
+
augmented_records.append(augmented)
|
| 234 |
+
|
| 235 |
+
# 증폭된 데이터를 DataFrame으로 변환
|
| 236 |
+
augmented_df = pd.DataFrame(augmented_records)
|
| 237 |
+
|
| 238 |
+
# 원본과 병합
|
| 239 |
+
result_df = pd.concat([df, augmented_df], ignore_index=True)
|
| 240 |
+
return result_df
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def main():
|
| 244 |
+
load_dotenv()
|
| 245 |
+
|
| 246 |
+
repo_id = require_env("HF_DATA_REPO_ID")
|
| 247 |
+
token = require_env("HF_DATA_TOKEN")
|
| 248 |
+
|
| 249 |
+
print(f"📂 기존 데이터셋 로드 중: {repo_id}")
|
| 250 |
+
|
| 251 |
+
# 개별 parquet 파일을 모두 로드 (user로 시작하지 않는 파일도 포함)
|
| 252 |
+
api = HfApi()
|
| 253 |
+
try:
|
| 254 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
|
| 255 |
+
# 모든 parquet 파일 필터링 (user로 시작하지 않는 것도 포함)
|
| 256 |
+
parquet_files = [f for f in files if f.endswith(".parquet")]
|
| 257 |
+
print(f"📊 Parquet 파일 수: {len(parquet_files)}")
|
| 258 |
+
|
| 259 |
+
existing = DatasetDict()
|
| 260 |
+
for file_path in parquet_files:
|
| 261 |
+
try:
|
| 262 |
+
# 파일명에서 사용자 ID 추출
|
| 263 |
+
# 형식: data/user_xxx.parquet 또는 data/user_xxx-00000-of-00001.parquet
|
| 264 |
+
filename = file_path.split("/")[-1] if "/" in file_path else file_path
|
| 265 |
+
# .parquet 확장자 제거
|
| 266 |
+
filename_no_ext = filename.replace(".parquet", "")
|
| 267 |
+
# -00000-of-00001 부분이 있으면 제거, 없으면 그대로 사용
|
| 268 |
+
if "-" in filename_no_ext:
|
| 269 |
+
user_id = filename_no_ext.split("-")[0]
|
| 270 |
+
else:
|
| 271 |
+
user_id = filename_no_ext
|
| 272 |
+
|
| 273 |
+
# 개별 파일을 pandas로 직접 로드
|
| 274 |
+
from huggingface_hub import hf_hub_download
|
| 275 |
+
import tempfile
|
| 276 |
+
|
| 277 |
+
# 파일 다운로드
|
| 278 |
+
local_path = hf_hub_download(
|
| 279 |
+
repo_id=repo_id,
|
| 280 |
+
filename=file_path,
|
| 281 |
+
repo_type="dataset",
|
| 282 |
+
token=token
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# pandas로 직접 읽기
|
| 286 |
+
df = pd.read_parquet(local_path)
|
| 287 |
+
if len(df) > 0:
|
| 288 |
+
existing[user_id] = Dataset.from_pandas(df, preserve_index=False)
|
| 289 |
+
print(f"✅ {user_id}: {len(df)} 레코드 로드")
|
| 290 |
+
else:
|
| 291 |
+
print(f"⚠️ {user_id}: 빈 데이터셋, 건너뜀")
|
| 292 |
+
except Exception as e2:
|
| 293 |
+
print(f"⚠️ {file_path}: 로드 실패 ({str(e2)[:100]}), 건너뜀")
|
| 294 |
+
continue
|
| 295 |
+
except Exception as e3:
|
| 296 |
+
print(f"❌ 데이터셋 로드 완전 실패: {e3}")
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
# 유효한 사용자만 필터링 (데이터가 있는 사용자만)
|
| 300 |
+
valid_users = {}
|
| 301 |
+
for user_id in existing.keys():
|
| 302 |
+
try:
|
| 303 |
+
user_data = existing[user_id]
|
| 304 |
+
if len(user_data) > 0:
|
| 305 |
+
valid_users[user_id] = user_data
|
| 306 |
+
else:
|
| 307 |
+
print(f"⚠️ {user_id}: 빈 데이터셋, 건너뜀")
|
| 308 |
+
except Exception as e:
|
| 309 |
+
print(f"⚠️ {user_id}: 데이터 접근 실패 ({e}), 건너뜀")
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
if len(valid_users) == 0:
|
| 313 |
+
print("❌ 유효한 사용자 데이터가 없습니다.")
|
| 314 |
+
return
|
| 315 |
+
|
| 316 |
+
print(f"✅ 유효한 사용자 수: {len(valid_users)}명")
|
| 317 |
+
|
| 318 |
+
# 현재 총 레코��� 수 계산
|
| 319 |
+
current_total = sum(len(valid_users[user_id]) for user_id in valid_users)
|
| 320 |
+
print(f"📊 현재 총 레코드 수: {current_total}")
|
| 321 |
+
|
| 322 |
+
# 기존 사용자 목록 가져오기 (샘플링용)
|
| 323 |
+
all_users = list(valid_users.keys())
|
| 324 |
+
|
| 325 |
+
if len(all_users) == 0:
|
| 326 |
+
print("❌ 증폭할 참조 데이터가 없습니다.")
|
| 327 |
+
return
|
| 328 |
+
|
| 329 |
+
# 새로운 사용자 20명 생성 (기존 사용자 데이터를 참조하여 증폭)
|
| 330 |
+
print(f"🎯 새로운 사용자 {TARGET_USERS}명 생성 중...")
|
| 331 |
+
print(f"📋 참조 사용자: {len(all_users)}명")
|
| 332 |
+
print(f"🎯 사용자당 목표 레코드 수: {RECORDS_PER_USER}")
|
| 333 |
+
|
| 334 |
+
# 새로운 사용자 데이터셋 생성
|
| 335 |
+
new_user_datasets = {}
|
| 336 |
+
for i in range(1, TARGET_USERS + 1):
|
| 337 |
+
# 새로운 사용자 ID 생성
|
| 338 |
+
new_user_id = f"augmented_user_{i:03d}"
|
| 339 |
+
|
| 340 |
+
# 기존 사용자 중 랜덤 선택 (참조용)
|
| 341 |
+
reference_user_id = random.choice(all_users)
|
| 342 |
+
reference_df = valid_users[reference_user_id].to_pandas()
|
| 343 |
+
|
| 344 |
+
if len(reference_df) == 0:
|
| 345 |
+
print(f"⚠️ 참조 사용자 {reference_user_id}의 데이터가 비어있어 건너뜀")
|
| 346 |
+
continue
|
| 347 |
+
|
| 348 |
+
try:
|
| 349 |
+
# 참조 데이터를 증폭하여 새로운 사용자 데이터 생성
|
| 350 |
+
new_user_df = augment_user_data(reference_df, RECORDS_PER_USER)
|
| 351 |
+
new_user_datasets[new_user_id] = Dataset.from_pandas(new_user_df, preserve_index=False)
|
| 352 |
+
print(f"📈 {new_user_id}: {RECORDS_PER_USER} 레코드 생성 (참조: {reference_user_id})")
|
| 353 |
+
except Exception as e:
|
| 354 |
+
print(f"❌ {new_user_id}: 생성 실패 ({e}), 건너뜀")
|
| 355 |
+
continue
|
| 356 |
+
|
| 357 |
+
if len(new_user_datasets) == 0:
|
| 358 |
+
print("❌ 새로운 사용자 데이터가 생성되지 않았습니다.")
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
# 기존 데이터셋에 새로운 사용자 데이터 추가
|
| 362 |
+
final_datasets = {}
|
| 363 |
+
# 기존 사용자 데이터 유지
|
| 364 |
+
for user_id in valid_users.keys():
|
| 365 |
+
final_datasets[user_id] = valid_users[user_id]
|
| 366 |
+
# 새로운 사용자 데이터 추가
|
| 367 |
+
for user_id in new_user_datasets.keys():
|
| 368 |
+
final_datasets[user_id] = new_user_datasets[user_id]
|
| 369 |
+
|
| 370 |
+
final_dict = DatasetDict(final_datasets)
|
| 371 |
+
new_users_total = sum(len(new_user_datasets[user_id]) for user_id in new_user_datasets)
|
| 372 |
+
total_records = sum(len(final_dict[user_id]) for user_id in final_dict)
|
| 373 |
+
print(f"📊 새로운 사용자들의 총 레코드 수: {new_users_total}")
|
| 374 |
+
print(f"📊 전체 데이터셋 총 레코드 수: {total_records}")
|
| 375 |
+
print(f"📊 새로운 parquet 파일 수: {len(new_user_datasets)}개")
|
| 376 |
+
|
| 377 |
+
print(f"📤 Hugging Face Hub에 업로드 중: {repo_id}")
|
| 378 |
+
final_dict.push_to_hub(repo_id, token=token, private=True)
|
| 379 |
+
print("✅ 업로드 완료")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
main()
|
| 384 |
+
|