prakharg24's picture
Update my_pages/ica.py
1588366 verified
raw
history blame
5.42 kB
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text
plt.style.use('dark_background')
def render():
add_navigation("txt_ica", "txt_multiverse")
add_instruction_text(
"""
Explore the intention-convention-arbitrariness (ICA) framework.<br>
Use different sliders to uncover examples in the ICA triangle.
"""
)
if "weights" not in st.session_state:
st.session_state.weights = {
"Intentional": 0.33,
"Conventional": 0.33,
"Arbitrary": 0.34
}
col1, col2 = st.columns([0.6, 0.4])
with col1:
control_choice = st.radio(
"Select dimension to adjust",
["Intentional", "Conventional", "Arbitrary"],
horizontal=True,
label_visibility="collapsed"
)
# Current values
w = st.session_state.weights
current_value = w[control_choice]
with col2:
new_value = st.slider(control_choice, 0.0, 1.0, current_value, 0.01, label_visibility="collapsed")
# Adjust others proportionally
diff = new_value - current_value
others = [k for k in w.keys() if k != control_choice]
total_other = w[others[0]] + w[others[1]]
if total_other > 0:
w[others[0]] -= diff * (w[others[0]] / total_other)
w[others[1]] -= diff * (w[others[1]] / total_other)
w[control_choice] = new_value
# Clamp small floating point errors
for k in w:
w[k] = max(0.0, min(1.0, round(w[k], 4)))
# Normalize back to sum=1
total = sum(w.values())
if total != 0:
for k in w:
w[k] = round(w[k] / total, 4)
# Triangle vertices
vertices = np.array([
[0.5, np.sqrt(3)/2], # Intentional
[0, 0], # Conventional
[1, 0] # Arbitrary
])
# Point from barycentric coords
point = (
w["Intentional"] * vertices[0] +
w["Conventional"] * vertices[1] +
w["Arbitrary"] * vertices[2]
)
# Plot
fig, ax = plt.subplots()
ax.plot(*np.append(vertices, [vertices[0]], axis=0).T)
# ax.scatter(vertices[:,0], vertices[:,1], c=["blue", "green", "red"], s=100)
ax.text(*vertices[0], "Intentional", ha="center", va="bottom", color="green", weight="heavy")
ax.text(*vertices[1], "Conventional", ha="right", va="top", color="green", weight="heavy")
ax.text(*vertices[2], "Arbitrary", ha="left", va="top", color="green", weight="heavy")
ax.scatter(point[0], point[1], c="white", s=10000)
ax.scatter(point[0], point[1], c="orange", s=10000, zorder=5, alpha=0.3)
ax.set_aspect("equal")
ax.axis("off")
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
# --- Dummy points scattered inside triangle ---
# (x, y, text)
locations = [
(0.9, 0.1, "Random Seeds", "Random Seeds are highly arbitrary, without any convention or intentionality.", "left", "bottom"),
(0.35, 0.06, "Neural networks for Tabular Data", "Using neural networks of some arbitrary size (hidden layers) for a setting where \
they are not needed is highly conventional, a bit arbitrary, and has very low intentionality.", "left", "bottom"),
(0.4, 0.5, "Pre-trained LLM for a Complex Task", "Using a high performing LLM for a complex task is intentional, however, it also has \
conventionality to it, as a specialized model could have worked, depending on context.\
No arbitrariness.", "right", "bottom"),
(0.5, 0.7, "Best Bias Mitigation for a Particular Setup", "Choosing the most appropriate bias mitigation technique,\
specialized for the particular context, is highly intentional", "center", "bottom"),
(0.7, 0.5, "Randomly chosen Regularization Technique", "Adding regularization to improve robustness, but choosing the regularization technique randomly,\
creates a decision that is intentional and arbitrary, while avoiding conventionality.", "left", "bottom"),
(0.1, 0.1, "ReLU Activation as Default", "Choosing ReLU activation without testing what other activations might also work,\
is a highly conventional decision.", "right", "bottom"),
]
torch_radius = 0.177 # how far the "torch" illuminates
explanations = []
# Illuminate nearby points
for (x, y, label, labeltext, ha, va) in locations:
dist = np.linalg.norm([x - point[0], y - point[1]])
if dist <= torch_radius:
ax.scatter(x, y, c="red", s=50, zorder=6)
ax.text(x, y + 0.03, label, ha=ha, va=va, color="red", zorder=6, weight="heavy")
explanations.append((label, labeltext))
else:
ax.scatter(x, y, c="red", s=50, zorder=6, alpha=0.3)
col1, col2, col3 = st.columns([0.3, 1, 0.3])
with col2:
st.pyplot(fig)
if len(explanations) > 0:
text_to_show = ""
for label, labeltext in explanations:
text_to_show += "<b>" + label + ":</b> " + labeltext + "<br>"
add_red_text(text_to_show)