Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import scipy.sparse as sp | |
| from fainress_component import disparate_impact_remover, reweighting, sample | |
| def nba_CatGCN_pre_process(df, df_edge_list, sens_attr, label, special_case, onehot_bin_columns, onehot_cat_columns, debaising_approach=None): | |
| if onehot_bin_columns != None: | |
| df = apply_bin_columns(df, onehot_bin_columns) | |
| if onehot_cat_columns != None: | |
| df = apply_cat_columns(df, onehot_cat_columns) | |
| #nba case | |
| if -1 in df[label].unique(): | |
| df[label] = df[label].replace(-1, 0) | |
| if debaising_approach != None: | |
| if debaising_approach == 'disparate_impact_remover': | |
| df = disparate_impact_remover(df, sens_attr, label) | |
| elif debaising_approach == 'reweighting': | |
| df = reweighting(df, sens_attr, label) | |
| elif debaising_approach == 'sample': | |
| df = sample(df, sens_attr, label) | |
| #if debaising_approach == 'sample': | |
| # df = df.reset_index() | |
| # df = df.drop(['index'], axis=1) | |
| # df = df.drop_duplicates() | |
| if debaising_approach == 'disparate_impact_remover' or debaising_approach == 'reweighting': | |
| df.AGE = df.AGE.astype(int) | |
| df.country = df.country.astype(int) | |
| df.SALARY = df.SALARY.astype(int) | |
| df['user_id'] = pd.to_numeric(df['user_id']) | |
| df = df.astype({'user_id': int}) | |
| df.AGE = df.AGE.astype(str) | |
| df.MP = df.MP.astype(str) | |
| df.FG = df.FG.astype(str) | |
| df['AGE'] = df['AGE'].astype(str).astype(int) | |
| #for the nba dataset we choose age as the m apping option to the userid | |
| uid_age = df[['user_id', 'AGE']].copy() | |
| uid_age.dropna(inplace=True) | |
| uid_age2 = df[['user_id', 'AGE']].copy() | |
| #create uid2id | |
| uid2id = {num: i for i, num in enumerate(df['user_id'])} | |
| #create age2id | |
| age2id = {num: i for i, num in enumerate(pd.unique(uid_age['AGE']))} | |
| #create user_field | |
| user_field = col_map(uid_age, 'user_id', uid2id) | |
| user_field = col_map(user_field, 'AGE', age2id) | |
| ## new part for disparate remover | |
| if debaising_approach == 'disparate_impact_remover': | |
| user_field = user_field.reset_index() | |
| user_field = user_field.drop(['user_id'], axis=1) | |
| user_field = user_field.rename(columns={"index": "user_id"}) | |
| user_field['user_id'] = user_field['user_id'].astype(str).astype(int) | |
| #create user_label | |
| user_label = df[df['user_id'].isin(uid_age2['user_id'])] | |
| user_label = col_map(user_label, 'user_id', uid2id) | |
| user_label = label_map(user_label, user_label.columns[1:]) | |
| print('User label size', user_label.size) | |
| # save_path = "./input_ali_data" | |
| save_path = "./" | |
| # process edge list | |
| if df_edge_list['source'].dtype != 'int64': | |
| df_edge_list['source'] = df_edge_list['source'].astype(str).astype(np.int64) | |
| df_edge_list['target'] = df_edge_list['target'].astype(str).astype(np.int64) | |
| source = [] | |
| target = [] | |
| for i in range(df_edge_list.shape[0]): | |
| if any(df.user_id == df_edge_list.source[i]) == True and any(df.user_id == df_edge_list.target[i]) == True: | |
| index = df.user_id[df.user_id == df_edge_list.source[i]].index.tolist()[0] | |
| source.append(index) | |
| index2 = df.user_id[df.user_id == df_edge_list.target[i]].index.tolist()[0] | |
| target.append(index2) | |
| user_edge_new = pd.DataFrame({'uid': source, 'uid2': target}) | |
| user_edge_new.to_csv(os.path.join(save_path, 'user_edge.csv'), index=False) | |
| user_field.to_csv(os.path.join(save_path, 'user_field.csv'), index=False) | |
| user_label.to_csv(os.path.join(save_path, 'user_labels.csv'), index=False) | |
| user_label[['user_id','SALARY']].to_csv(os.path.join(save_path, 'user_salary.csv'), index=False) | |
| user_salary = user_label[['user_id', 'SALARY']] | |
| print('User salary size', user_salary.size) | |
| user_label[['user_id','AGE']].to_csv(os.path.join(save_path, 'user_age.csv'), index=False) | |
| user_label[['user_id','MP']].to_csv(os.path.join(save_path, 'user_mp.csv'), index=False) | |
| user_label[['user_id','FG']].to_csv(os.path.join(save_path, 'user_fg.csv'), index=False) | |
| user_label[['user_id','country']].to_csv(os.path.join(save_path, 'user_country.csv'), index=False) | |
| user_label[['user_id','player_height']].to_csv(os.path.join(save_path, 'user_player_height.csv'), index=False) | |
| user_label[['user_id','player_weight']].to_csv(os.path.join(save_path, 'user_player_weight.csv'), index=False) | |
| NUM_FIELD = 10 | |
| #np.random_seed(42) | |
| # load user_field.csv | |
| user_field = field_reader(os.path.join(save_path, 'user_field.csv')) | |
| print("Shapes of user with field:", user_field.shape) | |
| print("Number of user with field:", np.count_nonzero(np.sum(user_field, axis=1))) | |
| neighs = get_neighs(user_field) | |
| sample_neighs = [] | |
| for i in range(len(neighs)): | |
| sample_neighs.append(list(sample_neigh(neighs[i], NUM_FIELD))) | |
| sample_neighs = np.array(sample_neighs) | |
| np.save(os.path.join(save_path, 'user_field.npy'), sample_neighs) | |
| user_field_new = sample_neighs | |
| user_edge_path = './user_edge.csv' | |
| user_field_new_path = './user_field.npy' | |
| user_salary_path = './user_salary.csv' | |
| user_label_path = './user_labels.csv' | |
| return user_edge_path, user_field_new_path, user_salary_path, user_label_path | |
| def get_count(tp, id): | |
| playcount_groupbyid = tp[[id]].groupby(id, as_index=True) | |
| count = playcount_groupbyid.size() | |
| return count | |
| def filter_triplets(tp, user, item, min_uc=0, min_sc=0): | |
| # Only keep the triplets for users who clicked on at least min_uc items | |
| if min_uc > 0: | |
| usercount = get_count(tp, user) | |
| tp = tp[tp[user].isin(usercount.index[usercount >= min_uc])] | |
| # Only keep the triplets for items which were clicked on by at least min_sc users. | |
| if min_sc > 0: | |
| itemcount = get_count(tp, item) | |
| tp = tp[tp[item].isin(itemcount.index[itemcount >= min_sc])] | |
| # Update both usercount and itemcount after filtering | |
| usercount, itemcount = get_count(tp, user), get_count(tp, item) | |
| return tp, usercount, itemcount | |
| def col_map(df, col, num2id): | |
| df[[col]] = df[[col]].applymap(lambda x: num2id[x]) | |
| return df | |
| def label_map(label_df, label_list): | |
| for label in label_list: | |
| label2id = {num: i for i, num in enumerate(pd.unique(label_df[label]))} | |
| label_df = col_map(label_df, label, label2id) | |
| return label_df | |
| def field_reader(path): | |
| """ | |
| Reading the sparse field matrix stored as csv from the disk. | |
| :param path: Path to the csv file. | |
| :return field: csr matrix of field. | |
| """ | |
| user_field = pd.read_csv(path) | |
| user_index = user_field["user_id"].values.tolist() | |
| field_index = user_field["AGE"].values.tolist() | |
| user_count = max(user_index)+1 | |
| field_count = max(field_index)+1 | |
| field_index = sp.csr_matrix((np.ones_like(user_index), (user_index, field_index)), shape=(user_count, field_count)) | |
| return field_index | |
| #user_field = field_reader(os.path.join(save_path, 'user_field.csv')) | |
| #print("Shapes of user with field:", user_field.shape) | |
| #print("Number of user with field:", np.count_nonzero(np.sum(user_field, axis=1))) | |
| def get_neighs(csr): | |
| neighs = [] | |
| # t = time.time() | |
| idx = np.arange(csr.shape[1]) | |
| for i in range(csr.shape[0]): | |
| x = csr[i, :].toarray()[0] > 0 | |
| neighs.append(idx[x]) | |
| # if i % (10*1000) == 0: | |
| # print('sec/10k:', time.time()-t) | |
| return neighs | |
| def sample_neigh(neigh, num_sample): | |
| if len(neigh) >= num_sample: | |
| sample_neigh = np.random.choice(neigh, num_sample, replace=False) | |
| elif len(neigh) < num_sample: | |
| sample_neigh = np.random.choice(neigh, num_sample, replace=True) | |
| return sample_neigh | |
| def apply_bin_columns(df, onehot_bin_columns): | |
| for column in df: | |
| if column in onehot_bin_columns: | |
| df[column] = df[column].astype(int) | |
| return df | |
| def apply_cat_columns(df, onehot_cat_columns): | |
| df = pd.get_dummies(df, columns=onehot_cat_columns) | |
| return df |