Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import tomllib | |
| value_mapping = { | |
| 'his_SEX': {'female': 0, 'male': 1}, | |
| 'his_HISPANIC': {'no': 0, 'yes': 1}, | |
| 'his_NACCNIHR': {'whi': 0, 'blk': 1, 'asi': 2, 'ind': 3, 'haw': 4, 'mul': 5}, | |
| } | |
| label_names = ['amy_label', 'tau_label'] | |
| class CSVDataset: | |
| def __init__(self, dat_file, cnf_file): | |
| ''' ... ''' | |
| # load data csv | |
| df = pd.read_csv(dat_file) | |
| # value mapping | |
| # for col, mapping in value_mapping.items(): | |
| # df[col] = df[col].replace(mapping) | |
| # load toml file to get feature names | |
| # with open(cnf_file, 'rb') as file: | |
| # feature_names = tomllib.load(file)['feature'].keys() | |
| cnf = pd.read_csv(cnf_file) | |
| expected_features = [col for col in list(cnf['Name']) if col not in label_names] | |
| # Only use features that exist in both the config and the data | |
| available_features = [col for col in expected_features if col in df.columns] | |
| missing_features = [col for col in expected_features if col not in df.columns] | |
| if missing_features: | |
| print(f"Warning: {len(missing_features)} features missing from data file:") | |
| print(f"Missing: {missing_features[:10]}...") | |
| print(f"Using {len(available_features)} out of {len(expected_features)} expected features") | |
| self.df = df | |
| self.df_features = df[available_features] | |
| self.df_labels = df[label_names] if all(col in df.columns for col in label_names) else pd.DataFrame() | |
| def __len__(self): | |
| ''' ... ''' | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| ''' ... ''' | |
| row = self.df_features.iloc[idx] | |
| clean_row = row.dropna() | |
| feature_dict = clean_row.to_dict() | |
| row = self.df_labels.iloc[idx] | |
| clean_row = row.dropna() | |
| label_dict = clean_row.to_dict() | |
| return feature_dict, label_dict | |
| if __name__ == '__main__': | |
| # load dataset | |
| dset = CSVDataset( | |
| dat_file = "./test.csv", | |
| cnf_file = "./input_meta_info.csv" | |
| ) | |
| print(dset[1]) |