kcoskun commited on
Commit
47ff8ac
·
verified ·
1 Parent(s): c62c865

initial commit

Browse files
Files changed (1) hide show
  1. app.py +362 -0
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import time
4
+ from itertools import pairwise
5
+
6
+ import gradio as gr
7
+ import matplotlib.colors as mcolors
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ # Define the transition matrix and state information
13
+ num_states = 4
14
+ states = np.arange(num_states)
15
+ state_names = [f"{i}" for i in states]
16
+
17
+ def generate_p(num_states):
18
+ rng = np.random.default_rng(42)
19
+ return rng.dirichlet(alpha=np.repeat(1, num_states), size=np.repeat(num_states, 1))
20
+
21
+ def generate_sequence(alphabet, P, length=10):
22
+ rng = np.random.default_rng(42)
23
+ sequence = rng.choice(a=alphabet, size=1).tolist()
24
+ for i in range(length - 1):
25
+ next_state = rng.choice(a=alphabet, p=P[sequence[-1]])
26
+ sequence.extend([next_state])
27
+ return sequence
28
+
29
+ def update_p(P, s_prev, s, lambda_=0.9):
30
+ P[s_prev,] *= lambda_
31
+ P[s_prev,][s] += (1 - lambda_)
32
+ return P
33
+
34
+ # compute the Hellinger distance between two probability distributions p and q
35
+ def hellinger_distance(p, q):
36
+ return np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q)) ** 2, axis=len(p.shape)-1))
37
+
38
+ # Generate a colorbar as an image
39
+ def generate_colorbar(colormap, normalize):
40
+ fig, ax = plt.subplots(figsize=(1, 4))
41
+ fig.subplots_adjust(left=0.5, right=0.6) # Adjust layout for a narrow colorbar
42
+ colorbar = plt.colorbar(
43
+ plt.cm.ScalarMappable(norm=normalize, cmap=colormap),
44
+ cax=ax,
45
+ orientation='vertical',
46
+ )
47
+ colorbar.set_label("Transition Probability", rotation=90, labelpad=15)
48
+
49
+ # Save colorbar as an image in memory
50
+ buf = io.BytesIO()
51
+ plt.savefig(buf, format="png", bbox_inches="tight", transparent=True)
52
+ buf.seek(0)
53
+ base64_colorbar = base64.b64encode(buf.read()).decode("utf-8")
54
+ plt.close(fig) # Close the figure to free up memory
55
+ return f"<img src='data:image/png;base64,{base64_colorbar}' style='height:250px;'>"
56
+
57
+
58
+ # Helper function to generate HTML for state diagram with directed edges
59
+ def generate_state_diagram_html(P):
60
+
61
+ # Coordinates for the 2x2 grid positions of the states
62
+ state_positions = {
63
+ 0: (50, 50),
64
+ 1: (200, 50),
65
+ 2: (50, 200),
66
+ 3: (200, 200),
67
+ }
68
+
69
+ # Generate the SVG for edges between states based on the transition matrix
70
+ edges = ""
71
+ for i in range(num_states):
72
+ for j in range(num_states):
73
+ if P[i, j] > 0: # Only draw edge if probability > 0
74
+
75
+ # Get positions of the states (nodes)
76
+ x1, y1 = state_positions[i]
77
+ x2, y2 = state_positions[j]
78
+
79
+ # Transition weight determines the thickness of the edge
80
+ thickness = max(2, 5 * P[i, j]) # Set min thickness to 2
81
+
82
+ # Add slight offsets to avoid overlap
83
+ offset = 30
84
+ egap = 5
85
+
86
+ if i == 0 and j == 0:
87
+ edges += f"""
88
+ <path d="M{25} {40} C{0} {30} {30} {0} {40} {25}"
89
+ """
90
+ if i == 0 and j == 1:
91
+ edges += f"""
92
+ <path d="M{x1+offset} {y1-egap} L{x2-offset} {y2-egap}"
93
+ """
94
+ if i == 0 and j == 2:
95
+ edges += f"""
96
+ <path d="M{x1-egap} {y1+offset} L{x2-egap} {y2-offset}"
97
+ """
98
+ if i == 0 and j == 3:
99
+ edges += f"""
100
+ <path d="M{x1+offset+egap} {y1+offset-egap} L{x2-offset+egap} {y2-offset-egap}"
101
+ """
102
+ # ----------------
103
+ if i == 1 and j == 1:
104
+ edges += f"""
105
+ <path d="M225 40 C250 30 220 0 210 25"
106
+ """
107
+ if i == 1 and j == 0:
108
+ edges += f"""
109
+ <path d="M{x1-offset} {y1+egap} L{x2+offset} {y2+egap}"
110
+ """
111
+ if i == 1 and j == 2:
112
+ edges += f"""
113
+ <path d="M{x1-offset+egap} {y1+offset+egap} L{x2+offset+egap} {y2-offset+egap}"
114
+ """
115
+ if i == 1 and j == 3:
116
+ edges += f"""
117
+ <path d="M{x1+egap} {y1+offset} L{x2+egap} {y2-offset}"
118
+ """
119
+ # ----------------
120
+ if i == 2 and j == 0:
121
+ edges += f"""
122
+ <path d="M{x1+egap} {y1-offset} L{x2+egap} {y2+offset}"
123
+ """
124
+ if i == 2 and j == 1:
125
+ edges += f"""
126
+ <path d="M{x1+offset-egap} {y1-offset-egap} L{x2-offset-egap} {y2+offset-egap}"
127
+ """
128
+ if i == 2 and j == 2:
129
+ edges += f"""
130
+ <path d="M25 210 C0 220 30 250 40 225"
131
+ """
132
+ if i == 2 and j == 3:
133
+ edges += f"""
134
+ <path d="M{x1+offset} {y1-egap} L{x2-offset} {y2-egap}"
135
+ """
136
+ # ----------------
137
+ if i == 3 and j == 0:
138
+ edges += f"""
139
+ <path d="M{x1-offset-egap} {y1-offset+egap} L{x2+offset-egap} {y2+offset+egap}"
140
+ """
141
+ if i == 3 and j == 1:
142
+ edges += f"""
143
+ <path d="M{x1-egap} {y1-offset} L{x2-egap} {y2+offset}"
144
+ """
145
+ if i == 3 and j == 2:
146
+ edges += f"""
147
+ <path d="M{x1-offset} {y1+egap} L{x2+offset} {y2+egap}"
148
+ """
149
+ if i == 3 and j == 3:
150
+ edges += f"""
151
+ <path d="M225 210 C250 220 220 250 210 225"
152
+ """
153
+
154
+ edges += f"""
155
+ style="stroke: #F97316; stroke-width: {thickness}; opacity: {0.05+0.95*P[i, j]}; fill: none; marker-end: url(#arrowhead);" />
156
+ """
157
+
158
+ # Define the arrowhead marker
159
+ arrowhead_marker = """
160
+ <defs>
161
+ <marker id="arrowhead" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="5" markerHeight="10" orient="auto" markerUnits="strokeWidth">
162
+ <polygon points="0,0 10,5 0,10" fill="#F97316" />
163
+ </marker>
164
+ </defs>
165
+ """
166
+
167
+ # Create the state diagram (2x2 grid)
168
+ html_content = f"""
169
+ <svg width="250" height="250">
170
+ {arrowhead_marker}
171
+ {edges}
172
+ <!-- Nodes -->
173
+ <circle cx="50" cy="50" r="25" fill="{'lightgrey'}" />
174
+ <text x="50" y="50" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[0]}</text>
175
+
176
+ <circle cx="200" cy="50" r="25" fill="{'lightgrey'}" />
177
+ <text x="200" y="50" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[1]}</text>
178
+
179
+ <circle cx="50" cy="200" r="25" fill="{'lightgrey'}" />
180
+ <text x="50" y="200" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[2]}</text>
181
+
182
+ <circle cx="200" cy="200" r="25" fill="{'lightgrey'}" />
183
+ <text x="200" y="200" text-anchor="middle" dy="5" font-size="16" font-weight="bold" fill="black">{state_names[3]}</text>
184
+ </svg>
185
+ """
186
+ return html_content
187
+
188
+ # Helper function to generate HTML for transition matrix heatmap
189
+ def generate_transition_matrix_html(P, prob_threshold=0.5, normalize=None, colormap=None, ticks=None):
190
+ header_size = "30px" # Set a fixed size for headers
191
+ cell_size = "60px" # Set a fixed size for cells
192
+ html_content = f"<table style='border: none; margin-bottom: {header_size} !important;'>"
193
+
194
+ # Add xticks (column labels)
195
+ if ticks:
196
+ html_content += f"<tr style='border: none; height: {header_size};'>"
197
+ html_content += "<td style='border: none; padding: 0;'></td>" # Top left corner empty
198
+ for tick in ticks:
199
+ html_content += f"<td style='border: none; padding: 0; text-align: center; font-weight: bold;'>{tick}</td>"
200
+ html_content += "</tr>"
201
+
202
+ # Add the transition matrix rows
203
+ for i, row in enumerate(P):
204
+ html_content += "<tr style='border: none;'>"
205
+
206
+ # Add yticks (row labels)
207
+ if ticks:
208
+ html_content += f"<td style='border: none; padding: 0; text-align: center; font-weight: bold; width: {header_size};'>{ticks[i]}</td>"
209
+
210
+ for j, value in enumerate(row):
211
+ # Map the transition probability to a color intensity
212
+ rgba_color = colormap(normalize(value)) if normalize else (value, value, value, 1)
213
+ hex_color = mcolors.to_hex(rgba_color)
214
+ text_color = "white" if value > prob_threshold else "black"
215
+ html_content += f"<td style='width: {cell_size}; height: {cell_size}; padding: 0; background-color: {hex_color}; " \
216
+ f"color: {text_color}; text-align: center; border: none; font-size: 18px;'>" \
217
+ f"{value:.2f}</td>"
218
+ html_content += "</tr>"
219
+ html_content += "</table>"
220
+
221
+ return html_content
222
+
223
+ def matrix_to_string(matrix):
224
+ return "\n".join([" ".join(map(str, row)) for row in matrix])
225
+
226
+ def string_to_matrix(prob_str):
227
+ # Convert the input string into a 4x4 matrix
228
+ rows = prob_str.strip().split('\n') # Split into rows by newline
229
+ matrix = [list(map(float, row.strip().split(" "))) for row in rows] # Convert each row into a list of floats
230
+ matrix = np.array(matrix) # Convert to a numpy array for convenience
231
+ return matrix
232
+
233
+ def initial_html(P):
234
+ colormap = plt.cm.Blues
235
+ normalize = mcolors.Normalize(vmin=0, vmax=1)
236
+ state_diagram_html = generate_state_diagram_html(P)
237
+ transition_matrix_html = generate_transition_matrix_html(
238
+ P=P,
239
+ colormap=colormap,
240
+ normalize=normalize,
241
+ ticks=state_names
242
+ )
243
+ colorbar_html = generate_colorbar(colormap, normalize)
244
+ combined_html = f"""
245
+ <div style="display: flex; align-items: center; justify-content: center; gap: 50px;">
246
+ <div>{state_diagram_html}</div>
247
+ <div>{transition_matrix_html}</div>
248
+ <div>{colorbar_html}</div
249
+ </div>
250
+ """
251
+ return combined_html
252
+
253
+ def process_sequence(current_P, P_true_str, sequence_length, lambda_, tau, state_HD):
254
+
255
+ P_true = string_to_matrix(P_true_str)
256
+ sequence = generate_sequence(states, P_true, length=sequence_length)
257
+
258
+ P = current_P
259
+
260
+ # Set up the colormap
261
+ colormap = plt.cm.Blues
262
+ normalize = mcolors.Normalize(vmin=0, vmax=1)
263
+ colorbar_html = generate_colorbar(colormap, normalize) # Generate the colorbar once
264
+
265
+ hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]})
266
+ hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")
267
+
268
+ for s_count, (s_prev, s) in enumerate(pairwise(sequence)):
269
+
270
+ if s_count == 0:
271
+ P_prev = P.copy()
272
+ elif s_count % tau == 0:
273
+ hd = np.max(hellinger_distance(P_prev, P))
274
+ state_HD.append(hd)
275
+ P_prev = P.copy()
276
+ hd_data = pd.DataFrame({"time": list(range(len(state_HD))), "hd": [hd for hd in state_HD]})
277
+ hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")
278
+
279
+ P = update_p(P, s_prev, s, lambda_) # Update the transition matrix
280
+
281
+ # Generate HTML for the current state and transition matrix
282
+ state_diagram_html = generate_state_diagram_html(P)
283
+ transition_matrix_html = generate_transition_matrix_html(
284
+ P=P,
285
+ colormap=colormap,
286
+ normalize=normalize,
287
+ ticks=state_names
288
+ )
289
+
290
+ # Combine matrix and colorbar HTML side-by-side in a flex container
291
+ combined_html = f"""
292
+ <div style="display: flex; align-items: center; justify-content: center; gap: 50px;">
293
+ <div>{state_diagram_html}</div>
294
+ <div>{transition_matrix_html}</div>
295
+ <div>{colorbar_html}</div>
296
+ </div>
297
+ """
298
+
299
+ # Yield the state diagram HTML and combined transition matrix + colorbar HTML
300
+ yield combined_html, hd_plot
301
+
302
+ time.sleep(0.005) # Pause before the next state
303
+
304
+ with gr.Blocks() as demo:
305
+
306
+ # initial probability distribution is a uniform distribution, saved in a State object
307
+ P0 = 1/num_states * np.ones(shape=np.repeat(num_states, 2))
308
+ state_P = gr.State(P0)
309
+ state_HD = gr.State([0])
310
+
311
+ modes = [
312
+ [
313
+ [0.19764149, 0.15019737, 0.5511312 , 0.10102994],
314
+ [0.00981767, 0.29327228, 0.01960321, 0.67730684],
315
+ [0.73753637, 0.01300742, 0.16719004, 0.08226616],
316
+ [0.59819186, 0.04902065, 0.35083554, 0.00195195]
317
+ ],
318
+ [
319
+ [0.00739772, 0.05260542, 0.09302236, 0.8469745 ],
320
+ [0.07016094, 0.79909194, 0.00816012, 0.122587 ],
321
+ [0.11185849, 0.00218829, 0.86738849, 0.01856473],
322
+ [0.03560788, 0.1767552 , 0.72055577, 0.06708115]
323
+ ],
324
+ [
325
+ [0.02901062, 0.0548365 , 0.88122331, 0.03492957],
326
+ [0.03180118, 0.08216069, 0.85603785, 0.03000028],
327
+ [0.02280227, 0.09024757, 0.08233428, 0.80461588],
328
+ [0.23941395, 0.05389086, 0.55260164, 0.15409355]
329
+ ]
330
+ ]
331
+
332
+ with gr.Column():
333
+ with gr.Column():
334
+ target_p = gr.Textbox(
335
+ value=matrix_to_string(modes[0]),
336
+ label="Enter Transition Matrix (Rows separated by newlines, Values by spaces)",
337
+ lines=4,
338
+ placeholder="Enter the 4x4 matrix here...",
339
+ )
340
+ with gr.Row():
341
+ mode_0 = gr.Button("Mode 0")
342
+ mode_1 = gr.Button("Mode 1")
343
+ mode_2 = gr.Button("Mode 2")
344
+ with gr.Row():
345
+ sequence_length = gr.Number(label="Sequence Length", value=500, minimum=0)
346
+ lambda_ = gr.Number(label="Lambda", value=0.95, minimum=0, maximum=1)
347
+ tau = gr.Number(label="Tau", value=25, minimum=0)
348
+ run_btn = gr.Button("Run")
349
+ html_output = gr.HTML(value=initial_html(P0), label="State Diagram and Transition Matrix")
350
+ hd_data = pd.DataFrame({"time": list(range(len(state_HD.value))), "hd": [hd for hd in state_HD.value]})
351
+ hd_plot = gr.LinePlot(hd_data, x="time", y="hd", x_title="Time", y_title="Hellinger Distance")
352
+
353
+ run_btn.click(
354
+ fn=process_sequence,
355
+ inputs=[state_P, target_p, sequence_length, lambda_, tau, state_HD],
356
+ outputs=[html_output, hd_plot]
357
+ )
358
+ mode_0.click(fn=lambda: matrix_to_string(modes[0]), inputs=None, outputs=target_p)
359
+ mode_1.click(fn=lambda: matrix_to_string(modes[1]), inputs=None, outputs=target_p)
360
+ mode_2.click(fn=lambda: matrix_to_string(modes[2]), inputs=None, outputs=target_p)
361
+
362
+ demo.launch(share=False)