File size: 15,623 Bytes
47ff8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bb3f49
47ff8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bb3f49
47ff8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bb3f49
 
47ff8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2d4edf
47ff8ac
5bb3f49
 
 
 
47ff8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2d4edf
47ff8ac
5bb3f49
 
 
 
47ff8ac
 
 
 
 
 
 
 
a2d4edf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47ff8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bb3f49
47ff8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import base64
import io
import time
from itertools import pairwise

import gradio as gr
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Define the transition matrix and state information
num_states = 4
states = np.arange(num_states)
state_names = [f"{i}" for i in states]

def generate_p(num_states):
    rng = np.random.default_rng(42)
    return rng.dirichlet(alpha=np.repeat(1, num_states), size=np.repeat(num_states, 1))

def generate_sequence(alphabet, P, length=10):
    rng = np.random.default_rng(42)
    sequence = rng.choice(a=alphabet, size=1).tolist()
    for i in range(length - 1):
        next_state = rng.choice(a=alphabet, p=P[sequence[-1]])
        sequence.extend([next_state])
    return sequence

def update_p(P, s_prev, s, lambda_=0.9):
    P[s_prev,] *= lambda_
    P[s_prev,][s] += (1 - lambda_)
    return P

# compute the Hellinger distance between two probability distributions p and q
def hellinger_distance(p, q):
    return np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q)) ** 2, axis=len(p.shape)-1))

# Generate a colorbar as an image
def generate_colorbar(colormap, normalize):
    fig, ax = plt.subplots(figsize=(1, 4))
    fig.subplots_adjust(left=0.5, right=0.6)  # Adjust layout for a narrow colorbar
    colorbar = plt.colorbar(
        plt.cm.ScalarMappable(norm=normalize, cmap=colormap),
        cax=ax,
        orientation='vertical',
    )
    colorbar.set_label("Transition Probability", rotation=90, labelpad=15)

    # Save colorbar as an image in memory
    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", transparent=True)
    buf.seek(0)
    base64_colorbar = base64.b64encode(buf.read()).decode("utf-8")
    plt.close(fig)  # Close the figure to free up memory
    return f"<img src='data:image/png;base64,{base64_colorbar}' style='height:250px; margin-top:30px'>"


# Helper function to generate HTML for state diagram with directed edges
def generate_state_diagram_html(P):

    # Coordinates for the 2x2 grid positions of the states
    state_positions = {
        0: (50, 50),
        1: (200, 50),
        2: (50, 200),
        3: (200, 200),
    }
    
    # Generate the SVG for edges between states based on the transition matrix
    edges = ""
    for i in range(num_states):
        for j in range(num_states):
            if P[i, j] > 0:  # Only draw edge if probability > 0

                # Get positions of the states (nodes)
                x1, y1 = state_positions[i]
                x2, y2 = state_positions[j]

                # Transition weight determines the thickness of the edge
                thickness = max(2, 5 * P[i, j])  # Set min thickness to 2

                # Add slight offsets to avoid overlap
                offset = 30
                egap = 5
                
                if i == 0 and j == 0:
                    edges += f"""
                        <path d="M{25} {40} C{0} {30} {30} {0} {40} {25}"
                    """
                if i == 0 and j == 1:
                    edges += f"""
                        <path d="M{x1+offset} {y1-egap} L{x2-offset} {y2-egap}"
                    """
                if i == 0 and j == 2:
                    edges += f"""
                        <path d="M{x1-egap} {y1+offset} L{x2-egap} {y2-offset}"
                    """
                if i == 0 and j == 3:
                    edges += f"""
                        <path d="M{x1+offset+egap} {y1+offset-egap} L{x2-offset+egap} {y2-offset-egap}"
                    """
                # ----------------
                if i == 1 and j == 1:
                    edges += f"""
                        <path d="M225 40 C250 30 220 0 210 25"
                    """
                if i == 1 and j == 0:
                    edges += f"""
                        <path d="M{x1-offset} {y1+egap} L{x2+offset} {y2+egap}"
                    """
                if i == 1 and j == 2:
                    edges += f"""
                        <path d="M{x1-offset+egap} {y1+offset+egap} L{x2+offset+egap} {y2-offset+egap}"
                    """
                if i == 1 and j == 3:
                    edges += f"""
                        <path d="M{x1+egap} {y1+offset} L{x2+egap} {y2-offset}"
                    """
                # ----------------
                if i == 2 and j == 0:
                    edges += f"""
                        <path d="M{x1+egap} {y1-offset} L{x2+egap} {y2+offset}"
                    """
                if i == 2 and j == 1:
                    edges += f"""
                        <path d="M{x1+offset-egap} {y1-offset-egap} L{x2-offset-egap} {y2+offset-egap}"
                    """
                if i == 2 and j == 2:
                    edges += f"""
                        <path d="M25 210 C0 220 30 250 40 225"
                    """
                if i == 2 and j == 3:
                    edges += f"""
                        <path d="M{x1+offset} {y1-egap} L{x2-offset} {y2-egap}"
                    """
                # ----------------
                if i == 3 and j == 0:
                    edges += f"""
                        <path d="M{x1-offset-egap} {y1-offset+egap} L{x2+offset-egap} {y2+offset+egap}"
                    """
                if i == 3 and j == 1:
                    edges += f"""
                        <path d="M{x1-egap} {y1-offset} L{x2-egap} {y2+offset}"
                    """
                if i == 3 and j == 2:
                    edges += f"""
                        <path d="M{x1-offset} {y1+egap} L{x2+offset} {y2+egap}"
                    """
                if i == 3 and j == 3:
                    edges += f"""
                        <path d="M225 210 C250 220 220 250 210 225"
                    """

                edges += f"""
                    style="stroke: #F97316; stroke-width: {thickness}; opacity: {0.05+0.95*P[i, j]}; fill: none; marker-end: url(#arrowhead);" />
                """
    
    # Define the arrowhead marker
    arrowhead_marker = """
        <defs>
            <marker id="arrowhead" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="5" markerHeight="10" orient="auto" markerUnits="strokeWidth">
                <polygon points="0,0 10,5 0,10" fill="#F97316" />
            </marker>
        </defs>
    """
    
    # Create the state diagram (2x2 grid)
    html_content = f"""
        <div style="text-align:center;">State Diagram</div>
        <svg width="250" height="250">
            {arrowhead_marker}
            {edges}
            <!-- Nodes -->
            <circle cx="50" cy="50" r="25" fill="{'lightgrey'}" />
            <text x="50" y="50" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[0]}</text>

            <circle cx="200" cy="50" r="25" fill="{'lightgrey'}" />
            <text x="200" y="50" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[1]}</text>

            <circle cx="50" cy="200" r="25" fill="{'lightgrey'}" />
            <text x="50" y="200" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[2]}</text>

            <circle cx="200" cy="200" r="25" fill="{'lightgrey'}" />
            <text x="200" y="200" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[3]}</text>
        </svg>
    """
    return html_content

# Helper function to generate HTML for transition matrix heatmap
def generate_transition_matrix_html(P, prob_threshold=0.5, normalize=None, colormap=None, ticks=None):
    header_size = "30px"  # Set a fixed size for headers
    cell_size = "60px"  # Set a fixed size for cells
    html_content = f"<div style='text-align: center; margin-left: {header_size}; margin-bottom: 10px;'>Transition Matrix</div>"
    html_content += f"<table style='border: none; margin-bottom: {header_size} !important;'>"
    
    # Add xticks (column labels)
    if ticks:
        html_content += f"<tr style='border: none; height: {header_size};'>"
        html_content += "<td style='border: none; padding: 0;'></td>"  # Top left corner empty
        for tick in ticks:
            html_content += f"<td style='border: none; padding: 0; text-align: center; font-weight: bold;'>{tick}</td>"
        html_content += "</tr>"

    # Add the transition matrix rows
    for i, row in enumerate(P):
        html_content += "<tr style='border: none;'>"
        
        # Add yticks (row labels)
        if ticks:
            html_content += f"<td style='border: none; padding: 0; text-align: center; font-weight: bold; width: {header_size};'>{ticks[i]}</td>"
        
        for j, value in enumerate(row):
            # Map the transition probability to a color intensity
            rgba_color = colormap(normalize(value)) if normalize else (value, value, value, 1)
            hex_color = mcolors.to_hex(rgba_color)
            text_color = "white" if value > prob_threshold else "black"
            html_content += f"<td style='width: {cell_size}; height: {cell_size}; padding: 0; background-color: {hex_color}; " \
                            f"color: {text_color}; text-align: center; border: none; font-size: 18px;'>" \
                            f"{value:.2f}</td>"
        html_content += "</tr>"
    html_content += "</table>"
    
    return html_content

def matrix_to_string(matrix):
    return "\n".join([" ".join(map(str, row)) for row in matrix])

def string_to_matrix(prob_str):
    # Convert the input string into a 4x4 matrix
    rows = prob_str.strip().split('\n')  # Split into rows by newline
    matrix = [list(map(float, row.strip().split(" "))) for row in rows]  # Convert each row into a list of floats
    matrix = np.array(matrix)  # Convert to a numpy array for convenience
    return matrix

def initial_html(P):
    colormap = plt.cm.Blues
    normalize = mcolors.Normalize(vmin=0, vmax=1)
    state_diagram_html = generate_state_diagram_html(P)
    transition_matrix_html = generate_transition_matrix_html(
        P=P,
        colormap=colormap,
        normalize=normalize,
        ticks=state_names
    )
    colorbar_html = generate_colorbar(colormap, normalize) 
    combined_html = f"""
        <div class="output_flexbox">
            <div>{state_diagram_html}</div>
            <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
                <div>{transition_matrix_html}</div>
                <div>{colorbar_html}</div>
            </div>
        </div>
    """
    return combined_html

def process_sequence(current_P, P_true_str, sequence_length, lambda_, tau, state_HD):

    P_true = string_to_matrix(P_true_str)
    sequence = generate_sequence(states, P_true, length=sequence_length)

    P = current_P

    # Set up the colormap
    colormap = plt.cm.Blues
    normalize = mcolors.Normalize(vmin=0, vmax=1)
    colorbar_html = generate_colorbar(colormap, normalize)  # Generate the colorbar once

    hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]})
    hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")

    for s_count, (s_prev, s) in enumerate(pairwise(sequence)):

        if s_count == 0:
            P_prev = P.copy()
        elif s_count % tau == 0:
            hd = np.max(hellinger_distance(P_prev, P))
            state_HD.append(hd)
            P_prev = P.copy()
            hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]})
            hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")

        P = update_p(P, s_prev, s, lambda_)  # Update the transition matrix

        # Generate HTML for the current state and transition matrix
        state_diagram_html = generate_state_diagram_html(P)
        transition_matrix_html = generate_transition_matrix_html(
            P=P,
            colormap=colormap,
            normalize=normalize,
            ticks=state_names
        )
        
        # Combine matrix and colorbar HTML side-by-side in a flex container
        combined_html = f"""
            <div class="output_flexbox">
                <div>{state_diagram_html}</div>
                <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
                    <div>{transition_matrix_html}</div>
                    <div>{colorbar_html}</div>
                </div>
            </div>
        """
        
        # Yield the state diagram HTML and combined transition matrix + colorbar HTML
        yield combined_html, hd_plot

        time.sleep(0.005)  # Pause before the next state

css = """
    .output_flexbox{
        display: flex;
        align-items: center;
        justify-content:center;
        gap: 50px;
        flex-direction: row;
    }
    @media screen and (max-width: 800px){
        .output_flexbox{
            flex-direction: column;
        }
    }
"""

with gr.Blocks(css=css) as demo:

    # initial probability distribution is a uniform distribution, saved in a State object
    P0 = 1/num_states * np.ones(shape=np.repeat(num_states, 2))
    state_P = gr.State(P0)
    state_HD = gr.State([0])

    modes = [
        [
            [0.19764149, 0.15019737, 0.5511312 , 0.10102994],
            [0.00981767, 0.29327228, 0.01960321, 0.67730684],
            [0.73753637, 0.01300742, 0.16719004, 0.08226616],
            [0.59819186, 0.04902065, 0.35083554, 0.00195195]
        ],
        [
            [0.00739772, 0.05260542, 0.09302236, 0.8469745 ],
            [0.07016094, 0.79909194, 0.00816012, 0.122587  ],
            [0.11185849, 0.00218829, 0.86738849, 0.01856473],
            [0.03560788, 0.1767552 , 0.72055577, 0.06708115]
        ],
        [
            [0.02901062, 0.0548365 , 0.88122331, 0.03492957],
            [0.03180118, 0.08216069, 0.85603785, 0.03000028],
            [0.02280227, 0.09024757, 0.08233428, 0.80461588],
            [0.23941395, 0.05389086, 0.55260164, 0.15409355]
        ]
    ]

    with gr.Column():
        with gr.Column():
            target_p = gr.Textbox(
                value=matrix_to_string(modes[0]),
                label="Enter transition matrix or select a mode.",
                lines=4,
                placeholder="Enter the 4x4 matrix here...",
            )
            with gr.Row():
                mode_0 = gr.Button("Mode 0")
                mode_1 = gr.Button("Mode 1")
                mode_2 = gr.Button("Mode 2")
            with gr.Row():
                sequence_length = gr.Number(label="Sequence Length", value=500, minimum=0)
                lambda_ = gr.Number(label="Lambda", value=0.95, minimum=0, maximum=1)
                tau = gr.Number(label="Tau", value=25, minimum=0)
        run_btn = gr.Button("Run")
        html_output = gr.HTML(value=initial_html(P0), label="State Diagram and Transition Matrix")
        hd_data = pd.DataFrame({"time": list(range(len(state_HD.value))), "hd": [hd for hd in state_HD.value]})
        hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")

    run_btn.click(
        fn=process_sequence, 
        inputs=[state_P, target_p, sequence_length, lambda_, tau, state_HD], 
        outputs=[html_output, hd_plot]
    )
    mode_0.click(fn=lambda: matrix_to_string(modes[0]), inputs=None, outputs=target_p)
    mode_1.click(fn=lambda: matrix_to_string(modes[1]), inputs=None, outputs=target_p)
    mode_2.click(fn=lambda: matrix_to_string(modes[2]), inputs=None, outputs=target_p)

demo.launch(share=False)