Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import time | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import copy | |
| import scipy | |
| import scipy.signal | |
| from scipy.stats import norm | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from scipy.spatial.distance import jensenshannon | |
| from scipy.optimize import curve_fit | |
| import multiprocessing | |
| from multiprocessing import Pool, Queue, Manager | |
| plt.rcParams['figure.constrained_layout.use'] = True | |
| plt.rcParams['figure.max_open_warning'] = 10 | |
| matplotlib.rcParams['interactive'] = False | |
| g_st, g_et, g_num = -2.3, 2.3, 460 | |
| g_res = (g_et-g_st)/g_num | |
| g_fw, g_fh = 3, 3.2 | |
| ################################################################################### | |
| # common function | |
| ################################################################################### | |
| def rs(str_label): | |
| return str_label.replace("z_{0}", "x").replace("z_0", "x") | |
| def set_axis(axis, x_range, y_range, x_label, y_label): | |
| matplotlib.rcParams.update({'font.size': 10, "axes.linewidth": 0.5, "lines.linewidth": 0.7, "figure.dpi": 100}) | |
| if x_range is not None: | |
| axis.set_xlim(*x_range) | |
| if y_range is not None: | |
| axis.set_ylim(*y_range) | |
| if x_label is not None: | |
| axis.set_xlabel(x_label) | |
| if y_label is not None: | |
| axis.set_ylabel(y_label) | |
| axis.xaxis.set_major_locator(plt.MultipleLocator(1)) | |
| st, et = x_range[0]//0.2*0.2, x_range[1]//0.2*0.2 | |
| count = int((et - st)/0.2) | |
| axis.set_xticks(np.linspace(st, et, count+1), minor=True) | |
| return | |
| def plot_pdf(x, x_pdf, max_y=3.2, title=None, titlesize=10, | |
| label=None, xlabel="domain", ylabel="pdf", style="solid", color="blue"): | |
| fig = plt.figure(figsize=(g_fw, g_fh)) | |
| ax = fig.add_subplot(111) | |
| axis_pdf(ax, x, x_pdf, max_y, title, titlesize, label, xlabel, ylabel, style, color) | |
| return fig | |
| def plot_2d_pdf(x, y, pdf, cond_val=None, label=None, title=None, titlesize=10, xlabel="x", ylabel="y"): | |
| fig = plt.figure(figsize=(g_fw, g_fh)) | |
| ax = fig.add_subplot(111) | |
| axis_2d_pdf(ax, x, y, pdf, cond_val, title, titlesize, label, xlabel, ylabel) | |
| return fig | |
| def axis_pdf(ax, x, x_pdf, max_y=3.2, title=None, titlesize=10, | |
| label=None, xlabel="domain", ylabel="pdf", style="solid", color="blue"): | |
| set_axis(ax, (x[0], x[-1]), (0, max_y), xlabel, ylabel) | |
| ax.plot(x, x_pdf, label=label, color=color, linestyle=style) | |
| if title is not None: | |
| ax.set_title(title, fontsize=titlesize) | |
| ax.legend() | |
| return | |
| def axis_2d_pdf(ax, x, y, pdf, cond_val=None, title=None, titlesize=10, label=None, xlabel="x", ylabel="y"): | |
| set_axis(ax, (x[0], x[-1]), (y[0], y[-1]), xlabel, ylabel) | |
| ax.contourf(x, y, pdf, label=label) | |
| if title is not None: | |
| ax.set_title(title, fontsize=titlesize) | |
| if cond_val is not None: | |
| ax.plot([cond_val, cond_val], [y[-1], y[0]], color="orange") | |
| ax.legend() | |
| return | |
| def add_random_noise(x_pdf, noise_ratio, seed, st, et, num, res): | |
| _, noise_pdf = init_x_pdf(st, et, num, seed=seed) | |
| z_pdf = (1-noise_ratio)*x_pdf + noise_ratio*noise_pdf | |
| z_pdf = z_pdf/(res*z_pdf.sum()) | |
| return z_pdf | |
| def power_range(st, et, num, coeff=2): | |
| roi_nodes = st + np.ceil((np.linspace(0, 1, num=num) ** coeff) * (et - st)) | |
| roi_nodes = roi_nodes.astype(int) | |
| for ii in range(1, len(roi_nodes)): | |
| if roi_nodes[ii] <= roi_nodes[ii - 1]: | |
| roi_nodes[ii] = roi_nodes[ii - 1] + 1 | |
| roi_nodes[ii] = int(roi_nodes[ii]) | |
| return list(roi_nodes) | |
| def init_x_pdf(st, et, num, modal_count=16, shape_type=0, seed=200): | |
| rg = np.random.RandomState(int(seed)) | |
| C = modal_count | |
| res = (et - st) / num | |
| if shape_type == 0: | |
| mean_st, mean_et = -1.1, 1.1 | |
| std_st, std_et = 0.03, 0.20 | |
| elif shape_type == 1: | |
| mean_st, mean_et = -1.5, 1.5 | |
| std_st, std_et = 0.01, 0.09 | |
| elif shape_type == 2: | |
| mean_st, mean_et = -1.5, 1.5 | |
| std_st, std_et = 0.05, 0.35 | |
| else: | |
| mean_st, mean_et = -0.8, 0.8 | |
| std_st, std_et = 0.05, 0.35 | |
| mean = mean_st + rg.random(C) * (mean_et - mean_st) | |
| std = std_st + rg.random(C) * (std_et - std_st) | |
| weight = 1 + rg.random(C) * 10 | |
| weight = weight / weight.sum() | |
| x = np.linspace(st, et, num + 1, dtype=np.float64) | |
| x_pdf = np.zeros_like(x, dtype=np.float64) | |
| for i in range(C): | |
| # print("%+0.5f___%+0.5f___%+0.5f" % (mean[i], std[i], weight[i])) | |
| x_pdf += weight[i] * norm.pdf(x, mean[i], std[i]) | |
| x_pdf += 1E-8 | |
| x_pdf = x_pdf / (x_pdf * res).sum() # normalized to 1 | |
| return x, x_pdf | |
| def forward_next_pdf(x, x_pdf, alpha, res): | |
| ''' | |
| x : input domain | |
| x_pdf : input pdf of continual variable | |
| res : resolution of x's domain | |
| Two ways to understand normalizing to 1: | |
| convert to discrete variable and summarize | |
| Approximate integral for continual variable | |
| ''' | |
| if np.isclose(alpha, 1.0): | |
| return x, x_pdf, None, None, None | |
| y = copy.deepcopy(x) | |
| xy_pdf = np.zeros([*x.shape, *y.shape], dtype=np.float64) | |
| for i in range(len(x)): | |
| p_x = x_pdf[i] | |
| mu = x[i] * np.sqrt(alpha) | |
| std = np.sqrt(1 - alpha) | |
| p_y__x = norm.pdf(y, mu, std) | |
| p_y__x = p_y__x/(p_y__x*res).sum() | |
| # this will cause posterior distortion in the near zero area | |
| # p_y__x += 1E-8 | |
| # p_y__x = p_y__x / (p_y__x * res).sum() # normalize to 1 | |
| xy_pdf[i] = p_x * p_y__x | |
| xy_pdf = xy_pdf / (xy_pdf * res * res).sum() # normalize to 1 | |
| y_pdf = (xy_pdf * res).sum(axis=0) | |
| xcy_pdf = xy_pdf / (y_pdf[None, :] + 1E-10) | |
| ycx_pdf = xy_pdf / (x_pdf[:, None] + 1E-10) | |
| return y, y_pdf, xy_pdf, xcy_pdf, ycx_pdf | |
| ################################################################################### | |
| # transform block function | |
| ################################################################################### | |
| def shrink(x, x_pdf, alpha, st, res): | |
| ''' | |
| x : input domain | |
| x_pdf : input pdf of continual variable | |
| function : y = sqrt(\alpha) * x | |
| inverse function : x = y / sqrt(\alpha) | |
| derivative : y'= sqrt(\alpha) | |
| ''' | |
| # y's domain is the sample as x | |
| y = copy.deepcopy(x) | |
| shrink_pdf = np.zeros_like(x_pdf, dtype=np.float64) | |
| sqrt_alpha = np.sqrt(alpha) | |
| for i in range(len(y)): | |
| # get corresponding x by inverse function | |
| idx = int((y[i] / sqrt_alpha - st) / res) | |
| if idx < 0 or idx >= len(x_pdf): | |
| continue | |
| # scale with the reciprocal of derivative of y | |
| shrink_pdf[i] = (1 / sqrt_alpha) * x_pdf[idx] | |
| return shrink_pdf | |
| def conv(x, x_pdf, alpha, res): | |
| # gauss_pdf is continual random variable pdf | |
| gauss_pdf = norm.pdf(x, 0, np.sqrt(1 - alpha)) | |
| # convert to discrete probability by multiplying with res, and convert back to continual by dividing res | |
| out_pdf = scipy.signal.convolve(x_pdf * res, gauss_pdf * res, "same") / res | |
| return out_pdf | |
| def shrink_conv(x, x_pdf, shrink_alpha, conv_alpha, st, res): | |
| # linear transform | |
| shrink_pdf = shrink(x, x_pdf, shrink_alpha, st, res) | |
| # add independent noises, that is equivalent to convolution | |
| conv_pdf = conv(x, shrink_pdf, conv_alpha, res) | |
| return conv_pdf | |
| def plot_init_pdf(seed, st, et, num): | |
| x, x_pdf = init_x_pdf(st, et, num, shape_type=0, seed=seed) | |
| fig = plot_pdf(x, x_pdf, label="x", title="input variable's pdf") | |
| fig.axes[0].title.set_size(9) | |
| return fig, x, x_pdf | |
| def plot_shrink_pdf(x, x_pdf, alpha, st, res): | |
| if x is None or x_pdf is None: | |
| return None | |
| shrink_pdf = shrink(x, x_pdf, alpha, st, res) | |
| fig = plot_pdf(x, shrink_pdf, label=r"$y=\sqrt{\alpha}x$", title="pdf after linear transform", titlesize=9) | |
| return fig | |
| def plot_conv_pdf(x, x_pdf, alpha, res): | |
| if x is None or x_pdf is None: | |
| return None | |
| conv_pdf = conv(x, x_pdf, alpha, res) | |
| fig = plot_pdf(x, conv_pdf, label=r"$y=x+\sqrt{1-\alpha}\epsilon$", title="pdf after add noises", titlesize=9) | |
| return fig | |
| def plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, st, res): | |
| if x is None or x_pdf is None: | |
| return None | |
| shrink_conv_pdf = shrink_conv(x, x_pdf, shrink_alpha, conv_alpha, st, res) | |
| title = r"pdf after two sub transforms" | |
| label = r"$y=\sqrt{\alpha_s}x + \sqrt{1-\alpha_e}\epsilon$" | |
| fig = plot_pdf(x, shrink_conv_pdf, label=label, title=title, titlesize=9) | |
| return fig | |
| def init_change(seed, shrink_alpha, conv_alpha): | |
| global g_st, g_et, g_num, g_res | |
| init_fig, x, x_pdf = plot_init_pdf(seed, g_st, g_et, g_num) | |
| shrink_fig = plot_shrink_pdf(x, x_pdf, shrink_alpha, g_st, g_res) | |
| conv_fig = plot_conv_pdf(x, x_pdf, conv_alpha, g_res) | |
| shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res) | |
| return init_fig, x, x_pdf, shrink_fig, conv_fig, shrink_conv_fig | |
| def shrink_change(x, x_pdf, shrink_alpha, conv_alpha): | |
| global g_st, g_et, g_num, g_res | |
| shrink_fig = plot_shrink_pdf(x, x_pdf, shrink_alpha, g_st, g_res) | |
| shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res) | |
| return shrink_fig, shrink_conv_fig | |
| def conv_change(x, x_pdf, shrink_alpha, conv_alpha): | |
| global g_st, g_et, g_num, g_res | |
| conv_fig = plot_conv_pdf(x, x_pdf, conv_alpha, g_res) | |
| shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res) | |
| return conv_fig, shrink_conv_fig | |
| ################################################################################### | |
| # cond prob block function | |
| ################################################################################### | |
| def cond_prob_init_change(seed, alpha, cond_val): | |
| global g_st, g_et, g_num, g_res | |
| x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=0, seed=seed) | |
| x_pdf = hijack(seed, x, x_pdf) | |
| fig_x = plot_pdf(x, x_pdf, xlabel="x domain", ylabel="pdf", title="input variable's pdf", titlesize=9) | |
| outputs = cond_prob_alpha_change(x, x_pdf, alpha, cond_val) | |
| z, zcx_pdf, fig_z, fig_zcx, fig_xcz, fig_fix_xcz = outputs | |
| return x, x_pdf, z, zcx_pdf, fig_x, fig_z, fig_zcx, fig_xcz, fig_fix_xcz | |
| def cond_prob_alpha_change(x, x_pdf, alpha, cond_val): | |
| forward_info = forward_next_pdf(x, x_pdf, alpha, g_res) | |
| z, z_pdf, xz_pdf, xcz_pdf, zcx_pdf = forward_info | |
| label = r"$z=\sqrt{\alpha}x + \sqrt{1-\alpha}\epsilon$" | |
| input_title = r"output variable's pdf" | |
| fore_cond_title = r"forward conditional pdf" | |
| fig_z = plot_pdf(z, z_pdf, label=label, title=input_title, titlesize=9, xlabel="z domain", ylabel="pdf") | |
| fig_zcx = plot_2d_pdf(x, z, zcx_pdf.transpose(), label="$q(z|x)$", | |
| title=fore_cond_title, titlesize=9, xlabel="x domain(cond)", ylabel="z domain") | |
| ret_fig = cond_prob_cond_change(x, x_pdf, z, xcz_pdf, alpha, cond_val) | |
| fig_xcz, fig_fix_xcz = ret_fig | |
| return z, xcz_pdf, fig_z, fig_zcx, fig_xcz, fig_fix_xcz | |
| def cond_prob_cond_change(x, x_pdf, z, xcz_pdf, alpha, cond_val): | |
| global g_st, g_et, g_num, g_res | |
| cond_idx = int((cond_val - g_st) / g_res) | |
| cond_pdf = xcz_pdf[:, cond_idx] | |
| back_cond_title = "backward conditional pdf" | |
| fig_xcz = plot_2d_pdf(x, z, xcz_pdf, cond_val, label="$q(x|z)$", | |
| title=back_cond_title, xlabel="z domain(cond)", ylabel="x domain") | |
| fig_xcz.axes[0].title.set_size(9) | |
| gauss = norm.pdf(x, cond_val / np.sqrt(alpha), np.sqrt((1 - alpha) / alpha)) | |
| fixed_back_cond_title = "posterior with fixed condition" | |
| fig_fix_xcz = plt.figure(figsize=(g_fw, g_fh)) | |
| ax = fig_fix_xcz.add_subplot(111) | |
| axis_pdf(ax, x, gauss, max_y=5, label="$gauss$", style="dashed", color="green") | |
| axis_pdf(ax, x, x_pdf, max_y=5, label="$q(x)$", style="dashed", color="blue") | |
| axis_pdf(ax, x, cond_pdf, max_y=5, label="$q(x|z=%s)$" % cond_val, | |
| title=fixed_back_cond_title, titlesize=9, xlabel="x domain", color="orange") | |
| handles, labels = ax.get_legend_handles_labels() | |
| ax.add_artist(ax.legend(handles[:2], labels[:2], handlelength=0.8, loc="upper left")) | |
| ax.add_artist(ax.legend(handles[2:], labels[2:], handlelength=0.8, loc="upper right")) | |
| return fig_xcz, fig_fix_xcz | |
| ################################################################################### | |
| # forward block function | |
| ################################################################################### | |
| def plot_first_pdf(x, x_pdf, ax): | |
| title, label = r"origin var pdf", r"forward q(x)", | |
| xlabel = rs(r"x domain") | |
| axis_pdf(ax, x, x_pdf, title=title, label=label, xlabel=xlabel, ylabel="pdf", color="blue") | |
| ax.legend(handlelength=1.2, labels=[label]) | |
| return | |
| def forward_init_change(seed): | |
| global g_st, g_et, g_num, g_res | |
| x, x_pdf = init_x_pdf(g_st, g_et, g_num, seed=seed) | |
| x_pdf = hijack(seed, x, x_pdf) | |
| fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(8 * g_fw, 1 * g_fh)) | |
| axes = axes.flatten() | |
| plot_first_pdf(x, x_pdf, axes[0]) | |
| return x, x_pdf, fig, None | |
| def plot_forward_pdf(axes, seq_info, color, pidx=-1): | |
| count = len(seq_info) | |
| step = int(count/3+1) | |
| if pidx >= 0: | |
| st, et = pidx*step, (pidx+1)*step | |
| seq_info = seq_info[st:et] | |
| for info in seq_info: | |
| _, _, nz, nz_pdf, _, cidx, nidx, alpha = info | |
| if nidx == 0: | |
| title, label = "origin var pdf", r"forward $q(x)$", | |
| else: | |
| title, label = rs(r"$q(z_{%d})\ \alpha=%0.3f$"%(nidx, alpha)), r"forward $q(z_{%d})$"%nidx | |
| xlabel = rs(r"$z_{%d}\ domain$"%nidx) | |
| axis_pdf(axes[nidx], nz, nz_pdf, title=title, label=label, xlabel=xlabel, ylabel="pdf", color=color) | |
| axes[nidx].legend(handlelength=1.2) | |
| if nidx == (count-1): | |
| axes[count-1].plot(nz, norm.pdf(nz, 0, 1), label=r"$\mathcal{N}\/(0, 1)$", color="green") | |
| axes[count-1].legend() | |
| return | |
| def plot_backward_pdf(axes, fore_seq_info, back_seq_info, label_prefix, res, color, pidx=-1): | |
| count = len(fore_seq_info) | |
| step = int(count/3 + 1) | |
| if pidx >= 0: | |
| st, et = (2-pidx)*step, (2-pidx+1)*step # reverse | |
| fore_seq_info, back_seq_info = fore_seq_info[st:et], back_seq_info[st:et] | |
| for fore_info, back_info in zip(fore_seq_info, back_seq_info): | |
| fore_nz_pdf, back_nz_pdf = fore_info[3], back_info[3] | |
| nz, nidx = fore_info[2], fore_info[6] | |
| div = jensenshannon(back_nz_pdf*res, fore_nz_pdf*res) | |
| name = r"$\mathcal{N}\/(0,1)$" if nidx == count-1 else "revert" # specific name at end point | |
| label = rs(label_prefix + name + " div=%0.2f"%div) | |
| xlabel = rs(r"$z_{%d}\ domain$" % nidx) | |
| axis_pdf(axes[nidx], nz, back_nz_pdf, label=label, xlabel=xlabel, ylabel="pdf", color=color) | |
| axes[nidx].legend(handlelength=1.2) | |
| return | |
| def plot_backward_cond_pdf(axes, seq_info, reverse=True, pidx=-1): | |
| count = len(seq_info) | |
| step = int(count/3+1) | |
| if pidx >= 0: | |
| st, et = ((2-pidx)*step, (2-pidx+1)*step) if reverse else (pidx*step, (pidx+1)*step) | |
| seq_info = seq_info[st:et] | |
| for info in seq_info: | |
| cz, cz_pdf, nz, nz_pdf, bc_pdf, cidx, nidx, alpha = info | |
| if bc_pdf is None: | |
| continue | |
| title = rs(r"$q(z_{%d}|z_{%d})\ \alpha=%0.3f$" % (cidx, nidx, alpha)) | |
| xlabel, ylabel = rs(r"$z_{%d}$" % nidx), rs(r"$z_{%d}$" % cidx) | |
| axis_2d_pdf(axes[nidx], cz, nz, bc_pdf, title=title, xlabel=xlabel, ylabel=ylabel) | |
| return | |
| def get_back_seq_info(ez, ez_pdf, fore_seq_info, res): | |
| back_seq_info = copy.deepcopy(fore_seq_info) | |
| count = len(back_seq_info) | |
| nz, nz_pdf = ez, ez_pdf | |
| for ii in reversed(range(count)): | |
| bc_pdf = back_seq_info[ii][4] | |
| if bc_pdf is None: | |
| back_seq_info[ii][2:4] = nz, nz_pdf | |
| continue | |
| cz_pdf = np.matmul(bc_pdf, nz_pdf[:, None]) * res | |
| cz, cz_pdf = nz, cz_pdf.flatten() | |
| back_seq_info[ii][:4] = cz, cz_pdf, nz, nz_pdf | |
| nz, nz_pdf = cz, cz_pdf | |
| return back_seq_info | |
| def forward_seq_apply(x, x_pdf, st_alpha, et_alpha, step): | |
| global g_st, g_et, g_num, g_res | |
| if x_pdf is None: | |
| return None, None, None, None | |
| alphas = np.linspace(st_alpha, et_alpha, step) | |
| col_count = 8 | |
| row_count = int(np.ceil((step+1)/8)) | |
| fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
| axes = axes.flatten() | |
| pos_fig, pos_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
| pos_axes = pos_axes.flatten() | |
| # plot_first_pdf(x, x_pdf, fig, axes[0]) | |
| seq_info = [[None, None, x, x_pdf, None, -1, 0, None]] | |
| cz, cz_pdf = x, x_pdf | |
| for ii, alpha in enumerate(alphas): | |
| forward_info = forward_next_pdf(cz, cz_pdf, alpha, g_res) | |
| nz, nz_pdf, joint_pdf, bc_pdf, fc_pdf = forward_info | |
| cidx, nidx = ii, ii+1 | |
| # title, label = r"$q(z_%d)\ \alpha=%0.3f$"%(nidx, alpha), r"$q(z_%d)$"%nidx | |
| # axis_pdf(axes[nidx], nz, nz_pdf, title=title, label=label, xlabel=r"$z_{%d}\ domain$"%nidx, ylabel="pdf") | |
| # bc_label = rs(r"$q(z_%d|z_%d)\ \alpha=%0.3f$"%(cidx, nidx, alpha)) | |
| # bc_xlabel, bc_ylabel = rs(r"$z_%d$"%nidx), rs(r"$z_%d$"%cidx) | |
| # axis_2d_pdf(back_axes[nidx], cz, nz, bc_pdf, label=bc_label, xlabel=bc_xlabel, ylabel=bc_ylabel) | |
| seq_info.append([cz, cz_pdf, nz, nz_pdf, bc_pdf, cidx, nidx, alpha]) | |
| cz, cz_pdf = nz, nz_pdf | |
| # plot_forward_pdf(axes, seq_info, "blue") | |
| # plot_backward_bc_pdf(back_axes, seq_info) | |
| # fig.tight_layout() | |
| # back_fig.tight_layout() | |
| forward_plot_state = fig, axes, pos_fig, pos_axes, seq_info, g_res, "blue" | |
| return seq_info, forward_plot_state | |
| def forward_plot_part(plot_state, pidx): | |
| if plot_state is None: | |
| return None, None | |
| fig, axes, back_fig, pos_axes, seq_info, res, color = plot_state | |
| plot_forward_pdf(axes, seq_info, color, pidx) | |
| plot_backward_cond_pdf(pos_axes, seq_info, False, pidx) | |
| # fig.tight_layout() | |
| # back_fig.tight_layout() | |
| return fig, back_fig | |
| def backward_seq_apply(fore_seq_info, is_forward_pdf, is_backward_pdf, noise_seed, noise_ratio): | |
| global g_st, g_et, g_num, g_res | |
| if fore_seq_info is None: | |
| return None, None | |
| col_count = 8 | |
| step = len(fore_seq_info)-1 | |
| row_count = int(np.ceil((step+1)/8)) | |
| fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
| axes = axes.flatten() | |
| x, x_pdf = fore_seq_info[0][2:4] | |
| if is_forward_pdf: | |
| plot_forward_pdf(axes, fore_seq_info, "blue") | |
| ez, ez_pdf = fore_seq_info[-1][2], norm.pdf(x, 0, 1) | |
| std_back_seq_info, noise_back_seq_info = None, None | |
| if is_backward_pdf: | |
| # plot_backward_pdf(axes, ez, ez_pdf, fore_seq_info, g_res, "std ", color="green") | |
| std_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res) | |
| if noise_ratio > 0: | |
| ez_pdf = add_random_noise(ez_pdf, noise_ratio, noise_seed, g_st, g_et, g_num, g_res) | |
| # plot_backward_pdf(axes, ez, ez_pdf, fore_seq_info, g_res, "noise ", color="red") | |
| noise_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res) | |
| # fig.tight_layout() | |
| plot_state = fig, axes, fore_seq_info, std_back_seq_info, noise_back_seq_info, g_res | |
| return fig, plot_state | |
| def backward_plot_part(plot_state, pidx=-1): | |
| if plot_state is None: | |
| return None | |
| fig, axes, fore_seq_info, std_back_seq_info, noise_back_seq_info, res = plot_state | |
| if std_back_seq_info is not None: | |
| plot_backward_pdf(axes, fore_seq_info, std_back_seq_info, "std ", res, "green", pidx) | |
| if noise_back_seq_info is not None: | |
| plot_backward_pdf(axes, fore_seq_info, noise_back_seq_info, "noise ", res, "red", pidx) | |
| return fig | |
| def fit_pos_with_gauss(idx, x, bc_pdf, queue): | |
| # bc_pdf = copy.deepcopy(bc_pdf) | |
| for ii in range(bc_pdf.shape[1]): | |
| # guess = bc_pdf[:, ii].mean() | |
| (mu, std), _ = curve_fit(norm.pdf, x, bc_pdf[:, ii], p0=[0, 1]) | |
| bc_pdf[:, ii] = norm.pdf(x, mu, std) | |
| # queue.put((idx, bc_pdf)) | |
| return bc_pdf | |
| def seq_fit_pos_with_gauss(fore_seq_info): | |
| fit_seq_info = copy.deepcopy(fore_seq_info) | |
| # queue = Manager().Queue() | |
| # ls_param = [] | |
| threads = [] | |
| for ii in range(len(fit_seq_info)): | |
| x, _, _, _, bc_pdf = fit_seq_info[ii][:5] | |
| if bc_pdf is None: | |
| continue | |
| # os.system("echo hihi") | |
| # thrd = Thread(target=fit_pos_with_gauss, args=(ii, x, bc_pdf, None)) | |
| # threads.append(thrd) | |
| fit_seq_info[ii][4] = fit_pos_with_gauss(ii, x, bc_pdf, None) | |
| # ls_param.append((ii, x, bc_pdf, None)) | |
| # for thrd in threads: | |
| # thrd.start() | |
| # for thrd in threads: | |
| # thrd.join() | |
| # with Pool(6) as pool: | |
| # pool.starmap(fit_pos_with_gauss, ls_param) | |
| # | |
| # for ii in range(queue.qsize()): | |
| # idx, bc_pdf = queue.get() | |
| # seq_info[idx][4] = bc_pdf | |
| # with WorkerPool(n_jobs=5) as pool: | |
| # results = pool.map(fit_pos_with_gauss, ls_param) | |
| return fit_seq_info | |
| def fit_and_backward_apply(fore_seq_info, is_forward_pdf, is_backward_pdf): | |
| global g_st, g_et, g_num, g_res | |
| if fore_seq_info is None: | |
| return None, None, None | |
| col_count = 8 | |
| step = len(fore_seq_info)-1 | |
| row_count = int(np.ceil((step+1) / 8)) | |
| fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
| axes = axes.flatten() | |
| pos_fig, pos_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
| pos_axes = pos_axes.flatten() | |
| x, x_pdf = fore_seq_info[0][2:4] | |
| # axis_pdf(axes[0], x, x_pdf, title="origin var pdf $q(x)$", label="forward", xlabel="x domain", ylabel="pdf") | |
| if is_forward_pdf: | |
| plot_forward_pdf(axes, fore_seq_info, "blue") | |
| ez, ez_pdf = fore_seq_info[-1][2], norm.pdf(x, 0, 1) | |
| # axes[step].plot(ez, ez_pdf, label="$\mathcal{N}\/(0, 1)$", color="green") | |
| if is_backward_pdf: | |
| std_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res) | |
| plot_backward_pdf(axes, fore_seq_info, std_back_seq_info, "std ", g_res, "green") | |
| fit_back_seq_info = seq_fit_pos_with_gauss(fore_seq_info) | |
| fit_back_seq_info = get_back_seq_info(ez, ez_pdf, fit_back_seq_info, g_res) | |
| # plot_backward_pdf(axes, ez, ez_pdf, fit_seq_info, g_res, "fit ", color="orange") | |
| # plot_backward_bc_pdf(back_axes, seq_info) | |
| # fig.tight_layout() | |
| # back_fig.tight_layout() | |
| fit_plot_state = fig, axes, pos_fig, pos_axes, fore_seq_info, fit_back_seq_info, g_res | |
| return fig, pos_fig, fit_plot_state | |
| def fit_plot_part(plot_state, is_show_pos, pidx=-1): | |
| if plot_state is None: | |
| return None, None | |
| fig, axes, back_fig, back_axes, fore_seq_info, fit_back_seq_info, res = plot_state | |
| plot_backward_pdf(axes, fore_seq_info, fit_back_seq_info, "fit ", res, "orange", pidx) | |
| if is_show_pos: | |
| plot_backward_cond_pdf(back_axes, fit_back_seq_info, True, pidx) | |
| # back_fig.tight_layout() | |
| return fig, back_fig | |
| ################################################################################### | |
| # contraction block function | |
| ################################################################################### | |
| def contraction_init_change(seed, alpha, two_inputs_seed): | |
| global g_st, g_et, g_num, g_res | |
| rg = np.random.RandomState(int(seed)) | |
| shape_type = rg.randint(0, 4) | |
| x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=shape_type, seed=seed) | |
| x_pdf = hijack(seed, x, x_pdf) | |
| # test | |
| # x_pdf[x_pdf < 0.01] = 0 | |
| x_pdf = x_pdf / (x_pdf * g_res).sum() # normalized to 1 | |
| fig = plot_pdf(x, x_pdf, title="input variable pdf", titlesize=9) | |
| info = contraction_alpha_change(x, x_pdf, alpha, two_inputs_seed) | |
| fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2 = info | |
| return fig, x, x_pdf, fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2 | |
| def contraction_alpha_change(x, x_pdf, alpha, two_inputs_seed): | |
| global g_st, g_et, g_num, g_res | |
| forward_info = forward_next_pdf(x, x_pdf, alpha, g_res) | |
| z, z_pdf, xz_pdf, xcz_pdf, zcx_pdf = forward_info | |
| label = r"$z=\sqrt{\alpha}x + \sqrt{1-\alpha}\epsilon$" | |
| z_title = r"output variable pdf" | |
| xcz_title = r"posterior pdf" | |
| fig_z = plot_pdf(z, z_pdf, label=label, title=z_title, titlesize=9, xlabel="z domain", ylabel="pdf") | |
| fig_xcz = plot_2d_pdf(x, z, xcz_pdf, None, label="$q(x|z)$", | |
| title=xcz_title, titlesize=9, xlabel="z domain(cond)", ylabel="x domain") | |
| xcz = xcz_pdf/xcz_pdf.sum(axis=0, keepdims=True) | |
| evals = np.linalg.eigvals(xcz) | |
| evals = sorted(np.absolute(evals), reverse=True) | |
| lambda_2 = evals[1] | |
| fig_inp_out = contraction_apply(x, x_pdf, xcz_pdf, two_inputs_seed) | |
| return fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2 | |
| def change_two_inputs_seed(): | |
| seed = random.randint(0, 1E6) | |
| return seed | |
| def contraction_apply(x, x_pdf, bc_pdf, seed): | |
| global g_st, g_et, g_num, g_res | |
| rg = np.random.RandomState(int(seed)) | |
| modals = [1, 2, 8, 12, 16, 16, 16, 16, 16, 20] | |
| count1, count2 = rg.choice(modals), rg.choice(modals) | |
| seed1, seed2 = rg.randint(0, 1E6, 2) | |
| shape1, shape2 = rg.randint(0, 4, 2) | |
| z1, z1_pdf = init_x_pdf(g_st, g_et, g_num, count1, shape_type=shape1, seed=seed1) | |
| z2, z2_pdf = init_x_pdf(g_st, g_et, g_num, count2, shape_type=shape2, seed=seed2) | |
| div_z = jensenshannon(z1_pdf*g_res, z2_pdf*g_res) | |
| x1_pdf = np.matmul(bc_pdf, z1_pdf[:, None]) * g_res | |
| x2_pdf = np.matmul(bc_pdf, z2_pdf[:, None]) * g_res | |
| x1_pdf, x2_pdf = x1_pdf.flatten(), x2_pdf.flatten() | |
| div_x = jensenshannon(x1_pdf*g_res, x2_pdf*g_res) | |
| div_in_label, div_out_label = r"$div_{in}=%0.3f$"%div_z, r"$div_{out}=%0.3f$"%div_x | |
| fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(2*g_fw, 1*g_fh)) | |
| axis_pdf(axes[0], z1, z1_pdf, max_y=3.8, label="input1", | |
| title="two random input", titlesize=9, xlabel="z domain", ylabel="pdf", color="orange") | |
| axis_pdf(axes[0], z2, z2_pdf, max_y=3.8, label="input2", xlabel="z domain", ylabel="pdf", color="green") | |
| axes[0].plot([], [], label=div_in_label, color="blue") | |
| handles, labels = axes[0].get_legend_handles_labels() | |
| axes[0].add_artist(axes[0].legend(handles[:2], labels[:2], handlelength=1.0, loc="upper left")) | |
| axes[0].add_artist(axes[0].legend(handles[2:], labels[2:], handlelength=0, loc="upper right")) | |
| # axis_pdf(axes[1], x, x_pdf, max_y=3.8, title="two output", titlesize=9, style="dotted", color="blue") | |
| axis_pdf(axes[1], z1, x1_pdf, max_y=3.8, label="output1", | |
| title="two output", titlesize=9, xlabel="x domain", ylabel="pdf", color="orange") | |
| axis_pdf(axes[1], z2, x2_pdf, max_y=3.8, label="output2", xlabel="x domain", ylabel="pdf", color="green") | |
| axes[1].plot([], [], label=div_out_label, color="blue") | |
| handles, labels = axes[1].get_legend_handles_labels() | |
| axes[1].add_artist(axes[1].legend(handles[:2], labels[:2], handlelength=1.0, loc="upper left")) | |
| axes[1].add_artist(axes[1].legend(handles[2:], labels[2:], handlelength=0, loc="upper right")) | |
| # fig.tight_layout() | |
| return fig | |
| def fixed_point_init_change(seed, x, x_pdf): | |
| rg = np.random.RandomState(int(seed)) | |
| shape_type = rg.randint(0, 4) | |
| count = rg.choice([1, 2, 8, 12, 16, 16, 16, 16, 16, 20]) | |
| z, z_pdf = init_x_pdf(g_st, g_et, g_num, modal_count=count, shape_type=shape_type, seed=seed) | |
| div = jensenshannon(z_pdf*g_res, x_pdf*g_res) | |
| fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(8*g_fw, 1*g_fh)) | |
| axes = axes.flatten() | |
| axis_pdf(axes[0], x, x_pdf, label="converging pdf", color="blue") | |
| axis_pdf(axes[0], z, z_pdf, title="random input of inverse transform", label="random input", color="green") | |
| axes[0].plot([], [], label="div=%0.2f"%div, color="orange") | |
| axes[0].legend(handlelength=1.2) | |
| # fig.tight_layout() | |
| return fig, z, z_pdf, None | |
| def matrix_power(in_mat, n): | |
| if n == 0: | |
| return np.eye(in_mat.shape[0]) | |
| temp_mat = matrix_power(in_mat, int(n / 2)) | |
| if n % 2 == 0: | |
| out_mat = np.matmul(temp_mat * 100, temp_mat * 100) / 10000 | |
| out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9) | |
| return out_mat | |
| else: | |
| out_mat = np.matmul(temp_mat * 100, temp_mat * 100) / 10000 | |
| out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9) | |
| out_mat = np.matmul(in_mat * 100, out_mat * 100) / 10000 | |
| out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9) | |
| return out_mat | |
| def fixed_point_apply_iterate(x, x_pdf, zt, zt_pdf, xcz_pdf, iterate_num, is_show_pow): | |
| global g_res | |
| if x_pdf is None or zt_pdf is None or xcz_pdf is None: | |
| return None, None, None | |
| col_count, max_row_count = 8, 3 | |
| max_ax_count = max_row_count*col_count - 1 | |
| ax_count = min(iterate_num, max_ax_count) | |
| row_count = int(np.ceil((ax_count+1)/col_count)) | |
| fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
| axes = axes.flatten() | |
| axis_pdf(axes[0], x, x_pdf, label="converging point", color="blue") | |
| axis_pdf(axes[0], zt, zt_pdf, title="random input", label="random input", color="green") | |
| div = jensenshannon(zt_pdf*g_res, x_pdf*g_res) | |
| axes[0].plot([], [], label="div=%0.2f"%div, color="green") | |
| axes[0].legend(handlelength=1.2) | |
| idxs = np.arange(iterate_num).tolist() | |
| if iterate_num > max_ax_count: | |
| idxs = np.arange(6).tolist() + power_range(6, iterate_num-1, max_ax_count-6, 2.5) | |
| pow_mats, pdfs = [], [] | |
| for ii, idx in enumerate(idxs): | |
| pow_idx, ax_idx = idx + 1, ii + 1 | |
| pow_mat = matrix_power(xcz_pdf*g_res, pow_idx) | |
| pz_pdf = np.matmul(pow_mat, zt_pdf[:, None]) | |
| pz, pz_pdf = zt, pz_pdf.flatten() | |
| pow_mats.append([pow_mat, pow_idx, ax_idx]) | |
| pdfs.append([x, x_pdf, pz_pdf, pow_idx, ax_idx]) | |
| pow_fig, pow_axes = None, None | |
| if is_show_pow: | |
| pow_fig, pow_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
| pow_axes = pow_axes.flatten() | |
| plot_state = (fig, pow_fig, axes, pow_axes, pdfs, pow_mats, g_res) | |
| return fig, pow_fig, plot_state | |
| def fixed_plot_part(plot_state, pidx): | |
| if plot_state is None: | |
| return None, None | |
| fig, pow_fig, axes, pow_axes, pdfs, pow_mats, res = plot_state | |
| step = int(len(pdfs)/3) + 1 | |
| roi_pdfs = pdfs[pidx*step: (pidx+1)*step] | |
| for pdf_info in roi_pdfs: | |
| x, x_pdf, pz_pdf, pow_idx, ax_idx = pdf_info | |
| axis_pdf(axes[ax_idx], x, x_pdf, label="converging pdf", color="blue") | |
| title = r"the %dth iterate" % pow_idx | |
| axis_pdf(axes[ax_idx], x, pz_pdf, title=title, label="transform result", color="green") | |
| div = jensenshannon(pz_pdf*res, x_pdf*res) | |
| axes[ax_idx].plot([], [], label="div=%0.3f"%div, color="green") | |
| axes[ax_idx].legend(handlelength=1.2) | |
| # fig.tight_layout() | |
| if pow_axes is None: | |
| return fig, None | |
| roi_pow_mats = pow_mats[pidx*step: (pidx+1)*step] | |
| for pow_info in roi_pow_mats: | |
| pow_mat, pow_idx, ax_idx = pow_info | |
| axis_2d_pdf(pow_axes[ax_idx], x, x, pow_mat, title="power(mat,%d)"%(pow_idx), xlabel="z", ylabel="x") | |
| # pow_fig.tight_layout() | |
| return fig, pow_fig | |
| def hijack(seed, x, x_pdf): | |
| if seed in [16002, 16003]: | |
| x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=2, seed=100) | |
| left, right = (-0.5, 0.5) if seed == 16002 else (-0.7, 0.7) | |
| mask = np.logical_and(x > left, x < right) | |
| x_pdf[mask] = 0 | |
| base = 17500 | |
| left, right = int(base+g_st*100), int(base+g_et*100) | |
| if seed in range(left, right): | |
| mu, std = g_st + (seed//10*10 - left)*0.01, (seed%10+1)*0.02 | |
| x_pdf = norm.pdf(x, mu, std) | |
| return x_pdf | |