local_copilot / data /decision_tree_functions.py
Kash6's picture
Deploy AI Coding Assistant
04653e2
# coding: utf-8
import numpy as np
import pandas as pd
import random
from helper_functions import determine_type_of_feature
# 1. Decision Tree helper functions
# (see "decision tree algorithm flow chart.png")
# 1.1 Data pure?
def check_purity(data):
label_column = data[:, -1]
unique_classes = np.unique(label_column)
if len(unique_classes) == 1:
return True
else:
return False
# 1.2 Classify
def classify_data(data):
label_column = data[:, -1]
unique_classes, counts_unique_classes = np.unique(label_column, return_counts=True)
index = counts_unique_classes.argmax()
classification = unique_classes[index]
return classification
# 1.3 Potential splits?
def get_potential_splits(data, random_subspace):
potential_splits = {}
_, n_columns = data.shape
column_indices = list(range(n_columns - 1)) # excluding the last column which is the label
if random_subspace and random_subspace <= len(column_indices):
column_indices = random.sample(population=column_indices, k=random_subspace)
for column_index in column_indices:
values = data[:, column_index]
unique_values = np.unique(values)
potential_splits[column_index] = unique_values
return potential_splits
# 1.4 Lowest Overall Entropy?
def calculate_entropy(data):
label_column = data[:, -1]
_, counts = np.unique(label_column, return_counts=True)
probabilities = counts / counts.sum()
entropy = sum(probabilities * -np.log2(probabilities))
return entropy
def calculate_overall_entropy(data_below, data_above):
n = len(data_below) + len(data_above)
p_data_below = len(data_below) / n
p_data_above = len(data_above) / n
overall_entropy = (p_data_below * calculate_entropy(data_below)
+ p_data_above * calculate_entropy(data_above))
return overall_entropy
def determine_best_split(data, potential_splits):
overall_entropy = 9999
for column_index in potential_splits:
for value in potential_splits[column_index]:
data_below, data_above = split_data(data, split_column=column_index, split_value=value)
current_overall_entropy = calculate_overall_entropy(data_below, data_above)
if current_overall_entropy <= overall_entropy:
overall_entropy = current_overall_entropy
best_split_column = column_index
best_split_value = value
return best_split_column, best_split_value