temporal-twins-code / src /generators /user_generator.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
from __future__ import annotations
import numpy as np
import pandas as pd
from typing import Dict
from src.core.config_loader import Config
USER_TYPE_PROBS: Dict[str, float] = {
"customer": 0.6,
"merchant": 0.15,
"supplier": 0.05,
"employer": 0.1,
"fraudster": 0.05,
"mule": 0.05,
}
KYC_LEVELS = ["low", "medium", "full"]
KYC_PROBS = [0.2, 0.3, 0.5]
RISK_LEVELS = ["low", "medium", "high"]
RISK_PROBS = [0.6, 0.3, 0.1]
def _sample_user_types(n: int) -> np.ndarray:
types = list(USER_TYPE_PROBS.keys())
probs = list(USER_TYPE_PROBS.values())
return np.random.choice(types, size=n, p=probs)
def _sample_kyc(n: int) -> np.ndarray:
return np.random.choice(KYC_LEVELS, size=n, p=KYC_PROBS)
def _sample_risk(n: int) -> np.ndarray:
return np.random.choice(RISK_LEVELS, size=n, p=RISK_PROBS)
def generate_users(config: Config) -> pd.DataFrame:
n = config.num_users
p = config.user_params
user_ids = np.arange(n)
# Transaction frequency (λ_u) ~ LogNormal
lambda_u = np.random.lognormal(
mean=np.log(p.lambda_mean),
sigma=p.lambda_std,
size=n
)
# Amount distribution parameters
mu_u = np.random.normal(
loc=p.mu_mean,
scale=p.mu_std,
size=n
)
sigma_u = np.random.uniform(
low=max(1e-6, p.sigma_mean - p.sigma_std),
high=p.sigma_mean + p.sigma_std,
size=n
)
# Ensure strictly positive
lambda_u = np.clip(lambda_u, 1e-6, None)
sigma_u = np.clip(sigma_u, 1e-6, None)
# Balance ~ LogNormal
balance = np.random.lognormal(mean=10.0, sigma=1.0, size=n)
user_type = _sample_user_types(n)
kyc_level = _sample_kyc(n)
risk_profile = _sample_risk(n)
df = pd.DataFrame({
"user_id": user_ids,
"user_type": user_type,
"lambda_u": lambda_u,
"mu_u": mu_u,
"sigma_u": sigma_u,
"balance": balance,
"kyc_level": kyc_level,
"risk_profile": risk_profile,
})
# Basic validation checks
if df.isnull().any().any():
raise ValueError("NaNs detected in generated users")
if (df["lambda_u"] <= 0).any():
raise ValueError("Invalid lambda_u values")
if (df["sigma_u"] <= 0).any():
raise ValueError("Invalid sigma_u values")
return df