import base64
import io
import time
from itertools import pairwise
import gradio as gr
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Define the transition matrix and state information
num_states = 4
states = np.arange(num_states)
state_names = [f"{i}" for i in states]
def generate_p(num_states):
rng = np.random.default_rng(42)
return rng.dirichlet(alpha=np.repeat(1, num_states), size=np.repeat(num_states, 1))
def generate_sequence(alphabet, P, length=10):
rng = np.random.default_rng(42)
sequence = rng.choice(a=alphabet, size=1).tolist()
for i in range(length - 1):
next_state = rng.choice(a=alphabet, p=P[sequence[-1]])
sequence.extend([next_state])
return sequence
def update_p(P, s_prev, s, lambda_=0.9):
P[s_prev,] *= lambda_
P[s_prev,][s] += (1 - lambda_)
return P
# compute the Hellinger distance between two probability distributions p and q
def hellinger_distance(p, q):
return np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q)) ** 2, axis=len(p.shape)-1))
# Generate a colorbar as an image
def generate_colorbar(colormap, normalize):
fig, ax = plt.subplots(figsize=(1, 4))
fig.subplots_adjust(left=0.5, right=0.6) # Adjust layout for a narrow colorbar
colorbar = plt.colorbar(
plt.cm.ScalarMappable(norm=normalize, cmap=colormap),
cax=ax,
orientation='vertical',
)
colorbar.set_label("Transition Probability", rotation=90, labelpad=15)
# Save colorbar as an image in memory
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", transparent=True)
buf.seek(0)
base64_colorbar = base64.b64encode(buf.read()).decode("utf-8")
plt.close(fig) # Close the figure to free up memory
return f""
# Helper function to generate HTML for state diagram with directed edges
def generate_state_diagram_html(P):
# Coordinates for the 2x2 grid positions of the states
state_positions = {
0: (50, 50),
1: (200, 50),
2: (50, 200),
3: (200, 200),
}
# Generate the SVG for edges between states based on the transition matrix
edges = ""
for i in range(num_states):
for j in range(num_states):
if P[i, j] > 0: # Only draw edge if probability > 0
# Get positions of the states (nodes)
x1, y1 = state_positions[i]
x2, y2 = state_positions[j]
# Transition weight determines the thickness of the edge
thickness = max(2, 5 * P[i, j]) # Set min thickness to 2
# Add slight offsets to avoid overlap
offset = 30
egap = 5
if i == 0 and j == 0:
edges += f"""
"""
# Define the arrowhead marker
arrowhead_marker = """
"""
# Create the state diagram (2x2 grid)
html_content = f"""
State Diagram
"""
return html_content
# Helper function to generate HTML for transition matrix heatmap
def generate_transition_matrix_html(P, prob_threshold=0.5, normalize=None, colormap=None, ticks=None):
header_size = "30px" # Set a fixed size for headers
cell_size = "60px" # Set a fixed size for cells
html_content = f"
"
for j, value in enumerate(row):
# Map the transition probability to a color intensity
rgba_color = colormap(normalize(value)) if normalize else (value, value, value, 1)
hex_color = mcolors.to_hex(rgba_color)
text_color = "white" if value > prob_threshold else "black"
html_content += f"
" \
f"{value:.2f}
"
html_content += "
"
html_content += "
"
return html_content
def matrix_to_string(matrix):
return "\n".join([" ".join(map(str, row)) for row in matrix])
def string_to_matrix(prob_str):
# Convert the input string into a 4x4 matrix
rows = prob_str.strip().split('\n') # Split into rows by newline
matrix = [list(map(float, row.strip().split(" "))) for row in rows] # Convert each row into a list of floats
matrix = np.array(matrix) # Convert to a numpy array for convenience
return matrix
def initial_html(P):
colormap = plt.cm.Blues
normalize = mcolors.Normalize(vmin=0, vmax=1)
state_diagram_html = generate_state_diagram_html(P)
transition_matrix_html = generate_transition_matrix_html(
P=P,
colormap=colormap,
normalize=normalize,
ticks=state_names
)
colorbar_html = generate_colorbar(colormap, normalize)
combined_html = f"""
{state_diagram_html}
{transition_matrix_html}
{colorbar_html}
"""
return combined_html
def process_sequence(current_P, P_true_str, sequence_length, lambda_, tau, state_HD):
P_true = string_to_matrix(P_true_str)
sequence = generate_sequence(states, P_true, length=sequence_length)
P = current_P
# Set up the colormap
colormap = plt.cm.Blues
normalize = mcolors.Normalize(vmin=0, vmax=1)
colorbar_html = generate_colorbar(colormap, normalize) # Generate the colorbar once
hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]})
hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")
for s_count, (s_prev, s) in enumerate(pairwise(sequence)):
if s_count == 0:
P_prev = P.copy()
elif s_count % tau == 0:
hd = np.max(hellinger_distance(P_prev, P))
state_HD.append(hd)
P_prev = P.copy()
hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]})
hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")
P = update_p(P, s_prev, s, lambda_) # Update the transition matrix
# Generate HTML for the current state and transition matrix
state_diagram_html = generate_state_diagram_html(P)
transition_matrix_html = generate_transition_matrix_html(
P=P,
colormap=colormap,
normalize=normalize,
ticks=state_names
)
# Combine matrix and colorbar HTML side-by-side in a flex container
combined_html = f"""
{state_diagram_html}
{transition_matrix_html}
{colorbar_html}
"""
# Yield the state diagram HTML and combined transition matrix + colorbar HTML
yield combined_html, hd_plot
time.sleep(0.005) # Pause before the next state
css = """
.output_flexbox{
display: flex;
align-items: center;
justify-content:center;
gap: 50px;
flex-direction: row;
}
@media screen and (max-width: 800px){
.output_flexbox{
flex-direction: column;
}
}
"""
with gr.Blocks(css=css) as demo:
# initial probability distribution is a uniform distribution, saved in a State object
P0 = 1/num_states * np.ones(shape=np.repeat(num_states, 2))
state_P = gr.State(P0)
state_HD = gr.State([0])
modes = [
[
[0.19764149, 0.15019737, 0.5511312 , 0.10102994],
[0.00981767, 0.29327228, 0.01960321, 0.67730684],
[0.73753637, 0.01300742, 0.16719004, 0.08226616],
[0.59819186, 0.04902065, 0.35083554, 0.00195195]
],
[
[0.00739772, 0.05260542, 0.09302236, 0.8469745 ],
[0.07016094, 0.79909194, 0.00816012, 0.122587 ],
[0.11185849, 0.00218829, 0.86738849, 0.01856473],
[0.03560788, 0.1767552 , 0.72055577, 0.06708115]
],
[
[0.02901062, 0.0548365 , 0.88122331, 0.03492957],
[0.03180118, 0.08216069, 0.85603785, 0.03000028],
[0.02280227, 0.09024757, 0.08233428, 0.80461588],
[0.23941395, 0.05389086, 0.55260164, 0.15409355]
]
]
with gr.Column():
with gr.Column():
target_p = gr.Textbox(
value=matrix_to_string(modes[0]),
label="Enter transition matrix or select a mode.",
lines=4,
placeholder="Enter the 4x4 matrix here...",
)
with gr.Row():
mode_0 = gr.Button("Mode 0")
mode_1 = gr.Button("Mode 1")
mode_2 = gr.Button("Mode 2")
with gr.Row():
sequence_length = gr.Number(label="Sequence Length", value=500, minimum=0)
lambda_ = gr.Number(label="Lambda", value=0.95, minimum=0, maximum=1)
tau = gr.Number(label="Tau", value=25, minimum=0)
run_btn = gr.Button("Run")
html_output = gr.HTML(value=initial_html(P0), label="State Diagram and Transition Matrix")
hd_data = pd.DataFrame({"time": list(range(len(state_HD.value))), "hd": [hd for hd in state_HD.value]})
hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")
run_btn.click(
fn=process_sequence,
inputs=[state_P, target_p, sequence_length, lambda_, tau, state_HD],
outputs=[html_output, hd_plot]
)
mode_0.click(fn=lambda: matrix_to_string(modes[0]), inputs=None, outputs=target_p)
mode_1.click(fn=lambda: matrix_to_string(modes[1]), inputs=None, outputs=target_p)
mode_2.click(fn=lambda: matrix_to_string(modes[2]), inputs=None, outputs=target_p)
demo.launch(share=False)