# coding: utf-8 import numpy as np import pandas as pd import random from helper_functions import determine_type_of_feature # 1. Decision Tree helper functions # (see "decision tree algorithm flow chart.png") # 1.1 Data pure? def check_purity(data): label_column = data[:, -1] unique_classes = np.unique(label_column) if len(unique_classes) == 1: return True else: return False # 1.2 Classify def classify_data(data): label_column = data[:, -1] unique_classes, counts_unique_classes = np.unique(label_column, return_counts=True) index = counts_unique_classes.argmax() classification = unique_classes[index] return classification # 1.3 Potential splits? def get_potential_splits(data, random_subspace): potential_splits = {} _, n_columns = data.shape column_indices = list(range(n_columns - 1)) # excluding the last column which is the label if random_subspace and random_subspace <= len(column_indices): column_indices = random.sample(population=column_indices, k=random_subspace) for column_index in column_indices: values = data[:, column_index] unique_values = np.unique(values) potential_splits[column_index] = unique_values return potential_splits # 1.4 Lowest Overall Entropy? def calculate_entropy(data): label_column = data[:, -1] _, counts = np.unique(label_column, return_counts=True) probabilities = counts / counts.sum() entropy = sum(probabilities * -np.log2(probabilities)) return entropy def calculate_overall_entropy(data_below, data_above): n = len(data_below) + len(data_above) p_data_below = len(data_below) / n p_data_above = len(data_above) / n overall_entropy = (p_data_below * calculate_entropy(data_below) + p_data_above * calculate_entropy(data_above)) return overall_entropy def determine_best_split(data, potential_splits): overall_entropy = 9999 for column_index in potential_splits: for value in potential_splits[column_index]: data_below, data_above = split_data(data, split_column=column_index, split_value=value) current_overall_entropy = calculate_overall_entropy(data_below, data_above) if current_overall_entropy <= overall_entropy: overall_entropy = current_overall_entropy best_split_column = column_index best_split_value = value return best_split_column, best_split_value