Commit ·
75536b4
1
Parent(s): 93ef35c
first commit
Browse files- app.py +907 -0
- requirements.txt +4 -0
app.py
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import itertools
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, Dict, List, Tuple
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ============================================================
|
| 13 |
+
# Rearrangement Algorithm visualizer
|
| 14 |
+
# ------------------------------------------------------------
|
| 15 |
+
# We represent a joint distribution by an n x d matrix X.
|
| 16 |
+
# Each column is one marginal sample. Rearrangement means:
|
| 17 |
+
# - values inside a column may be permuted;
|
| 18 |
+
# - the multiset in every column is unchanged;
|
| 19 |
+
# - hence all empirical marginal distributions are preserved.
|
| 20 |
+
# The algorithm decreases E[psi(X)] = E[f(sum_j X_j)]
|
| 21 |
+
# by rearranging columns while preserving empirical marginals.
|
| 22 |
+
# ============================================================
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class Step:
|
| 27 |
+
step: int
|
| 28 |
+
matrix: List[List[float]]
|
| 29 |
+
objective: float
|
| 30 |
+
action: str
|
| 31 |
+
column: int | None = None
|
| 32 |
+
before_order: List[int] | None = None
|
| 33 |
+
after_order: List[int] | None = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parse_matrix(text: str) -> np.ndarray:
|
| 37 |
+
"""Parse a matrix from Python-list/JSON-ish text or CSV-like rows."""
|
| 38 |
+
text = text.strip()
|
| 39 |
+
if not text:
|
| 40 |
+
raise ValueError("行列を入力してください。")
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
if text.startswith("["):
|
| 44 |
+
value = ast.literal_eval(text)
|
| 45 |
+
x = np.array(value, dtype=float)
|
| 46 |
+
else:
|
| 47 |
+
rows = []
|
| 48 |
+
for line in text.splitlines():
|
| 49 |
+
line = line.strip()
|
| 50 |
+
if not line:
|
| 51 |
+
continue
|
| 52 |
+
rows.append([float(v.strip()) for v in line.split(",") if v.strip()])
|
| 53 |
+
x = np.array(rows, dtype=float)
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
raise ValueError("行列として解釈できませんでした。例: [[1,4,7],[2,5,8],[3,6,9]]") from exc
|
| 56 |
+
|
| 57 |
+
if x.ndim != 2:
|
| 58 |
+
raise ValueError("2次元の行列を入力してください。")
|
| 59 |
+
if x.shape[0] < 2 or x.shape[1] < 2:
|
| 60 |
+
raise ValueError("少なくとも 2 行 2 列が必要です。")
|
| 61 |
+
if not np.isfinite(x).all():
|
| 62 |
+
raise ValueError("NaN や Inf は使えません。")
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def make_initial_matrix(n: int, d: int, seed: int, distribution: str) -> np.ndarray:
|
| 67 |
+
rng = np.random.default_rng(seed)
|
| 68 |
+
|
| 69 |
+
if distribution == "normal":
|
| 70 |
+
cols = [
|
| 71 |
+
np.sort(rng.normal(loc=0.0, scale=1.0 + 0.25 * j, size=n))
|
| 72 |
+
for j in range(d)
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
elif distribution == "uniform":
|
| 76 |
+
cols = [
|
| 77 |
+
np.sort(rng.uniform(-1.0 - j * 0.2, 1.0 + j * 0.2, size=n))
|
| 78 |
+
for j in range(d)
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
elif distribution == "lognormal":
|
| 82 |
+
cols = [
|
| 83 |
+
np.sort(rng.lognormal(mean=0.0, sigma=0.35 + 0.1 * j, size=n))
|
| 84 |
+
for j in range(d)
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
elif distribution == "integer":
|
| 88 |
+
cols = [
|
| 89 |
+
np.sort(
|
| 90 |
+
rng.integers(
|
| 91 |
+
low=0,
|
| 92 |
+
high=10 + 2 * j + 1,
|
| 93 |
+
size=n,
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
for j in range(d)
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError("未知の分布です。")
|
| 101 |
+
|
| 102 |
+
x = np.column_stack(cols)
|
| 103 |
+
|
| 104 |
+
# Randomly permute each column independently to create an initial coupling.
|
| 105 |
+
for j in range(d):
|
| 106 |
+
x[:, j] = rng.permutation(x[:, j])
|
| 107 |
+
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def matrix_to_text(x: np.ndarray) -> str:
|
| 112 |
+
return "\n".join(", ".join(f"{v:.4g}" for v in row) for row in x)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def build_f(name: str, theta: float, custom_expr: str) -> Tuple[Callable[[np.ndarray], np.ndarray], str]:
|
| 116 |
+
"""
|
| 117 |
+
Return f for
|
| 118 |
+
|
| 119 |
+
psi(x_1, ..., x_d) = f(sum_j x_j)
|
| 120 |
+
|
| 121 |
+
Input s is a length-n NumPy array of row sums.
|
| 122 |
+
Output must also be length n.
|
| 123 |
+
"""
|
| 124 |
+
if name == "square":
|
| 125 |
+
return lambda s: s**2, "ψ(x₁, ..., x_d) = f(Σxᵢ), f(s) = s²"
|
| 126 |
+
|
| 127 |
+
if name == "absolute":
|
| 128 |
+
return lambda s: np.abs(s), "ψ(x₁, ..., x_d) = f(Σxᵢ), f(s) = |s|"
|
| 129 |
+
|
| 130 |
+
if name == "exponential":
|
| 131 |
+
return lambda s: np.exp(theta * s), f"ψ(x₁, ..., x_d) = f(Σxᵢ), f(s) = exp({theta:g} · s)"
|
| 132 |
+
|
| 133 |
+
if name == "positive_part":
|
| 134 |
+
return lambda s: np.maximum(s, 0.0), "ψ(x₁, ..., x_d) = f(Σxᵢ), f(s) = max(s, 0)"
|
| 135 |
+
|
| 136 |
+
if name == "custom":
|
| 137 |
+
expr = custom_expr.strip()
|
| 138 |
+
if not expr:
|
| 139 |
+
raise ValueError("custom を使う場合は f(s) の式を入力してください。")
|
| 140 |
+
|
| 141 |
+
allowed = {
|
| 142 |
+
"np": np,
|
| 143 |
+
"abs": np.abs,
|
| 144 |
+
"sqrt": np.sqrt,
|
| 145 |
+
"exp": np.exp,
|
| 146 |
+
"log": np.log,
|
| 147 |
+
"maximum": np.maximum,
|
| 148 |
+
"minimum": np.minimum,
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def custom_f(s: np.ndarray) -> np.ndarray:
|
| 152 |
+
y = eval(expr, {"__builtins__": {}}, {**allowed, "s": s})
|
| 153 |
+
y = np.asarray(y, dtype=float)
|
| 154 |
+
if y.ndim == 0:
|
| 155 |
+
y = np.full(s.shape[0], float(y))
|
| 156 |
+
return y
|
| 157 |
+
|
| 158 |
+
return custom_f, f"ψ(x₁, ..., x_d) = f(Σxᵢ), f(s) = {expr}"
|
| 159 |
+
|
| 160 |
+
raise ValueError("未知の f です。")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def build_psi(name: str, theta: float, custom_expr: str) -> Tuple[Callable[[np.ndarray], np.ndarray], str]:
|
| 164 |
+
"""
|
| 165 |
+
Return row-wise psi with
|
| 166 |
+
|
| 167 |
+
psi(x_1, ..., x_d) = f(sum_j x_j).
|
| 168 |
+
|
| 169 |
+
Input X is n x d.
|
| 170 |
+
Output is length n.
|
| 171 |
+
"""
|
| 172 |
+
f, f_label = build_f(name, theta, custom_expr)
|
| 173 |
+
|
| 174 |
+
def psi(x: np.ndarray) -> np.ndarray:
|
| 175 |
+
s = np.sum(x, axis=1)
|
| 176 |
+
return f(s)
|
| 177 |
+
|
| 178 |
+
return psi, f_label
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def objective(x: np.ndarray, psi: Callable[[np.ndarray], np.ndarray]) -> float:
|
| 182 |
+
values = psi(x)
|
| 183 |
+
if values.shape[0] != x.shape[0]:
|
| 184 |
+
raise ValueError("ψ は各行ごとに1つの値を返す必要があります。")
|
| 185 |
+
if not np.isfinite(values).all():
|
| 186 |
+
raise ValueError("ψ の値に NaN または Inf が出ました。")
|
| 187 |
+
return float(np.mean(values))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def is_better(new_value: float, old_value: float, tol: float = 1e-12) -> bool:
|
| 191 |
+
return new_value < old_value - tol
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def greedy_sort_rearrangement(
|
| 195 |
+
x: np.ndarray,
|
| 196 |
+
psi_name: str,
|
| 197 |
+
theta: float,
|
| 198 |
+
custom_expr: str,
|
| 199 |
+
max_iter: int,
|
| 200 |
+
random_tie_break: bool,
|
| 201 |
+
seed: int,
|
| 202 |
+
) -> Tuple[List[Step], str]:
|
| 203 |
+
"""
|
| 204 |
+
Heuristic RA.
|
| 205 |
+
|
| 206 |
+
For one chosen column j, keep all other columns fixed.
|
| 207 |
+
We search for a permutation of column j that improves the selected objective.
|
| 208 |
+
|
| 209 |
+
The objective is to minimize E[psi(X)] = E[f(sum_j X_j)].
|
| 210 |
+
|
| 211 |
+
For common convex increasing-in-sum losses, the classical RA idea is to
|
| 212 |
+
put a selected column in opposite order to the partial row sum of the other
|
| 213 |
+
columns. For more general f, we use this as a proposal and also run a
|
| 214 |
+
small pairwise local improvement pass.
|
| 215 |
+
"""
|
| 216 |
+
psi, objective_label = build_psi(psi_name, theta, custom_expr)
|
| 217 |
+
objective_fn = lambda z: objective(z, psi)
|
| 218 |
+
rng = np.random.default_rng(seed)
|
| 219 |
+
x = x.copy()
|
| 220 |
+
n, d = x.shape
|
| 221 |
+
|
| 222 |
+
steps = [Step(0, x.tolist(), objective_fn(x), "初期カップリング")]
|
| 223 |
+
current = steps[-1].objective
|
| 224 |
+
|
| 225 |
+
for it in range(1, max_iter + 1):
|
| 226 |
+
improved_any = False
|
| 227 |
+
columns = list(range(d))
|
| 228 |
+
if random_tie_break:
|
| 229 |
+
rng.shuffle(columns)
|
| 230 |
+
|
| 231 |
+
for j in columns:
|
| 232 |
+
old_col = x[:, j].copy()
|
| 233 |
+
old_order = np.argsort(old_col, kind="mergesort").tolist()
|
| 234 |
+
|
| 235 |
+
# Proposal 1: anti-monotone arrangement against partial sums.
|
| 236 |
+
rest_sum = np.sum(x, axis=1) - x[:, j]
|
| 237 |
+
row_order = np.argsort(rest_sum, kind="mergesort")
|
| 238 |
+
col_sorted_desc = np.sort(old_col)[::-1]
|
| 239 |
+
|
| 240 |
+
candidate = x.copy()
|
| 241 |
+
candidate[row_order, j] = col_sorted_desc
|
| 242 |
+
cand_obj = objective_fn(candidate)
|
| 243 |
+
|
| 244 |
+
best = candidate
|
| 245 |
+
best_obj = cand_obj
|
| 246 |
+
best_action = f"列 {j + 1}: 他列の行和に対して反対順に rearrange"
|
| 247 |
+
|
| 248 |
+
# Proposal 2: if anti-monotone does not help enough, try pair swaps.
|
| 249 |
+
# This makes the app useful for non-smooth/custom f as well.
|
| 250 |
+
pair_best = x.copy()
|
| 251 |
+
pair_best_obj = current
|
| 252 |
+
pair_action = None
|
| 253 |
+
|
| 254 |
+
max_pair_checks = min(n * (n - 1) // 2, 2500)
|
| 255 |
+
pairs = list(itertools.combinations(range(n), 2))
|
| 256 |
+
|
| 257 |
+
if len(pairs) > max_pair_checks:
|
| 258 |
+
pairs_idx = rng.choice(len(pairs), size=max_pair_checks, replace=False)
|
| 259 |
+
pairs = [pairs[k] for k in pairs_idx]
|
| 260 |
+
|
| 261 |
+
for a, b in pairs:
|
| 262 |
+
tmp = x.copy()
|
| 263 |
+
tmp[a, j], tmp[b, j] = tmp[b, j], tmp[a, j]
|
| 264 |
+
tmp_obj = objective_fn(tmp)
|
| 265 |
+
|
| 266 |
+
if is_better(tmp_obj, pair_best_obj):
|
| 267 |
+
pair_best = tmp
|
| 268 |
+
pair_best_obj = tmp_obj
|
| 269 |
+
pair_action = f"列 {j + 1}: 行 {a + 1} と行 {b + 1} を swap"
|
| 270 |
+
|
| 271 |
+
if is_better(pair_best_obj, best_obj):
|
| 272 |
+
best = pair_best
|
| 273 |
+
best_obj = pair_best_obj
|
| 274 |
+
best_action = pair_action or f"列 {j + 1}: pairwise swap"
|
| 275 |
+
|
| 276 |
+
if is_better(best_obj, current):
|
| 277 |
+
x = best
|
| 278 |
+
current = best_obj
|
| 279 |
+
improved_any = True
|
| 280 |
+
|
| 281 |
+
new_col = x[:, j]
|
| 282 |
+
new_order = np.argsort(new_col, kind="mergesort").tolist()
|
| 283 |
+
|
| 284 |
+
steps.append(
|
| 285 |
+
Step(
|
| 286 |
+
len(steps),
|
| 287 |
+
x.tolist(),
|
| 288 |
+
current,
|
| 289 |
+
best_action,
|
| 290 |
+
column=j,
|
| 291 |
+
before_order=old_order,
|
| 292 |
+
after_order=new_order,
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if not improved_any:
|
| 297 |
+
steps.append(Step(len(steps), x.tolist(), current, "改善なし: 局所解として停止"))
|
| 298 |
+
break
|
| 299 |
+
|
| 300 |
+
return steps, objective_label
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def make_matrix_df(step: Step):
|
| 304 |
+
x = np.array(step.matrix)
|
| 305 |
+
df = pd.DataFrame(x, columns=[f"X{j + 1}" for j in range(x.shape[1])])
|
| 306 |
+
df.insert(0, "row", np.arange(1, x.shape[0] + 1))
|
| 307 |
+
df["sum"] = x.sum(axis=1)
|
| 308 |
+
df = df.round(6)
|
| 309 |
+
|
| 310 |
+
min_sum = df["sum"].min()
|
| 311 |
+
max_sum = df["sum"].max()
|
| 312 |
+
|
| 313 |
+
def highlight_sum_extremes(row):
|
| 314 |
+
styles = [""] * len(row)
|
| 315 |
+
|
| 316 |
+
sum_col_idx = row.index.get_loc("sum")
|
| 317 |
+
|
| 318 |
+
if row["sum"] == max_sum:
|
| 319 |
+
styles[sum_col_idx] = "background-color: #fecaca; color: #7f1d1d; font-weight: bold;"
|
| 320 |
+
|
| 321 |
+
if row["sum"] == min_sum:
|
| 322 |
+
styles[sum_col_idx] = "background-color: #bfdbfe; color: #1e3a8a; font-weight: bold;"
|
| 323 |
+
|
| 324 |
+
return styles
|
| 325 |
+
|
| 326 |
+
return df.style.apply(highlight_sum_extremes, axis=1)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def make_marginal_check_df(initial: np.ndarray, current: np.ndarray) -> pd.DataFrame:
|
| 330 |
+
rows = []
|
| 331 |
+
for j in range(initial.shape[1]):
|
| 332 |
+
same = np.allclose(np.sort(initial[:, j]), np.sort(current[:, j]))
|
| 333 |
+
rows.append(
|
| 334 |
+
{
|
| 335 |
+
"列": f"X{j + 1}",
|
| 336 |
+
"周辺分布 preserved?": "YES" if same else "NO",
|
| 337 |
+
"初期 min": np.min(initial[:, j]),
|
| 338 |
+
"現在 min": np.min(current[:, j]),
|
| 339 |
+
"初期 max": np.max(initial[:, j]),
|
| 340 |
+
"現在 max": np.max(current[:, j]),
|
| 341 |
+
"初期 mean": np.mean(initial[:, j]),
|
| 342 |
+
"現在 mean": np.mean(current[:, j]),
|
| 343 |
+
}
|
| 344 |
+
)
|
| 345 |
+
return pd.DataFrame(rows).round(6)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def make_heatmap(step: Step) -> go.Figure:
|
| 349 |
+
x = np.array(step.matrix)
|
| 350 |
+
fig = go.Figure(
|
| 351 |
+
data=go.Heatmap(
|
| 352 |
+
z=x,
|
| 353 |
+
x=[f"X{j + 1}" for j in range(x.shape[1])],
|
| 354 |
+
y=[f"row {i + 1}" for i in range(x.shape[0])],
|
| 355 |
+
colorbar=dict(title="value"),
|
| 356 |
+
)
|
| 357 |
+
)
|
| 358 |
+
title = f"Step {step.step}: {step.action}<br>目的関数 = {step.objective:.6g}"
|
| 359 |
+
fig.update_layout(title=title, height=450, margin=dict(l=70, r=30, t=80, b=40))
|
| 360 |
+
return fig
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def make_trace(steps: List[Step]) -> go.Figure:
|
| 364 |
+
fig = go.Figure()
|
| 365 |
+
fig.add_trace(
|
| 366 |
+
go.Scatter(
|
| 367 |
+
x=[s.step for s in steps],
|
| 368 |
+
y=[s.objective for s in steps],
|
| 369 |
+
mode="lines+markers",
|
| 370 |
+
hovertemplate="step %{x}<br>目的関数=%{y:.6g}<extra></extra>",
|
| 371 |
+
)
|
| 372 |
+
)
|
| 373 |
+
fig.update_layout(
|
| 374 |
+
title="目的関数の推移",
|
| 375 |
+
xaxis_title="step",
|
| 376 |
+
yaxis_title="目的関数",
|
| 377 |
+
height=360,
|
| 378 |
+
margin=dict(l=50, r=20, t=60, b=40),
|
| 379 |
+
)
|
| 380 |
+
return fig
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _discrete_entropy_from_counts(counts: np.ndarray) -> float:
|
| 386 |
+
"""Shannon entropy from empirical counts, using natural log."""
|
| 387 |
+
counts = np.asarray(counts, dtype=float)
|
| 388 |
+
counts = counts[counts > 0]
|
| 389 |
+
if counts.size == 0:
|
| 390 |
+
return 0.0
|
| 391 |
+
probs = counts / counts.sum()
|
| 392 |
+
return float(-np.sum(probs * np.log(probs)))
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _rank_bin_matrix(x: np.ndarray, n_bins: int) -> np.ndarray:
|
| 396 |
+
"""
|
| 397 |
+
Convert X to empirical-copula / rank-bin labels.
|
| 398 |
+
|
| 399 |
+
RA preserves each column's multiset, so the marginal empirical distribution
|
| 400 |
+
is fixed. To visualize the copula part, we throw away the scale of each
|
| 401 |
+
marginal and keep only rank-bin labels within each column.
|
| 402 |
+
|
| 403 |
+
The bin labels approximate U_j = F_j(X_j) on a finite sample.
|
| 404 |
+
"""
|
| 405 |
+
x = np.asarray(x, dtype=float)
|
| 406 |
+
n, d = x.shape
|
| 407 |
+
n_bins = int(max(2, min(int(n_bins), n)))
|
| 408 |
+
z = np.zeros((n, d), dtype=int)
|
| 409 |
+
|
| 410 |
+
for j in range(d):
|
| 411 |
+
# Stable ordinal ranks. Ties are broken by row order; for heavily tied
|
| 412 |
+
# integer examples this is still only a visualization, not an estimator
|
| 413 |
+
# of a continuous copula density.
|
| 414 |
+
order = np.argsort(x[:, j], kind="mergesort")
|
| 415 |
+
ranks = np.empty(n, dtype=int)
|
| 416 |
+
ranks[order] = np.arange(n)
|
| 417 |
+
z[:, j] = np.floor(ranks * n_bins / n).astype(int)
|
| 418 |
+
z[:, j] = np.clip(z[:, j], 0, n_bins - 1)
|
| 419 |
+
|
| 420 |
+
return z
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def copula_entropy_summary(x: np.ndarray, n_bins: int) -> Dict[str, float]:
|
| 424 |
+
"""
|
| 425 |
+
Rank-binned empirical entropy summary.
|
| 426 |
+
|
| 427 |
+
H_joint is H(B_1, ..., B_d), where B_j is the rank-bin of X_j.
|
| 428 |
+
H_marginal_sum is sum_j H(B_j). Because RA preserves each marginal,
|
| 429 |
+
H_marginal_sum should be constant up to ties/binning.
|
| 430 |
+
|
| 431 |
+
For a continuous copula density c, copula entropy is often defined as
|
| 432 |
+
H_c = h(c) = h(X_1,...,X_d) - sum_j h(X_j)
|
| 433 |
+
and mutual information is
|
| 434 |
+
I = sum_j h(X_j) - h(X_1,...,X_d) = -H_c.
|
| 435 |
+
|
| 436 |
+
The values here are finite-sample, rank-binned approximations.
|
| 437 |
+
"""
|
| 438 |
+
bins = _rank_bin_matrix(np.asarray(x, dtype=float), int(n_bins))
|
| 439 |
+
|
| 440 |
+
_, joint_counts = np.unique(bins, axis=0, return_counts=True)
|
| 441 |
+
h_joint = _discrete_entropy_from_counts(joint_counts)
|
| 442 |
+
|
| 443 |
+
h_marginals = []
|
| 444 |
+
for j in range(bins.shape[1]):
|
| 445 |
+
_, counts = np.unique(bins[:, j], return_counts=True)
|
| 446 |
+
h_marginals.append(_discrete_entropy_from_counts(counts))
|
| 447 |
+
|
| 448 |
+
h_marginal_sum = float(np.sum(h_marginals))
|
| 449 |
+
copula_entropy = h_joint - h_marginal_sum
|
| 450 |
+
mutual_information = h_marginal_sum - h_joint
|
| 451 |
+
|
| 452 |
+
return {
|
| 453 |
+
"H_joint": h_joint,
|
| 454 |
+
"H_marginal_sum": h_marginal_sum,
|
| 455 |
+
"H_copula": copula_entropy,
|
| 456 |
+
"mutual_information": mutual_information,
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def make_entropy_trace(steps: List[Step], n_bins: int) -> go.Figure:
|
| 461 |
+
rows = []
|
| 462 |
+
for s in steps:
|
| 463 |
+
ent = copula_entropy_summary(np.array(s.matrix), int(n_bins))
|
| 464 |
+
rows.append({"step": s.step, **ent})
|
| 465 |
+
|
| 466 |
+
df = pd.DataFrame(rows)
|
| 467 |
+
|
| 468 |
+
fig = go.Figure()
|
| 469 |
+
fig.add_trace(
|
| 470 |
+
go.Scatter(
|
| 471 |
+
x=df["step"],
|
| 472 |
+
y=df["H_joint"],
|
| 473 |
+
mode="lines+markers",
|
| 474 |
+
name="joint entropy H(X)",
|
| 475 |
+
hovertemplate="step %{x}<br>H(X)=%{y:.6g}<extra></extra>",
|
| 476 |
+
)
|
| 477 |
+
)
|
| 478 |
+
fig.add_trace(
|
| 479 |
+
go.Scatter(
|
| 480 |
+
x=df["step"],
|
| 481 |
+
y=df["H_marginal_sum"],
|
| 482 |
+
mode="lines+markers",
|
| 483 |
+
name="marginal entropy sum ΣH(Fj)",
|
| 484 |
+
hovertemplate="step %{x}<br>ΣH(Fj)=%{y:.6g}<extra></extra>",
|
| 485 |
+
)
|
| 486 |
+
)
|
| 487 |
+
fig.add_trace(
|
| 488 |
+
go.Scatter(
|
| 489 |
+
x=df["step"],
|
| 490 |
+
y=df["H_copula"],
|
| 491 |
+
mode="lines+markers",
|
| 492 |
+
name="copula entropy Hc=H(X)-ΣH(Fj)",
|
| 493 |
+
hovertemplate="step %{x}<br>Hc=%{y:.6g}<extra></extra>",
|
| 494 |
+
)
|
| 495 |
+
)
|
| 496 |
+
fig.update_layout(
|
| 497 |
+
title=f"rank-binned copula entropy の推移(各列 {int(n_bins)} rank bins)",
|
| 498 |
+
xaxis_title="step",
|
| 499 |
+
yaxis_title="entropy / nats",
|
| 500 |
+
height=420,
|
| 501 |
+
margin=dict(l=50, r=20, t=70, b=80),
|
| 502 |
+
legend=dict(orientation="h", yanchor="bottom", y=-0.35, xanchor="left", x=0),
|
| 503 |
+
)
|
| 504 |
+
return fig
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def make_entropy_history_df(steps: List[Step], n_bins: int) -> pd.DataFrame:
|
| 508 |
+
rows = []
|
| 509 |
+
for s in steps:
|
| 510 |
+
ent = copula_entropy_summary(np.array(s.matrix), int(n_bins))
|
| 511 |
+
rows.append(
|
| 512 |
+
{
|
| 513 |
+
"step": s.step,
|
| 514 |
+
"H(X) joint": ent["H_joint"],
|
| 515 |
+
"ΣH(Fj) marginal sum": ent["H_marginal_sum"],
|
| 516 |
+
"Hc = H(X)-ΣH(Fj)": ent["H_copula"],
|
| 517 |
+
"I = ΣH(Fj)-H(X)": ent["mutual_information"],
|
| 518 |
+
}
|
| 519 |
+
)
|
| 520 |
+
return pd.DataFrame(rows).round(6)
|
| 521 |
+
|
| 522 |
+
def make_sum_hist(step: Step) -> go.Figure:
|
| 523 |
+
x = np.array(step.matrix)
|
| 524 |
+
sums = x.sum(axis=1)
|
| 525 |
+
fig = go.Figure()
|
| 526 |
+
fig.add_trace(go.Histogram(x=sums, nbinsx=min(20, max(5, len(sums) // 2))))
|
| 527 |
+
fig.update_layout(
|
| 528 |
+
title="行和 S = X1 + ... + Xd の分布",
|
| 529 |
+
xaxis_title="S",
|
| 530 |
+
yaxis_title="count",
|
| 531 |
+
height=320,
|
| 532 |
+
margin=dict(l=50, r=20, t=60, b=40),
|
| 533 |
+
)
|
| 534 |
+
return fig
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def state_to_steps(state: List[Dict]) -> List[Step]:
|
| 538 |
+
return [Step(**s) for s in state]
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def run_algorithm(
|
| 542 |
+
matrix_text: str,
|
| 543 |
+
psi_name: str,
|
| 544 |
+
theta: float,
|
| 545 |
+
custom_expr: str,
|
| 546 |
+
max_iter: int,
|
| 547 |
+
random_tie_break: bool,
|
| 548 |
+
seed: int,
|
| 549 |
+
entropy_bins: int,
|
| 550 |
+
):
|
| 551 |
+
try:
|
| 552 |
+
x0 = parse_matrix(matrix_text)
|
| 553 |
+
steps, objective_label = greedy_sort_rearrangement(
|
| 554 |
+
x0,
|
| 555 |
+
psi_name,
|
| 556 |
+
theta,
|
| 557 |
+
custom_expr,
|
| 558 |
+
int(max_iter),
|
| 559 |
+
bool(random_tie_break),
|
| 560 |
+
int(seed),
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
state = [s.__dict__ for s in steps]
|
| 564 |
+
first = steps[0]
|
| 565 |
+
last = steps[-1]
|
| 566 |
+
|
| 567 |
+
improvement = steps[0].objective - last.objective
|
| 568 |
+
objective_name = "E[ψ(X)]"
|
| 569 |
+
|
| 570 |
+
summary = (
|
| 571 |
+
f"{objective_label}\n"
|
| 572 |
+
f"初期 {objective_name} = {steps[0].objective:.8g}\n"
|
| 573 |
+
f"最終 {objective_name} = {last.objective:.8g}\n"
|
| 574 |
+
f"改善量 = {improvement:.8g}\n"
|
| 575 |
+
f"ステップ数 = {len(steps) - 1}\n"
|
| 576 |
+
f"エントロピー: 各列を {int(entropy_bins)} 個のrank binに変換し、経験copula上で計算"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
entropy_history = make_entropy_history_df(steps, int(entropy_bins))
|
| 580 |
+
history = pd.DataFrame(
|
| 581 |
+
{
|
| 582 |
+
"step": [s.step for s in steps],
|
| 583 |
+
"目的関数": [s.objective for s in steps],
|
| 584 |
+
"操作": [s.action for s in steps],
|
| 585 |
+
}
|
| 586 |
+
).merge(entropy_history, on="step", how="left")
|
| 587 |
+
|
| 588 |
+
return (
|
| 589 |
+
state,
|
| 590 |
+
0,
|
| 591 |
+
len(steps) - 1,
|
| 592 |
+
make_matrix_df(first),
|
| 593 |
+
make_heatmap(first),
|
| 594 |
+
make_trace(steps),
|
| 595 |
+
make_entropy_trace(steps, int(entropy_bins)),
|
| 596 |
+
make_sum_hist(first),
|
| 597 |
+
make_marginal_check_df(np.array(steps[0].matrix), np.array(first.matrix)),
|
| 598 |
+
history,
|
| 599 |
+
summary,
|
| 600 |
+
"",
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
except Exception as exc:
|
| 604 |
+
return (
|
| 605 |
+
[],
|
| 606 |
+
0,
|
| 607 |
+
0,
|
| 608 |
+
pd.DataFrame(),
|
| 609 |
+
go.Figure(),
|
| 610 |
+
go.Figure(),
|
| 611 |
+
go.Figure(),
|
| 612 |
+
go.Figure(),
|
| 613 |
+
pd.DataFrame(),
|
| 614 |
+
pd.DataFrame(),
|
| 615 |
+
"",
|
| 616 |
+
f"エラー: {exc}",
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def show_step(state: List[Dict], step_no: int):
|
| 621 |
+
if not state:
|
| 622 |
+
return pd.DataFrame(), go.Figure(), go.Figure(), pd.DataFrame(), "先に実行してください。"
|
| 623 |
+
|
| 624 |
+
steps = state_to_steps(state)
|
| 625 |
+
step_no = int(max(0, min(step_no, len(steps) - 1)))
|
| 626 |
+
step = steps[step_no]
|
| 627 |
+
|
| 628 |
+
initial = np.array(steps[0].matrix)
|
| 629 |
+
current = np.array(step.matrix)
|
| 630 |
+
|
| 631 |
+
message = (
|
| 632 |
+
f"Step {step.step} / {len(steps) - 1}\n"
|
| 633 |
+
f"{step.action}\n"
|
| 634 |
+
f"目的関数 = {step.objective:.8g}"
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
return (
|
| 638 |
+
make_matrix_df(step),
|
| 639 |
+
make_heatmap(step),
|
| 640 |
+
make_sum_hist(step),
|
| 641 |
+
make_marginal_check_df(initial, current),
|
| 642 |
+
message,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def move_step(state: List[Dict], current_step: int, delta: int):
|
| 647 |
+
if not state:
|
| 648 |
+
return 0, pd.DataFrame(), go.Figure(), go.Figure(), pd.DataFrame(), "先に実行してください。"
|
| 649 |
+
|
| 650 |
+
steps = state_to_steps(state)
|
| 651 |
+
new_step = int(max(0, min(int(current_step) + delta, len(steps) - 1)))
|
| 652 |
+
|
| 653 |
+
matrix_df, heatmap, hist, marginal_df, message = show_step(state, new_step)
|
| 654 |
+
|
| 655 |
+
return new_step, matrix_df, heatmap, hist, marginal_df, message
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def autoplay_tick(state: List[Dict], current_step: int, interval_sec: float):
|
| 659 |
+
"""Advance one step on each timer tick and stop automatically at the final step."""
|
| 660 |
+
if not state:
|
| 661 |
+
return (
|
| 662 |
+
0,
|
| 663 |
+
pd.DataFrame(),
|
| 664 |
+
go.Figure(),
|
| 665 |
+
go.Figure(),
|
| 666 |
+
pd.DataFrame(),
|
| 667 |
+
"先に実行してください。",
|
| 668 |
+
gr.Timer(active=False),
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
steps = state_to_steps(state)
|
| 672 |
+
current_step = int(current_step)
|
| 673 |
+
|
| 674 |
+
if current_step >= len(steps) - 1:
|
| 675 |
+
matrix_df, heatmap, hist, marginal_df, message = show_step(state, current_step)
|
| 676 |
+
return (
|
| 677 |
+
current_step,
|
| 678 |
+
matrix_df,
|
| 679 |
+
heatmap,
|
| 680 |
+
hist,
|
| 681 |
+
marginal_df,
|
| 682 |
+
message + "再生終了。",
|
| 683 |
+
gr.Timer(active=False),
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
new_step = current_step + 1
|
| 687 |
+
matrix_df, heatmap, hist, marginal_df, message = show_step(state, new_step)
|
| 688 |
+
|
| 689 |
+
return (
|
| 690 |
+
new_step,
|
| 691 |
+
matrix_df,
|
| 692 |
+
heatmap,
|
| 693 |
+
hist,
|
| 694 |
+
marginal_df,
|
| 695 |
+
message,
|
| 696 |
+
gr.Timer(value=float(interval_sec), active=True),
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def generate_matrix(n: int, d: int, seed: int, distribution: str):
|
| 701 |
+
try:
|
| 702 |
+
x = make_initial_matrix(int(n), int(d), int(seed), distribution)
|
| 703 |
+
return matrix_to_text(x), ""
|
| 704 |
+
except Exception as exc:
|
| 705 |
+
return "", f"エラー: {exc}"
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
CSS = """
|
| 709 |
+
#title { text-align: center; }
|
| 710 |
+
.note { color: #475569; font-size: 0.95rem; }
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
with gr.Blocks(css=CSS, title="Rearrangement Algorithm Visualizer") as demo:
|
| 714 |
+
gr.Markdown(
|
| 715 |
+
"""
|
| 716 |
+
# Rearrangement Algorithm Visualizer
|
| 717 |
+
|
| 718 |
+
各列を周辺分布の標本とみなし、**列内の並べ替えだけ**で結合分布を変えます。
|
| 719 |
+
そのため、各周辺分布は保存されたまま、目的関数 `E[ψ(X)] = E[f(X1 + ... + Xd)]` が小さくなるようにヒューリスティックに rearrange します。
|
| 720 |
+
|
| 721 |
+
<p class="note">
|
| 722 |
+
行 = 同時実現シナリオ、列 = 周辺分布。列の値の multiset は不変です。
|
| 723 |
+
</p>
|
| 724 |
+
""",
|
| 725 |
+
elem_id="title",
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
state = gr.State([])
|
| 729 |
+
|
| 730 |
+
with gr.Row():
|
| 731 |
+
with gr.Column(scale=2):
|
| 732 |
+
matrix_input = gr.Textbox(
|
| 733 |
+
label="X: n 行 d 列",
|
| 734 |
+
value=(
|
| 735 |
+
"0.12, 1.80, -0.40\n"
|
| 736 |
+
"-1.10, 0.35, 1.25\n"
|
| 737 |
+
"0.65, -0.90, 0.10\n"
|
| 738 |
+
"1.40, 0.05, -1.30\n"
|
| 739 |
+
"-0.55, -1.20, 0.75\n"
|
| 740 |
+
"0.90, 1.10, -0.85"
|
| 741 |
+
),
|
| 742 |
+
lines=8,
|
| 743 |
+
placeholder="例:\n1, 4, 7\n2, 5, 8\n3, 6, 9",
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
with gr.Column(scale=1):
|
| 747 |
+
n_input = gr.Slider(label="生成 n", minimum=4, maximum=100, step=1, value=12)
|
| 748 |
+
d_input = gr.Slider(label="生成 d", minimum=2, maximum=10, step=1, value=3)
|
| 749 |
+
seed_input = gr.Number(label="seed", value=1, precision=0)
|
| 750 |
+
distribution_input = gr.Dropdown(
|
| 751 |
+
label="生成する周辺分布",
|
| 752 |
+
choices=["integer", "normal", "uniform", "lognormal"],
|
| 753 |
+
value="integer",
|
| 754 |
+
)
|
| 755 |
+
generate_matrix_btn = gr.Button("ランダム初期値を生成")
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
with gr.Row():
|
| 759 |
+
psi_input = gr.Dropdown(
|
| 760 |
+
label="f(ψ(x₁,...,x_d)=f(Σxᵢ))",
|
| 761 |
+
choices=["square", "absolute", "exponential", "positive_part", "custom"],
|
| 762 |
+
value="square",
|
| 763 |
+
)
|
| 764 |
+
theta_input = gr.Number(label="exponential の θ", value=1.0)
|
| 765 |
+
custom_expr_input = gr.Textbox(
|
| 766 |
+
label="custom f(s) 式",
|
| 767 |
+
value="s**2",
|
| 768 |
+
placeholder="例: s**2, np.exp(0.5*s), maximum(s, 0) / 変数: s",
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
with gr.Row():
|
| 772 |
+
max_iter_input = gr.Slider(label="最大 sweep 数", minimum=1, maximum=100, step=1, value=20)
|
| 773 |
+
random_tie_input = gr.Checkbox(label="列順をランダム化", value=False)
|
| 774 |
+
entropy_bins_input = gr.Slider(
|
| 775 |
+
label="copula entropy 用の rank bin 数",
|
| 776 |
+
minimum=2,
|
| 777 |
+
maximum=10,
|
| 778 |
+
step=1,
|
| 779 |
+
value=2,
|
| 780 |
+
info="各列をrank binに変換し、周辺を固定したcopula側の経験エントロピーを近似します。少ない行数では2〜4を推奨。",
|
| 781 |
+
)
|
| 782 |
+
run_btn = gr.Button("RA を実行", variant="primary")
|
| 783 |
+
|
| 784 |
+
error_box = gr.Textbox(label="メッセージ", interactive=False)
|
| 785 |
+
summary_box = gr.Textbox(label="サマリー", lines=5, interactive=False)
|
| 786 |
+
|
| 787 |
+
playback_timer = gr.Timer(value=0.8, active=False)
|
| 788 |
+
|
| 789 |
+
with gr.Row():
|
| 790 |
+
prev_btn = gr.Button("← 前へ")
|
| 791 |
+
play_btn = gr.Button("▶ 再生", variant="secondary")
|
| 792 |
+
stop_btn = gr.Button("⏸ 停止")
|
| 793 |
+
next_btn = gr.Button("次へ →")
|
| 794 |
+
|
| 795 |
+
with gr.Row():
|
| 796 |
+
step_slider = gr.Slider(label="Step", minimum=0, maximum=100, step=1, value=0)
|
| 797 |
+
interval_slider = gr.Slider(label="再生間隔 秒/step", minimum=0.1, maximum=3.0, step=0.1, value=0.8)
|
| 798 |
+
|
| 799 |
+
step_message = gr.Textbox(label="現在の状態", lines=3, interactive=False)
|
| 800 |
+
|
| 801 |
+
with gr.Row():
|
| 802 |
+
matrix_df = gr.Dataframe(label="現在の X", interactive=False, wrap=True)
|
| 803 |
+
marginal_df = gr.Dataframe(label="周辺分布チェック", interactive=False, wrap=True)
|
| 804 |
+
|
| 805 |
+
with gr.Row():
|
| 806 |
+
heatmap = gr.Plot(label="X のヒートマップ")
|
| 807 |
+
with gr.Column():
|
| 808 |
+
trace_plot = gr.Plot(label="目的関数の推移")
|
| 809 |
+
entropy_plot = gr.Plot(label="rank-binned copula entropy の推移")
|
| 810 |
+
|
| 811 |
+
sum_hist = gr.Plot(label="行和の分布")
|
| 812 |
+
history_df = gr.Dataframe(label="履歴", interactive=False, wrap=True)
|
| 813 |
+
|
| 814 |
+
generate_matrix_btn.click(
|
| 815 |
+
generate_matrix,
|
| 816 |
+
inputs=[n_input, d_input, seed_input, distribution_input],
|
| 817 |
+
outputs=[matrix_input, error_box],
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
run_btn.click(
|
| 822 |
+
run_algorithm,
|
| 823 |
+
inputs=[
|
| 824 |
+
matrix_input,
|
| 825 |
+
psi_input,
|
| 826 |
+
theta_input,
|
| 827 |
+
custom_expr_input,
|
| 828 |
+
max_iter_input,
|
| 829 |
+
random_tie_input,
|
| 830 |
+
seed_input,
|
| 831 |
+
entropy_bins_input,
|
| 832 |
+
],
|
| 833 |
+
outputs=[
|
| 834 |
+
state,
|
| 835 |
+
step_slider,
|
| 836 |
+
step_slider,
|
| 837 |
+
matrix_df,
|
| 838 |
+
heatmap,
|
| 839 |
+
trace_plot,
|
| 840 |
+
entropy_plot,
|
| 841 |
+
sum_hist,
|
| 842 |
+
marginal_df,
|
| 843 |
+
history_df,
|
| 844 |
+
summary_box,
|
| 845 |
+
error_box,
|
| 846 |
+
],
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
step_slider.change(
|
| 850 |
+
show_step,
|
| 851 |
+
inputs=[state, step_slider],
|
| 852 |
+
outputs=[matrix_df, heatmap, sum_hist, marginal_df, step_message],
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
prev_btn.click(
|
| 856 |
+
lambda s, c: move_step(s, c, -1),
|
| 857 |
+
inputs=[state, step_slider],
|
| 858 |
+
outputs=[step_slider, matrix_df, heatmap, sum_hist, marginal_df, step_message],
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
next_btn.click(
|
| 862 |
+
lambda s, c: move_step(s, c, 1),
|
| 863 |
+
inputs=[state, step_slider],
|
| 864 |
+
outputs=[step_slider, matrix_df, heatmap, sum_hist, marginal_df, step_message],
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
play_btn.click(
|
| 868 |
+
lambda interval: gr.Timer(value=float(interval), active=True),
|
| 869 |
+
inputs=[interval_slider],
|
| 870 |
+
outputs=[playback_timer],
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
stop_btn.click(
|
| 874 |
+
lambda: gr.Timer(active=False),
|
| 875 |
+
inputs=None,
|
| 876 |
+
outputs=[playback_timer],
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
playback_timer.tick(
|
| 880 |
+
autoplay_tick,
|
| 881 |
+
inputs=[state, step_slider, interval_slider],
|
| 882 |
+
outputs=[
|
| 883 |
+
step_slider,
|
| 884 |
+
matrix_df,
|
| 885 |
+
heatmap,
|
| 886 |
+
sum_hist,
|
| 887 |
+
marginal_df,
|
| 888 |
+
step_message,
|
| 889 |
+
playback_timer,
|
| 890 |
+
],
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
gr.Markdown(
|
| 894 |
+
"""
|
| 895 |
+
# Acknowledgments
|
| 896 |
+
|
| 897 |
+
以下文献を参考にしました。
|
| 898 |
+
|
| 899 |
+
[1] 小池, 南, 白石, 「[再配列アルゴリズムを用いたVaR境界の算出](https://www.jstage.jst.go.jp/article/jjssj/45/2/45_353/_pdf/-char/ja)」, 2016.
|
| 900 |
+
|
| 901 |
+
""",
|
| 902 |
+
elem_id="title",
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
if __name__ == "__main__":
|
| 907 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
+
numpy>=1.26.0
|
| 3 |
+
pandas>=2.2.0
|
| 4 |
+
plotly>=5.20.0
|