import base64 import io import time from itertools import pairwise import gradio as gr import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np import pandas as pd # Define the transition matrix and state information num_states = 4 states = np.arange(num_states) state_names = [f"{i}" for i in states] def generate_p(num_states): rng = np.random.default_rng(42) return rng.dirichlet(alpha=np.repeat(1, num_states), size=np.repeat(num_states, 1)) def generate_sequence(alphabet, P, length=10): rng = np.random.default_rng(42) sequence = rng.choice(a=alphabet, size=1).tolist() for i in range(length - 1): next_state = rng.choice(a=alphabet, p=P[sequence[-1]]) sequence.extend([next_state]) return sequence def update_p(P, s_prev, s, lambda_=0.9): P[s_prev,] *= lambda_ P[s_prev,][s] += (1 - lambda_) return P # compute the Hellinger distance between two probability distributions p and q def hellinger_distance(p, q): return np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q)) ** 2, axis=len(p.shape)-1)) # Generate a colorbar as an image def generate_colorbar(colormap, normalize): fig, ax = plt.subplots(figsize=(1, 4)) fig.subplots_adjust(left=0.5, right=0.6) # Adjust layout for a narrow colorbar colorbar = plt.colorbar( plt.cm.ScalarMappable(norm=normalize, cmap=colormap), cax=ax, orientation='vertical', ) colorbar.set_label("Transition Probability", rotation=90, labelpad=15) # Save colorbar as an image in memory buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", transparent=True) buf.seek(0) base64_colorbar = base64.b64encode(buf.read()).decode("utf-8") plt.close(fig) # Close the figure to free up memory return f"" # Helper function to generate HTML for state diagram with directed edges def generate_state_diagram_html(P): # Coordinates for the 2x2 grid positions of the states state_positions = { 0: (50, 50), 1: (200, 50), 2: (50, 200), 3: (200, 200), } # Generate the SVG for edges between states based on the transition matrix edges = "" for i in range(num_states): for j in range(num_states): if P[i, j] > 0: # Only draw edge if probability > 0 # Get positions of the states (nodes) x1, y1 = state_positions[i] x2, y2 = state_positions[j] # Transition weight determines the thickness of the edge thickness = max(2, 5 * P[i, j]) # Set min thickness to 2 # Add slight offsets to avoid overlap offset = 30 egap = 5 if i == 0 and j == 0: edges += f""" """ # Define the arrowhead marker arrowhead_marker = """ """ # Create the state diagram (2x2 grid) html_content = f"""
State Diagram
{arrowhead_marker} {edges} {state_names[0]} {state_names[1]} {state_names[2]} {state_names[3]} """ return html_content # Helper function to generate HTML for transition matrix heatmap def generate_transition_matrix_html(P, prob_threshold=0.5, normalize=None, colormap=None, ticks=None): header_size = "30px" # Set a fixed size for headers cell_size = "60px" # Set a fixed size for cells html_content = f"
Transition Matrix
" html_content += f"" # Add xticks (column labels) if ticks: html_content += f"" html_content += "" # Top left corner empty for tick in ticks: html_content += f"" html_content += "" # Add the transition matrix rows for i, row in enumerate(P): html_content += "" # Add yticks (row labels) if ticks: html_content += f"" for j, value in enumerate(row): # Map the transition probability to a color intensity rgba_color = colormap(normalize(value)) if normalize else (value, value, value, 1) hex_color = mcolors.to_hex(rgba_color) text_color = "white" if value > prob_threshold else "black" html_content += f"" html_content += "" html_content += "
{tick}
{ticks[i]}" \ f"{value:.2f}
" return html_content def matrix_to_string(matrix): return "\n".join([" ".join(map(str, row)) for row in matrix]) def string_to_matrix(prob_str): # Convert the input string into a 4x4 matrix rows = prob_str.strip().split('\n') # Split into rows by newline matrix = [list(map(float, row.strip().split(" "))) for row in rows] # Convert each row into a list of floats matrix = np.array(matrix) # Convert to a numpy array for convenience return matrix def initial_html(P): colormap = plt.cm.Blues normalize = mcolors.Normalize(vmin=0, vmax=1) state_diagram_html = generate_state_diagram_html(P) transition_matrix_html = generate_transition_matrix_html( P=P, colormap=colormap, normalize=normalize, ticks=state_names ) colorbar_html = generate_colorbar(colormap, normalize) combined_html = f"""
{state_diagram_html}
{transition_matrix_html}
{colorbar_html}
""" return combined_html def process_sequence(current_P, P_true_str, sequence_length, lambda_, tau, state_HD): P_true = string_to_matrix(P_true_str) sequence = generate_sequence(states, P_true, length=sequence_length) P = current_P # Set up the colormap colormap = plt.cm.Blues normalize = mcolors.Normalize(vmin=0, vmax=1) colorbar_html = generate_colorbar(colormap, normalize) # Generate the colorbar once hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]}) hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance") for s_count, (s_prev, s) in enumerate(pairwise(sequence)): if s_count == 0: P_prev = P.copy() elif s_count % tau == 0: hd = np.max(hellinger_distance(P_prev, P)) state_HD.append(hd) P_prev = P.copy() hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]}) hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance") P = update_p(P, s_prev, s, lambda_) # Update the transition matrix # Generate HTML for the current state and transition matrix state_diagram_html = generate_state_diagram_html(P) transition_matrix_html = generate_transition_matrix_html( P=P, colormap=colormap, normalize=normalize, ticks=state_names ) # Combine matrix and colorbar HTML side-by-side in a flex container combined_html = f"""
{state_diagram_html}
{transition_matrix_html}
{colorbar_html}
""" # Yield the state diagram HTML and combined transition matrix + colorbar HTML yield combined_html, hd_plot time.sleep(0.005) # Pause before the next state css = """ .output_flexbox{ display: flex; align-items: center; justify-content:center; gap: 50px; flex-direction: row; } @media screen and (max-width: 800px){ .output_flexbox{ flex-direction: column; } } """ with gr.Blocks(css=css) as demo: # initial probability distribution is a uniform distribution, saved in a State object P0 = 1/num_states * np.ones(shape=np.repeat(num_states, 2)) state_P = gr.State(P0) state_HD = gr.State([0]) modes = [ [ [0.19764149, 0.15019737, 0.5511312 , 0.10102994], [0.00981767, 0.29327228, 0.01960321, 0.67730684], [0.73753637, 0.01300742, 0.16719004, 0.08226616], [0.59819186, 0.04902065, 0.35083554, 0.00195195] ], [ [0.00739772, 0.05260542, 0.09302236, 0.8469745 ], [0.07016094, 0.79909194, 0.00816012, 0.122587 ], [0.11185849, 0.00218829, 0.86738849, 0.01856473], [0.03560788, 0.1767552 , 0.72055577, 0.06708115] ], [ [0.02901062, 0.0548365 , 0.88122331, 0.03492957], [0.03180118, 0.08216069, 0.85603785, 0.03000028], [0.02280227, 0.09024757, 0.08233428, 0.80461588], [0.23941395, 0.05389086, 0.55260164, 0.15409355] ] ] with gr.Column(): with gr.Column(): target_p = gr.Textbox( value=matrix_to_string(modes[0]), label="Enter transition matrix or select a mode.", lines=4, placeholder="Enter the 4x4 matrix here...", ) with gr.Row(): mode_0 = gr.Button("Mode 0") mode_1 = gr.Button("Mode 1") mode_2 = gr.Button("Mode 2") with gr.Row(): sequence_length = gr.Number(label="Sequence Length", value=500, minimum=0) lambda_ = gr.Number(label="Lambda", value=0.95, minimum=0, maximum=1) tau = gr.Number(label="Tau", value=25, minimum=0) run_btn = gr.Button("Run") html_output = gr.HTML(value=initial_html(P0), label="State Diagram and Transition Matrix") hd_data = pd.DataFrame({"time": list(range(len(state_HD.value))), "hd": [hd for hd in state_HD.value]}) hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance") run_btn.click( fn=process_sequence, inputs=[state_P, target_p, sequence_length, lambda_, tau, state_HD], outputs=[html_output, hd_plot] ) mode_0.click(fn=lambda: matrix_to_string(modes[0]), inputs=None, outputs=target_p) mode_1.click(fn=lambda: matrix_to_string(modes[1]), inputs=None, outputs=target_p) mode_2.click(fn=lambda: matrix_to_string(modes[2]), inputs=None, outputs=target_p) demo.launch(share=False)