Spaces:
Sleeping
Sleeping
add generate_dataset
Browse files- generate_dataset.py +197 -0
generate_dataset.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from datetime import datetime, timezone, timedelta
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from datasets import Dataset, DatasetDict
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
TOTAL_USERS = 50
|
| 13 |
+
RECORDS_PER_USER = 50
|
| 14 |
+
USER_EMB_DIM = 12
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class UserProfile:
|
| 19 |
+
user_id: str
|
| 20 |
+
session_prefix: str
|
| 21 |
+
base_time: datetime
|
| 22 |
+
acc_mean: Tuple[float, float, float]
|
| 23 |
+
gyro_mean: Tuple[float, float, float]
|
| 24 |
+
linacc_mean: Tuple[float, float, float]
|
| 25 |
+
gravity_mean: Tuple[float, float, float]
|
| 26 |
+
acc_std: Tuple[float, float, float]
|
| 27 |
+
gyro_std: Tuple[float, float, float]
|
| 28 |
+
rms_base: float
|
| 29 |
+
rms_gyro_base: float
|
| 30 |
+
mean_freq_acc: float
|
| 31 |
+
mean_freq_gyro: float
|
| 32 |
+
entropy_acc: float
|
| 33 |
+
entropy_gyro: float
|
| 34 |
+
jerk_mean: float
|
| 35 |
+
jerk_std: float
|
| 36 |
+
stability_index: float
|
| 37 |
+
freq_base: float
|
| 38 |
+
user_emb: List[float]
|
| 39 |
+
fatigue_base: float
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def require_env(var_name: str) -> str:
|
| 43 |
+
value = os.getenv(var_name)
|
| 44 |
+
if not value:
|
| 45 |
+
raise RuntimeError(f"νκ²½λ³μ {var_name}κ° νμν©λλ€.")
|
| 46 |
+
return value
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def random_vector(dim: int, scale: float = 1.0) -> List[float]:
|
| 50 |
+
return [round(random.uniform(-scale, scale), 4) for _ in range(dim)]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_user_profile(user_idx: int, start_time: datetime) -> UserProfile:
|
| 54 |
+
user_id = f"user_{user_idx:03d}"
|
| 55 |
+
session_prefix = f"{user_id}_session"
|
| 56 |
+
|
| 57 |
+
def triple(base_scale: float) -> Tuple[float, float, float]:
|
| 58 |
+
return tuple(round(random.uniform(-base_scale, base_scale), 4) for _ in range(3))
|
| 59 |
+
|
| 60 |
+
def positive_triple(low: float, high: float) -> Tuple[float, float, float]:
|
| 61 |
+
return tuple(round(random.uniform(low, high), 4) for _ in range(3))
|
| 62 |
+
|
| 63 |
+
profile = UserProfile(
|
| 64 |
+
user_id=user_id,
|
| 65 |
+
session_prefix=session_prefix,
|
| 66 |
+
base_time=start_time + timedelta(minutes=random.uniform(0, 5)),
|
| 67 |
+
acc_mean=triple(0.2),
|
| 68 |
+
gyro_mean=triple(0.05),
|
| 69 |
+
linacc_mean=triple(0.3),
|
| 70 |
+
gravity_mean=(round(random.uniform(-0.05, 0.05), 4),
|
| 71 |
+
round(random.uniform(-0.05, 0.05), 4),
|
| 72 |
+
round(random.uniform(0.9, 1.1), 4)),
|
| 73 |
+
acc_std=positive_triple(0.2, 0.6),
|
| 74 |
+
gyro_std=positive_triple(0.02, 0.08),
|
| 75 |
+
rms_base=round(random.uniform(0.3, 1.0), 4),
|
| 76 |
+
rms_gyro_base=round(random.uniform(0.05, 0.2), 4),
|
| 77 |
+
mean_freq_acc=round(random.uniform(25, 55), 2),
|
| 78 |
+
mean_freq_gyro=round(random.uniform(10, 25), 2),
|
| 79 |
+
entropy_acc=round(random.uniform(0.3, 0.8), 4),
|
| 80 |
+
entropy_gyro=round(random.uniform(0.3, 0.7), 4),
|
| 81 |
+
jerk_mean=round(random.uniform(-0.2, 0.2), 4),
|
| 82 |
+
jerk_std=round(random.uniform(0.02, 0.08), 4),
|
| 83 |
+
stability_index=round(random.uniform(0.6, 0.95), 4),
|
| 84 |
+
freq_base=round(random.uniform(30, 55), 2),
|
| 85 |
+
user_emb=random_vector(USER_EMB_DIM, scale=0.5),
|
| 86 |
+
fatigue_base=round(random.uniform(0.25, 0.6), 4),
|
| 87 |
+
)
|
| 88 |
+
return profile
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def add_noise(value: float, noise_scale: float) -> float:
|
| 92 |
+
return round(value + random.uniform(-noise_scale, noise_scale), 4)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def bounded(value: float, low: float, high: float) -> float:
|
| 96 |
+
return max(low, min(high, value))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def random_record(
|
| 100 |
+
profile: UserProfile,
|
| 101 |
+
record_idx: int,
|
| 102 |
+
prev_fatigue: float,
|
| 103 |
+
) -> Tuple[dict, float]:
|
| 104 |
+
window_start_ms = record_idx * 2000
|
| 105 |
+
window_end_ms = window_start_ms + 2000
|
| 106 |
+
base_time = profile.base_time + timedelta(milliseconds=window_start_ms)
|
| 107 |
+
|
| 108 |
+
def rand_float(scale: float = 1.0) -> float:
|
| 109 |
+
return round(random.uniform(-scale, scale), 4)
|
| 110 |
+
|
| 111 |
+
fatigue_delta = random.uniform(-0.05, 0.1)
|
| 112 |
+
fatigue = bounded(prev_fatigue + fatigue_delta, 0.05, 0.95)
|
| 113 |
+
|
| 114 |
+
record = {
|
| 115 |
+
"user_id": profile.user_id,
|
| 116 |
+
"session_id": f"{profile.session_prefix}_{record_idx:03d}",
|
| 117 |
+
"window_id": record_idx,
|
| 118 |
+
"window_start_ms": window_start_ms,
|
| 119 |
+
"window_end_ms": window_end_ms,
|
| 120 |
+
"timestamp_utc": base_time.replace(tzinfo=timezone.utc).isoformat(),
|
| 121 |
+
"acc_x_mean": add_noise(profile.acc_mean[0], 0.05),
|
| 122 |
+
"acc_y_mean": add_noise(profile.acc_mean[1], 0.05),
|
| 123 |
+
"acc_z_mean": add_noise(profile.acc_mean[2], 0.05),
|
| 124 |
+
"gyro_x_mean": add_noise(profile.gyro_mean[0], 0.01),
|
| 125 |
+
"gyro_y_mean": add_noise(profile.gyro_mean[1], 0.01),
|
| 126 |
+
"gyro_z_mean": add_noise(profile.gyro_mean[2], 0.01),
|
| 127 |
+
"linacc_x_mean": add_noise(profile.linacc_mean[0], 0.07),
|
| 128 |
+
"linacc_y_mean": add_noise(profile.linacc_mean[1], 0.07),
|
| 129 |
+
"linacc_z_mean": add_noise(profile.linacc_mean[2], 0.07),
|
| 130 |
+
"gravity_x_mean": add_noise(profile.gravity_mean[0], 0.005),
|
| 131 |
+
"gravity_y_mean": add_noise(profile.gravity_mean[1], 0.005),
|
| 132 |
+
"gravity_z_mean": add_noise(profile.gravity_mean[2], 0.02),
|
| 133 |
+
"acc_x_std": add_noise(profile.acc_std[0], 0.05),
|
| 134 |
+
"acc_y_std": add_noise(profile.acc_std[1], 0.05),
|
| 135 |
+
"acc_z_std": add_noise(profile.acc_std[2], 0.05),
|
| 136 |
+
"gyro_x_std": add_noise(profile.gyro_std[0], 0.005),
|
| 137 |
+
"gyro_y_std": add_noise(profile.gyro_std[1], 0.005),
|
| 138 |
+
"gyro_z_std": add_noise(profile.gyro_std[2], 0.005),
|
| 139 |
+
"rms_acc": add_noise(profile.rms_base, 0.1),
|
| 140 |
+
"rms_gyro": add_noise(profile.rms_gyro_base, 0.02),
|
| 141 |
+
"mean_freq_acc": round(add_noise(profile.mean_freq_acc, 1.5), 2),
|
| 142 |
+
"mean_freq_gyro": round(add_noise(profile.mean_freq_gyro, 0.8), 2),
|
| 143 |
+
"entropy_acc": add_noise(profile.entropy_acc, 0.05),
|
| 144 |
+
"entropy_gyro": add_noise(profile.entropy_gyro, 0.05),
|
| 145 |
+
"jerk_mean": add_noise(profile.jerk_mean, 0.02),
|
| 146 |
+
"jerk_std": add_noise(profile.jerk_std, 0.01),
|
| 147 |
+
"stability_index": bounded(add_noise(profile.stability_index, 0.03), 0.4, 0.99),
|
| 148 |
+
"rms_base": profile.rms_base,
|
| 149 |
+
"freq_base": profile.freq_base,
|
| 150 |
+
"user_emb": profile.user_emb,
|
| 151 |
+
"fatigue_prev": round(prev_fatigue, 4),
|
| 152 |
+
"fatigue": round(fatigue, 4),
|
| 153 |
+
"fatigue_level": 0 if fatigue < 0.3 else 1 if fatigue < 0.6 else 2,
|
| 154 |
+
"quality_flag": 1 if random.random() > 0.05 else 0,
|
| 155 |
+
"window_size_ms": 2000,
|
| 156 |
+
"overlap_rate": 0.5 + rand_float(0.05),
|
| 157 |
+
}
|
| 158 |
+
return record, fatigue
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def generate_dataset_dict() -> DatasetDict:
|
| 162 |
+
datasets_by_user: Dict[str, Dataset] = {}
|
| 163 |
+
start_time = datetime.utcnow()
|
| 164 |
+
|
| 165 |
+
for user_idx in range(1, TOTAL_USERS + 1):
|
| 166 |
+
profile = generate_user_profile(user_idx, start_time)
|
| 167 |
+
|
| 168 |
+
rows = []
|
| 169 |
+
prev_fatigue = profile.fatigue_base
|
| 170 |
+
for record_idx in range(RECORDS_PER_USER):
|
| 171 |
+
record, prev_fatigue = random_record(profile, record_idx, prev_fatigue)
|
| 172 |
+
rows.append(record)
|
| 173 |
+
|
| 174 |
+
df = pd.DataFrame(rows)
|
| 175 |
+
datasets_by_user[profile.user_id] = Dataset.from_pandas(df, preserve_index=False)
|
| 176 |
+
|
| 177 |
+
return DatasetDict(datasets_by_user)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def main():
|
| 181 |
+
load_dotenv()
|
| 182 |
+
|
| 183 |
+
repo_id = require_env("HF_DATA_REPO_ID")
|
| 184 |
+
token = require_env("HF_DATA_TOKEN")
|
| 185 |
+
|
| 186 |
+
print(f"π¦ Generating synthetic dataset: users={TOTAL_USERS}, records/user={RECORDS_PER_USER}")
|
| 187 |
+
dataset_dict = generate_dataset_dict()
|
| 188 |
+
total_records = sum(len(dataset_dict[user_id]) for user_id in dataset_dict)
|
| 189 |
+
print(f"π’ Total records: {total_records}")
|
| 190 |
+
|
| 191 |
+
print(f"π€ Pushing DatasetDict ({len(dataset_dict)} users) to Hugging Face: {repo_id}")
|
| 192 |
+
dataset_dict.push_to_hub(repo_id, token=token, private=True)
|
| 193 |
+
print("β
Upload complete")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|