CodePDE / solvers /darcy /beta_1.0 /seeds /implementation_3.py
LDA1020's picture
feat: code release
56c4b9b verified
import numpy as np
import torch
def solve_sparse(A_data, A_indices, A_shape, b, max_iter=2000, tol=1e-6):
"""Conjugate Gradient solver with sparse matrix support."""
x = torch.zeros_like(b)
r = b - torch.sparse.mm(torch.sparse_coo_tensor(A_indices, A_data, A_shape), x.unsqueeze(1)).squeeze()
p = r.clone()
rsold = torch.dot(r, r)
for _ in range(max_iter):
Ap = torch.sparse.mm(torch.sparse_coo_tensor(A_indices, A_data, A_shape), p.unsqueeze(1)).squeeze()
alpha = rsold / torch.dot(p, Ap)
x += alpha * p
r -= alpha * Ap
rsnew = torch.dot(r, r)
if torch.sqrt(rsnew) < tol:
break
p = r + (rsnew / rsold) * p
rsold = rsnew
return x
def solver(a):
"""Solve Darcy equation with corrected sparse matrix construction."""
batch_size, N, _ = a.shape
h = 1.0 / (N - 1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
solutions = np.zeros((batch_size, N, N), dtype=np.float32)
# Grid indices setup
i_grid, j_grid = torch.meshgrid(torch.arange(N, device=device),
torch.arange(N, device=device), indexing='ij')
k = i_grid * N + j_grid
interior = (i_grid > 0) & (i_grid < N-1) & (j_grid > 0) & (j_grid < N-1)
for b in range(batch_size):
a_t = torch.tensor(a[b], dtype=torch.float32, device=device)
# Precompute coefficient matrices
a_x = 0.5 * (a_t[1:] + a_t[:-1]) # (N-1, N)
a_y = 0.5 * (a_t[:, 1:] + a_t[:, :-1]) # (N, N-1)
# Diagonal terms calculation
diag = torch.zeros(N, N, device=device)
diag[1:-1, :] = (a_x[:-1] + a_x[1:])/h**2 # x-contrib
diag[:, 1:-1] += (a_y[:, :-1] + a_y[:, 1:])/h**2 # y-contrib
diag = -diag
rows, cols, values = [], [], []
# Interior diagonal terms
valid = interior
rows.append(k[valid])
cols.append(k[valid])
values.append(diag[valid])
# Off-diagonal terms with direct index calculation
directions = [
(1, 0, a_x, lambda i,j: (i, j)), # East
(-1, 0, a_x, lambda i,j: (i-1, j)), # West
(0, 1, a_y, lambda i,j: (i, j)), # North
(0, -1, a_y, lambda i,j: (i, j-1)) # South
]
for di, dj, coeffs, idx_fn in directions:
valid_neighbor = (i_grid + di >= 0) & (i_grid + di < N) & (j_grid + dj >= 0) & (j_grid + dj < N)
current_valid = interior & valid_neighbor
# Get indices for valid entries
i_vals = i_grid[current_valid]
j_vals = j_grid[current_valid]
# Calculate coefficient indices
ci, cj = idx_fn(i_vals, j_vals)
coeff_vals = coeffs[ci, cj] / h**2
rows.append(k[current_valid])
cols.append(k[current_valid] + di*N + dj)
values.append(coeff_vals)
# Boundary conditions
boundary = ~interior
rows.append(k[boundary])
cols.append(k[boundary])
values.append(torch.ones(boundary.sum(), device=device))
# Build sparse matrix
A_indices = torch.stack([torch.cat(rows), torch.cat(cols)])
A_data = torch.cat(values)
# Right-hand side
f = torch.zeros(N*N, device=device)
f[interior.flatten()] = -1.0
# Solve system
u = solve_sparse(A_data, A_indices, (N*N, N*N), f)
solutions[b] = u.reshape(N, N).cpu().numpy()
print(f"Batch {b+1}/{batch_size}", end='\r')
print("\nSolver completed with corrected sparse construction.")
return solutions