hicxai-condition-2 / src /load_adult_data.py
Suvh
Update to v1.1-chatty-luna (2025-12-07)
070061f
import pandas as pd
import numpy as np
import os
import json
def load_adult_data(data_dir, balance=False, discretize=True):
"""
Load the Adult dataset with robust feature handling, adapted from XAgent/Agent/utils.py.
"""
data_path = os.path.join(data_dir, 'adult.data')
json_path = os.path.join(os.path.dirname(data_dir), 'dataset_info', 'adult.json')
columns = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',
'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss',
'hours_per_week', 'native_country', 'income'
]
df = pd.read_csv(data_path, names=columns, skipinitialspace=True)
# Remove rows with missing values (marked as '?')
df = df.replace('?', np.nan)
df = df.dropna()
# Convert numerical columns to appropriate types
num_cols = ['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
for col in num_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
# Optionally encode categorical variables using one-hot encoding
cat_cols = [
'workclass', 'education', 'marital_status', 'occupation',
'relationship', 'race', 'sex', 'native_country'
]
if discretize:
df = pd.get_dummies(df, columns=cat_cols)
# Encode target
df['income'] = df['income'].apply(lambda x: 1 if '>50K' in str(x) else 0)
# Load metadata
with open(json_path, 'r') as f:
meta = json.load(f)
# Add feature names, types, and valid values to meta if missing
meta.setdefault('num_features', num_cols)
meta.setdefault('cat_features', cat_cols)
for cat in cat_cols:
meta.setdefault('feature_values', {})
meta['feature_values'][cat] = sorted(df[cat].dropna().unique().tolist()) if cat in df else []
# Add feature ranges for numeric features
meta.setdefault('feature_ranges', {})
for num in num_cols:
if num in df:
meta['feature_ranges'][num] = (float(df[num].min()), float(df[num].max()))
return df, meta
if __name__ == '__main__':
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
df, meta = load_adult_data(data_dir)
print('Data shape:', df.shape)
print('Metadata:', meta)