Merry99 commited on
Commit
500a872
·
1 Parent(s): e0a5f34

Update augment_dataset.py: Generate 20 new users with 500 records each, compatible with dataset commit fa41e8b

Browse files
Files changed (1) hide show
  1. augment_dataset.py +384 -0
augment_dataset.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ from datetime import datetime, timezone, timedelta
5
+ from typing import Dict, List, Optional
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ from datasets import Dataset, DatasetDict, load_dataset
10
+ from huggingface_hub import HfApi
11
+ from dotenv import load_dotenv
12
+
13
+
14
+ TARGET_USERS = 20
15
+ RECORDS_PER_USER = 500
16
+
17
+
18
+ def require_env(var_name: str) -> str:
19
+ value = os.getenv(var_name)
20
+ if not value:
21
+ raise RuntimeError(f"환경변수 {var_name}가 필요합니다.")
22
+ return value
23
+
24
+
25
+ def add_noise(value: float, noise_scale: float) -> float:
26
+ """값에 노이즈 추가"""
27
+ if value is None:
28
+ return None
29
+ return round(value + random.uniform(-noise_scale, noise_scale), 4)
30
+
31
+
32
+ def bounded(value: float, low: float, high: float) -> float:
33
+ """값을 범위 내로 제한"""
34
+ if value is None:
35
+ return None
36
+ return max(low, min(high, value))
37
+
38
+
39
+ def augment_record(original: dict, noise_scale: float = 0.1) -> dict:
40
+ """단일 레코드를 증폭 (물리적 관계와 상관관계를 고려한 의미있는 증폭)"""
41
+ augmented = original.copy()
42
+
43
+ # 시간 정보 변형 (연속성 유지)
44
+ if "timestamp_utc" in augmented and augmented["timestamp_utc"]:
45
+ try:
46
+ base_time = datetime.fromisoformat(augmented["timestamp_utc"].replace("Z", "+00:00"))
47
+ time_delta = timedelta(milliseconds=random.randint(-200, 200))
48
+ augmented["timestamp_utc"] = (base_time + time_delta).isoformat()
49
+ except:
50
+ pass
51
+
52
+ # window_id와 시간 범위 약간 조정 (연속성 유지)
53
+ if "window_id" in augmented:
54
+ augmented["window_id"] = augmented["window_id"] + random.randint(-1, 1)
55
+ if "window_start_ms" in augmented:
56
+ augmented["window_start_ms"] = augmented["window_start_ms"] + random.randint(-50, 50)
57
+ if "window_end_ms" in augmented:
58
+ augmented["window_end_ms"] = augmented["window_start_ms"] + 2000 # window_size_ms와 일치
59
+
60
+ # 가속도계 데이터 증폭 (x, y, z 간 상관관계 유지)
61
+ acc_noise = random.uniform(-noise_scale * 0.1, noise_scale * 0.1)
62
+ if "acc_x_mean" in augmented and augmented["acc_x_mean"] is not None:
63
+ augmented["acc_x_mean"] = add_noise(augmented["acc_x_mean"], abs(augmented["acc_x_mean"]) * 0.1 + 0.01)
64
+ if "acc_y_mean" in augmented and augmented["acc_y_mean"] is not None:
65
+ augmented["acc_y_mean"] = add_noise(augmented["acc_y_mean"], abs(augmented["acc_y_mean"]) * 0.1 + 0.01)
66
+ if "acc_z_mean" in augmented and augmented["acc_z_mean"] is not None:
67
+ augmented["acc_z_mean"] = add_noise(augmented["acc_z_mean"], abs(augmented["acc_z_mean"]) * 0.1 + 0.01)
68
+
69
+ # 자이로스코프 데이터 증폭
70
+ gyro_noise = random.uniform(-noise_scale * 0.02, noise_scale * 0.02)
71
+ if "gyro_x_mean" in augmented and augmented["gyro_x_mean"] is not None:
72
+ augmented["gyro_x_mean"] = add_noise(augmented["gyro_x_mean"], 0.005)
73
+ if "gyro_y_mean" in augmented and augmented["gyro_y_mean"] is not None:
74
+ augmented["gyro_y_mean"] = add_noise(augmented["gyro_y_mean"], 0.005)
75
+ if "gyro_z_mean" in augmented and augmented["gyro_z_mean"] is not None:
76
+ augmented["gyro_z_mean"] = add_noise(augmented["gyro_z_mean"], 0.005)
77
+
78
+ # 선형 가속도 증폭
79
+ if "linacc_x_mean" in augmented and augmented["linacc_x_mean"] is not None:
80
+ augmented["linacc_x_mean"] = add_noise(augmented["linacc_x_mean"], abs(augmented["linacc_x_mean"]) * 0.1 + 0.01)
81
+ if "linacc_y_mean" in augmented and augmented["linacc_y_mean"] is not None:
82
+ augmented["linacc_y_mean"] = add_noise(augmented["linacc_y_mean"], abs(augmented["linacc_y_mean"]) * 0.1 + 0.01)
83
+ if "linacc_z_mean" in augmented and augmented["linacc_z_mean"] is not None:
84
+ augmented["linacc_z_mean"] = add_noise(augmented["linacc_z_mean"], abs(augmented["linacc_z_mean"]) * 0.1 + 0.01)
85
+
86
+ # 중력 벡터 증폭 (물리적 제약: 크기가 약 9.8에 가까워야 함)
87
+ if all(f in augmented and augmented[f] is not None for f in ["gravity_x_mean", "gravity_y_mean", "gravity_z_mean"]):
88
+ gx = augmented["gravity_x_mean"] + random.uniform(-0.01, 0.01)
89
+ gy = augmented["gravity_y_mean"] + random.uniform(-0.01, 0.01)
90
+ gz = augmented["gravity_z_mean"] + random.uniform(-0.02, 0.02)
91
+ # 중력 벡터 크기 정규화 (약 9.8 유지)
92
+ g_mag = np.sqrt(gx**2 + gy**2 + gz**2)
93
+ if g_mag > 0:
94
+ scale = 9.8 / g_mag
95
+ augmented["gravity_x_mean"] = round(gx * scale, 4)
96
+ augmented["gravity_y_mean"] = round(gy * scale, 4)
97
+ augmented["gravity_z_mean"] = round(gz * scale, 4)
98
+
99
+ # 센서 표준편차 증폭 (RMS와 일관성 유지)
100
+ sensor_std_fields = [
101
+ "acc_x_std", "acc_y_std", "acc_z_std",
102
+ "gyro_x_std", "gyro_y_std", "gyro_z_std",
103
+ ]
104
+ for field in sensor_std_fields:
105
+ if field in augmented and augmented[field] is not None:
106
+ augmented[field] = bounded(add_noise(augmented[field], augmented[field] * 0.1), 0.01, 1.0)
107
+
108
+ # RMS 값 증폭 (센서 평균값과 일관성 유지)
109
+ if "rms_acc" in augmented and augmented["rms_acc"] is not None:
110
+ # RMS는 가속도 평균값의 크기와 관련
111
+ acc_mag = np.sqrt(
112
+ (augmented.get("acc_x_mean", 0) or 0)**2 +
113
+ (augmented.get("acc_y_mean", 0) or 0)**2 +
114
+ (augmented.get("acc_z_mean", 0) or 0)**2
115
+ )
116
+ rms_base = augmented["rms_acc"]
117
+ # RMS는 원본과 비슷한 범위 유지
118
+ augmented["rms_acc"] = bounded(add_noise(rms_base, rms_base * 0.1), 0.1, 2.0)
119
+
120
+ if "rms_gyro" in augmented and augmented["rms_gyro"] is not None:
121
+ gyro_mag = np.sqrt(
122
+ (augmented.get("gyro_x_mean", 0) or 0)**2 +
123
+ (augmented.get("gyro_y_mean", 0) or 0)**2 +
124
+ (augmented.get("gyro_z_mean", 0) or 0)**2
125
+ )
126
+ rms_gyro_base = augmented["rms_gyro"]
127
+ augmented["rms_gyro"] = bounded(add_noise(rms_gyro_base, rms_gyro_base * 0.1), 0.01, 0.5)
128
+
129
+ # 주파수 증폭 (RMS와 상관관계 유지)
130
+ if "mean_freq_acc" in augmented and augmented["mean_freq_acc"] is not None:
131
+ # RMS가 높으면 주파수도 약간 높아지는 경향
132
+ freq_factor = 1.0 + (augmented.get("rms_acc", 0) or 0) * 0.1
133
+ augmented["mean_freq_acc"] = round(add_noise(augmented["mean_freq_acc"] * freq_factor, 1.0) / freq_factor, 2)
134
+
135
+ if "mean_freq_gyro" in augmented and augmented["mean_freq_gyro"] is not None:
136
+ freq_factor = 1.0 + (augmented.get("rms_gyro", 0) or 0) * 0.2
137
+ augmented["mean_freq_gyro"] = round(add_noise(augmented["mean_freq_gyro"] * freq_factor, 0.5) / freq_factor, 2)
138
+
139
+ # 엔트로피 증폭 (안정성과 관련)
140
+ if "entropy_acc" in augmented and augmented["entropy_acc"] is not None:
141
+ augmented["entropy_acc"] = bounded(add_noise(augmented["entropy_acc"], 0.02), 0.1, 1.0)
142
+ if "entropy_gyro" in augmented and augmented["entropy_gyro"] is not None:
143
+ augmented["entropy_gyro"] = bounded(add_noise(augmented["entropy_gyro"], 0.02), 0.1, 1.0)
144
+
145
+ # Jerk 증폭 (가속도 변화율)
146
+ if "jerk_mean" in augmented and augmented["jerk_mean"] is not None:
147
+ augmented["jerk_mean"] = add_noise(augmented["jerk_mean"], 0.01)
148
+ if "jerk_std" in augmented and augmented["jerk_std"] is not None:
149
+ augmented["jerk_std"] = bounded(add_noise(augmented["jerk_std"], 0.005), 0.01, 0.2)
150
+
151
+ # 안정성 지수 증폭 (엔트로피와 반비례 관계)
152
+ if "stability_index" in augmented and augmented["stability_index"] is not None:
153
+ # 엔트로피가 높으면 안정성이 낮아짐
154
+ entropy_avg = ((augmented.get("entropy_acc", 0.5) or 0.5) + (augmented.get("entropy_gyro", 0.5) or 0.5)) / 2
155
+ stability_base = 1.0 - entropy_avg * 0.3 # 엔트로피 기반 추정
156
+ augmented["stability_index"] = bounded(add_noise(stability_base, 0.02), 0.4, 0.99)
157
+
158
+ # 피로도 증폭 (RMS, 주파수와 상관관계)
159
+ if "fatigue" in augmented and augmented["fatigue"] is not None:
160
+ # RMS가 높고 주파수가 낮으면 피로도 증가
161
+ rms_factor = (augmented.get("rms_acc", 0) or 0) / (augmented.get("rms_base", 1.0) or 1.0)
162
+ freq_factor = (augmented.get("mean_freq_acc", 40) or 40) / (augmented.get("freq_base", 40) or 40)
163
+ fatigue_delta = (rms_factor - 1.0) * 0.05 - (freq_factor - 1.0) * 0.03 + random.uniform(-0.03, 0.03)
164
+ augmented["fatigue"] = bounded(augmented["fatigue"] + fatigue_delta, 0.05, 0.95)
165
+ augmented["fatigue_level"] = 0 if augmented["fatigue"] < 0.3 else 1 if augmented["fatigue"] < 0.6 else 2
166
+
167
+ # 이전 피로도는 현재 피로도와 연속성 유지
168
+ if "fatigue_prev" in augmented and augmented["fatigue_prev"] is not None:
169
+ if "fatigue" in augmented and augmented["fatigue"] is not None:
170
+ # 이전 피로도는 현재 피로도보다 약간 낮거나 비슷
171
+ augmented["fatigue_prev"] = bounded(augmented["fatigue"] - random.uniform(0, 0.1), 0.05, 0.95)
172
+ else:
173
+ augmented["fatigue_prev"] = bounded(add_noise(augmented["fatigue_prev"], 0.02), 0.05, 0.95)
174
+
175
+ # user_emb 벡터에 작은 노이즈 추가
176
+ if "user_emb" in augmented and augmented["user_emb"] is not None:
177
+ if isinstance(augmented["user_emb"], str):
178
+ try:
179
+ emb_list = json.loads(augmented["user_emb"])
180
+ except:
181
+ emb_list = augmented["user_emb"]
182
+ else:
183
+ emb_list = augmented["user_emb"]
184
+
185
+ if isinstance(emb_list, list) and len(emb_list) > 0:
186
+ augmented["user_emb"] = [round(v + random.uniform(-0.01, 0.01), 4) for v in emb_list]
187
+
188
+ # overlap_rate 약간 변형
189
+ if "overlap_rate" in augmented and augmented["overlap_rate"] is not None:
190
+ augmented["overlap_rate"] = bounded(add_noise(augmented["overlap_rate"], 0.02), 0.3, 0.7)
191
+
192
+ # quality_flag는 가끔 변경
193
+ if "quality_flag" in augmented:
194
+ if random.random() < 0.05: # 5% 확률로 변경
195
+ augmented["quality_flag"] = 0 if augmented["quality_flag"] == 1 else 1
196
+
197
+ # session_id 약간 변형
198
+ if "session_id" in augmented and augmented["session_id"]:
199
+ parts = augmented["session_id"].split("_")
200
+ if len(parts) > 1:
201
+ try:
202
+ session_num = int(parts[-1])
203
+ augmented["session_id"] = "_".join(parts[:-1]) + "_" + str(session_num + random.randint(-5, 5))
204
+ except:
205
+ pass
206
+
207
+ return augmented
208
+
209
+
210
+ def augment_user_data(df: pd.DataFrame, target_count: int) -> pd.DataFrame:
211
+ """사용자별 데이터를 증폭하여 목표 개수만큼 생성"""
212
+ current_count = len(df)
213
+ if current_count == 0:
214
+ return df
215
+
216
+ if current_count >= target_count:
217
+ # 이미 충분하면 그대로 반환
218
+ return df.head(target_count)
219
+
220
+ # 증폭이 필요한 개수
221
+ needed = target_count - current_count
222
+
223
+ # 기존 데이터를 복제하고 증폭
224
+ augmented_records = []
225
+ for _ in range(needed):
226
+ # 랜덤하게 원본 레코드 선택
227
+ original_idx = random.randint(0, current_count - 1)
228
+ original = df.iloc[original_idx].to_dict()
229
+
230
+ # 증폭 (노이즈 스케일은 필드에 따라 다르게)
231
+ noise_scale = random.uniform(0.05, 0.15)
232
+ augmented = augment_record(original, noise_scale)
233
+ augmented_records.append(augmented)
234
+
235
+ # 증폭된 데이터를 DataFrame으로 변환
236
+ augmented_df = pd.DataFrame(augmented_records)
237
+
238
+ # 원본과 병합
239
+ result_df = pd.concat([df, augmented_df], ignore_index=True)
240
+ return result_df
241
+
242
+
243
+ def main():
244
+ load_dotenv()
245
+
246
+ repo_id = require_env("HF_DATA_REPO_ID")
247
+ token = require_env("HF_DATA_TOKEN")
248
+
249
+ print(f"📂 기존 데이터셋 로드 중: {repo_id}")
250
+
251
+ # 개별 parquet 파일을 모두 로드 (user로 시작하지 않는 파일도 포함)
252
+ api = HfApi()
253
+ try:
254
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
255
+ # 모든 parquet 파일 필터링 (user로 시작하지 않는 것도 포함)
256
+ parquet_files = [f for f in files if f.endswith(".parquet")]
257
+ print(f"📊 Parquet 파일 수: {len(parquet_files)}")
258
+
259
+ existing = DatasetDict()
260
+ for file_path in parquet_files:
261
+ try:
262
+ # 파일명에서 사용자 ID 추출
263
+ # 형식: data/user_xxx.parquet 또는 data/user_xxx-00000-of-00001.parquet
264
+ filename = file_path.split("/")[-1] if "/" in file_path else file_path
265
+ # .parquet 확장자 제거
266
+ filename_no_ext = filename.replace(".parquet", "")
267
+ # -00000-of-00001 부분이 있으면 제거, 없으면 그대로 사용
268
+ if "-" in filename_no_ext:
269
+ user_id = filename_no_ext.split("-")[0]
270
+ else:
271
+ user_id = filename_no_ext
272
+
273
+ # 개별 파일을 pandas로 직접 로드
274
+ from huggingface_hub import hf_hub_download
275
+ import tempfile
276
+
277
+ # 파일 다운로드
278
+ local_path = hf_hub_download(
279
+ repo_id=repo_id,
280
+ filename=file_path,
281
+ repo_type="dataset",
282
+ token=token
283
+ )
284
+
285
+ # pandas로 직접 읽기
286
+ df = pd.read_parquet(local_path)
287
+ if len(df) > 0:
288
+ existing[user_id] = Dataset.from_pandas(df, preserve_index=False)
289
+ print(f"✅ {user_id}: {len(df)} 레코드 로드")
290
+ else:
291
+ print(f"⚠️ {user_id}: 빈 데이터셋, 건너뜀")
292
+ except Exception as e2:
293
+ print(f"⚠️ {file_path}: 로드 실패 ({str(e2)[:100]}), 건너뜀")
294
+ continue
295
+ except Exception as e3:
296
+ print(f"❌ 데이터셋 로드 완전 실패: {e3}")
297
+ return
298
+
299
+ # 유효한 사용자만 필터링 (데이터가 있는 사용자만)
300
+ valid_users = {}
301
+ for user_id in existing.keys():
302
+ try:
303
+ user_data = existing[user_id]
304
+ if len(user_data) > 0:
305
+ valid_users[user_id] = user_data
306
+ else:
307
+ print(f"⚠️ {user_id}: 빈 데이터셋, 건너뜀")
308
+ except Exception as e:
309
+ print(f"⚠️ {user_id}: 데이터 접근 실패 ({e}), 건너뜀")
310
+ continue
311
+
312
+ if len(valid_users) == 0:
313
+ print("❌ 유효한 사용자 데이터가 없습니다.")
314
+ return
315
+
316
+ print(f"✅ 유효한 사용자 수: {len(valid_users)}명")
317
+
318
+ # 현재 총 레코��� 수 계산
319
+ current_total = sum(len(valid_users[user_id]) for user_id in valid_users)
320
+ print(f"📊 현재 총 레코드 수: {current_total}")
321
+
322
+ # 기존 사용자 목록 가져오기 (샘플링용)
323
+ all_users = list(valid_users.keys())
324
+
325
+ if len(all_users) == 0:
326
+ print("❌ 증폭할 참조 데이터가 없습니다.")
327
+ return
328
+
329
+ # 새로운 사용자 20명 생성 (기존 사용자 데이터를 참조하여 증폭)
330
+ print(f"🎯 새로운 사용자 {TARGET_USERS}명 생성 중...")
331
+ print(f"📋 참조 사용자: {len(all_users)}명")
332
+ print(f"🎯 사용자당 목표 레코드 수: {RECORDS_PER_USER}")
333
+
334
+ # 새로운 사용자 데이터셋 생성
335
+ new_user_datasets = {}
336
+ for i in range(1, TARGET_USERS + 1):
337
+ # 새로운 사용자 ID 생성
338
+ new_user_id = f"augmented_user_{i:03d}"
339
+
340
+ # 기존 사용자 중 랜덤 선택 (참조용)
341
+ reference_user_id = random.choice(all_users)
342
+ reference_df = valid_users[reference_user_id].to_pandas()
343
+
344
+ if len(reference_df) == 0:
345
+ print(f"⚠️ 참조 사용자 {reference_user_id}의 데이터가 비어있어 건너뜀")
346
+ continue
347
+
348
+ try:
349
+ # 참조 데이터를 증폭하여 새로운 사용자 데이터 생성
350
+ new_user_df = augment_user_data(reference_df, RECORDS_PER_USER)
351
+ new_user_datasets[new_user_id] = Dataset.from_pandas(new_user_df, preserve_index=False)
352
+ print(f"📈 {new_user_id}: {RECORDS_PER_USER} 레코드 생성 (참조: {reference_user_id})")
353
+ except Exception as e:
354
+ print(f"❌ {new_user_id}: 생성 실패 ({e}), 건너뜀")
355
+ continue
356
+
357
+ if len(new_user_datasets) == 0:
358
+ print("❌ 새로운 사용자 데이터가 생성되지 않았습니다.")
359
+ return
360
+
361
+ # 기존 데이터셋에 새로운 사용자 데이터 추가
362
+ final_datasets = {}
363
+ # 기존 사용자 데이터 유지
364
+ for user_id in valid_users.keys():
365
+ final_datasets[user_id] = valid_users[user_id]
366
+ # 새로운 사용자 데이터 추가
367
+ for user_id in new_user_datasets.keys():
368
+ final_datasets[user_id] = new_user_datasets[user_id]
369
+
370
+ final_dict = DatasetDict(final_datasets)
371
+ new_users_total = sum(len(new_user_datasets[user_id]) for user_id in new_user_datasets)
372
+ total_records = sum(len(final_dict[user_id]) for user_id in final_dict)
373
+ print(f"📊 새로운 사용자들의 총 레코드 수: {new_users_total}")
374
+ print(f"📊 전체 데이터셋 총 레코드 수: {total_records}")
375
+ print(f"📊 새로운 parquet 파일 수: {len(new_user_datasets)}개")
376
+
377
+ print(f"📤 Hugging Face Hub에 업로드 중: {repo_id}")
378
+ final_dict.push_to_hub(repo_id, token=token, private=True)
379
+ print("✅ 업로드 완료")
380
+
381
+
382
+ if __name__ == "__main__":
383
+ main()
384
+