added new functions, fixed data processing:
Browse files- changed to top 5 features
- rounded floats to 2 decimals
utils.py
CHANGED
|
@@ -2,8 +2,6 @@ import streamlit.components.v1 as components
|
|
| 2 |
import streamlit as st
|
| 3 |
from random import randrange, uniform
|
| 4 |
import pandas as pd
|
| 5 |
-
import joblib
|
| 6 |
-
import dill
|
| 7 |
import logging
|
| 8 |
import numpy as np
|
| 9 |
|
|
@@ -52,24 +50,26 @@ def get_explainability_texts(shap_values, feature_texts):
|
|
| 52 |
# Sort dictionaries based on the magnitude of values
|
| 53 |
sorted_positive_indices = [index for index, _ in sorted(positive_dict.items(), key=lambda item: abs(item[1]), reverse=True)]
|
| 54 |
positive_texts = [feature_texts[x] for x in sorted_positive_indices]
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
return positive_texts, sorted_positive_indices
|
| 58 |
|
| 59 |
|
| 60 |
def get_explainability_values(pos_indices, datapoint):
|
| 61 |
-
data = datapoint.iloc[0].tolist()
|
| 62 |
-
|
| 63 |
-
|
| 64 |
vals = []
|
| 65 |
for idx in pos_indices:
|
| 66 |
if idx in range(7,11) or idx in range(13,18):
|
| 67 |
-
val = str(bool(
|
| 68 |
else:
|
| 69 |
-
val =
|
| 70 |
vals.append(val)
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
return vals
|
| 74 |
|
| 75 |
def get_fake_certainty():
|
|
@@ -120,10 +120,39 @@ def get_model_url():
|
|
| 120 |
deployment_id = ""
|
| 121 |
return model_url, workspace_id, deployment_id
|
| 122 |
|
| 123 |
-
def
|
| 124 |
cleaned = [x.replace(':', '') for x in explainability_texts]
|
| 125 |
-
fi = [f'{
|
| 126 |
fi.insert(0, 'Important suspicious features: ')
|
| 127 |
result = '\n'.join(fi)
|
| 128 |
comment = f"Model certainty is {certainty}" + '\n''\n' + result
|
| 129 |
-
return comment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import streamlit as st
|
| 3 |
from random import randrange, uniform
|
| 4 |
import pandas as pd
|
|
|
|
|
|
|
| 5 |
import logging
|
| 6 |
import numpy as np
|
| 7 |
|
|
|
|
| 50 |
# Sort dictionaries based on the magnitude of values
|
| 51 |
sorted_positive_indices = [index for index, _ in sorted(positive_dict.items(), key=lambda item: abs(item[1]), reverse=True)]
|
| 52 |
positive_texts = [feature_texts[x] for x in sorted_positive_indices]
|
| 53 |
+
positive_texts = positive_texts[2:]
|
| 54 |
+
if len(positive_texts) > 5:
|
| 55 |
+
positive_texts = positive_texts[:5]
|
| 56 |
return positive_texts, sorted_positive_indices
|
| 57 |
|
| 58 |
|
| 59 |
def get_explainability_values(pos_indices, datapoint):
|
| 60 |
+
data = datapoint.iloc[0].tolist()
|
| 61 |
+
rounded_data = [round(value, 2) if isinstance(value, float) else value for value in data]
|
| 62 |
+
transformed_data = transformation(input=rounded_data, categories=CATEGORIES)
|
| 63 |
vals = []
|
| 64 |
for idx in pos_indices:
|
| 65 |
if idx in range(7,11) or idx in range(13,18):
|
| 66 |
+
val = str(bool(transformed_data[idx])).capitalize()
|
| 67 |
else:
|
| 68 |
+
val = transformed_data[idx]
|
| 69 |
vals.append(val)
|
| 70 |
+
vals = vals[2:]
|
| 71 |
+
if len(vals) > 5:
|
| 72 |
+
vals = vals[:5]
|
| 73 |
return vals
|
| 74 |
|
| 75 |
def get_fake_certainty():
|
|
|
|
| 120 |
deployment_id = ""
|
| 121 |
return model_url, workspace_id, deployment_id
|
| 122 |
|
| 123 |
+
def get_comment_explanation(certainty, explainability_texts, explainability_values):
|
| 124 |
cleaned = [x.replace(':', '') for x in explainability_texts]
|
| 125 |
+
fi = [f'{cleaned[i]} is {x}' for i, x in enumerate(explainability_values)]
|
| 126 |
fi.insert(0, 'Important suspicious features: ')
|
| 127 |
result = '\n'.join(fi)
|
| 128 |
comment = f"Model certainty is {certainty}" + '\n''\n' + result
|
| 129 |
+
return comment
|
| 130 |
+
|
| 131 |
+
def create_data_input_table(datapoint, col_names):
|
| 132 |
+
st.subheader("Flagged Transaction:")
|
| 133 |
+
data = datapoint.iloc[0].tolist()
|
| 134 |
+
data[7:12] = [bool(value) for value in data[7:12]]
|
| 135 |
+
rounded_list = [round(value, 2) if isinstance(value, float) else value for value in data]
|
| 136 |
+
df = pd.DataFrame({"Feature name": col_names, "Value": rounded_list })
|
| 137 |
+
st.dataframe(df, hide_index=True, width=450, height=35*len(df)+38)
|
| 138 |
+
|
| 139 |
+
# Create a function to generate a table
|
| 140 |
+
def create_table(texts, values, title):
|
| 141 |
+
df = pd.DataFrame({"Feature Explanation": texts, 'Value': values})
|
| 142 |
+
st.markdown(f'#### {title}') # Markdown for styling
|
| 143 |
+
st.dataframe(df, hide_index=True, width=450) # Display a simple table
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def ChangeButtonColour(widget_label, font_color, background_color='transparent'):
|
| 147 |
+
htmlstr = f"""
|
| 148 |
+
<script>
|
| 149 |
+
var elements = window.parent.document.querySelectorAll('button');
|
| 150 |
+
for (var i = 0; i < elements.length; ++i) {{
|
| 151 |
+
if (elements[i].innerText == '{widget_label}') {{
|
| 152 |
+
elements[i].style.color ='{font_color}';
|
| 153 |
+
elements[i].style.background = '{background_color}'
|
| 154 |
+
}}
|
| 155 |
+
}}
|
| 156 |
+
</script>
|
| 157 |
+
"""
|
| 158 |
+
components.html(f"{htmlstr}", height=0, width=0)
|