File size: 2,419 Bytes
82a8f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import numpy as np
from ideal_poly_volume_toolkit.geometry import (
    delaunay_triangulation_indices,
    triangle_volume_from_points_torch,
)

def random_angles(K, rng): 
    return 2*np.pi*rng.random(K)

def build_Z(thetas: torch.Tensor) -> torch.Tensor:
    Z = torch.empty(thetas.numel() + 2, dtype=torch.complex128, device=thetas.device)
    Z[0] = 1 + 0j
    Z[1] = 0 + 0j
    Z[2:] = torch.exp(1j * thetas.to(torch.complex128))
    return Z

# Test with original implementation
rng = np.random.default_rng(0)
K = 3
thetas = torch.tensor(
    random_angles(K, rng), dtype=torch.float64, requires_grad=True
)

print(f"Initial thetas: {thetas}")

# Build complex numbers and triangulate
with torch.no_grad():
    Z_np = build_Z(thetas).detach().numpy()
    idx = delaunay_triangulation_indices(Z_np)
    print(f"Triangulation has {len(idx)} triangles")
    print(f"Indices: {idx}")

# Compute volume with gradients
Z_t = build_Z(thetas)
total = torch.zeros((), dtype=torch.float64)

for i, (i_idx, j_idx, k_idx) in enumerate(idx):
    vol = triangle_volume_from_points_torch(
        Z_t[i_idx], Z_t[j_idx], Z_t[k_idx], series_terms=96
    )
    print(f"Triangle {i}: volume = {vol.item():.6f}")
    total = total + vol

print(f"\nTotal volume: {total.item()}")

# Compute gradients
loss = -total
loss.backward()
print(f"Gradients on thetas: {thetas.grad}")

# Check which vertices are involved in triangles
vertices_used = set()
for (i, j, k) in idx:
    vertices_used.update([i, j, k])
print(f"\nVertices used in triangulation: {sorted(vertices_used)}")

# Map back to thetas
print("Vertex mapping:")
print("0: Z[0] = 1+0j (fixed)")
print("1: Z[1] = 0+0j (fixed)")
for i, theta_idx in enumerate(range(K)):
    print(f"{i+2}: Z[{i+2}] = exp(i*thetas[{theta_idx}])")

# Explicitly check gradient flow by perturbing each theta
print("\n\nChecking gradient flow by finite differences:")
eps = 1e-6
for i in range(K):
    thetas_plus = thetas.detach().clone()
    thetas_plus[i] += eps
    
    Z_plus = build_Z(thetas_plus)
    vol_plus = torch.zeros((), dtype=torch.float64)
    for (i_idx, j_idx, k_idx) in idx:
        vol_plus = vol_plus + triangle_volume_from_points_torch(
            Z_plus[i_idx], Z_plus[j_idx], Z_plus[k_idx], series_terms=96
        )
    
    finite_diff_grad = (vol_plus.item() - total.item()) / eps
    print(f"Finite diff gradient for theta[{i}]: {finite_diff_grad:.6f}")