File size: 1,895 Bytes
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
from scipy import stats

def chatterjee_phase_to_amp(phi, amp, agg="max"):
    """
    phi: phase in radians (1D)
    amp: amplitude (1D)
    agg: 'max' | 'mean' | 'rss'
    """
    s = np.sin(phi)
    c = np.cos(phi)

    xi_s = stats.chatterjeexi(s, amp).statistic
    xi_c = stats.chatterjeexi(c, amp).statistic

    if agg == "max":
        xi = np.nanmax([xi_s, xi_c])
    elif agg == "mean":
        xi = np.nanmean([xi_s, xi_c])
    elif agg == "rss":
        xi = np.sqrt(xi_s**2 + xi_c**2)
        xi = float(np.clip(xi, 0.0, 1.0))
    else:
        raise ValueError("agg must be 'max', 'mean', or 'rss'")

    return xi #, {"xi_sin": xi_s, "xi_cos": xi_c}

def circular_correlation(rho, theta, mu=None, tau=None):
    rho = np.asarray(rho)
    theta = np.asarray(theta)

    if mu is None:
        mu = np.angle(np.mean(np.exp(1j * rho)))
    if tau is None:
        tau = np.angle(np.mean(np.exp(1j * theta)))

    x = np.sin(rho - mu)
    y = np.sin(theta - tau)

    return np.mean(x * y) / np.sqrt(np.var(x) * np.var(y))


def modulation_index(phase, amp, n_bins=18, eps=1e-12):
    """
    Tort et al. (2010) Modulation Index
    phase : radians (-pi, pi]
    amp   : amplitude envelope (>=0)
    """
    phase = np.asarray(phase).ravel()
    amp = np.asarray(amp).ravel()
    mask = np.isfinite(phase) & np.isfinite(amp)
    phase = phase[mask]
    amp = amp[mask]

    # phase bins
    edges = np.linspace(-np.pi, np.pi, n_bins + 1)
    bins = np.digitize(phase, edges) - 1
    bins = np.clip(bins, 0, n_bins - 1)

    mean_amp = np.zeros(n_bins)
    for k in range(n_bins):
        if np.any(bins == k):
            mean_amp[k] = amp[bins == k].mean()

    if mean_amp.sum() == 0:
        return np.nan

    p = mean_amp / mean_amp.sum()
    uniform = 1.0 / n_bins

    kl = np.sum(p * np.log((p + eps) / uniform))
    mi = kl / np.log(n_bins)
    return mi