|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
def solve_sparse(A_data, A_indices, A_shape, b, max_iter=2000, tol=1e-6): |
|
|
"""Conjugate Gradient solver with sparse matrix support.""" |
|
|
x = torch.zeros_like(b) |
|
|
r = b - torch.sparse.mm(torch.sparse_coo_tensor(A_indices, A_data, A_shape), x.unsqueeze(1)).squeeze() |
|
|
p = r.clone() |
|
|
rsold = torch.dot(r, r) |
|
|
|
|
|
for _ in range(max_iter): |
|
|
Ap = torch.sparse.mm(torch.sparse_coo_tensor(A_indices, A_data, A_shape), p.unsqueeze(1)).squeeze() |
|
|
alpha = rsold / torch.dot(p, Ap) |
|
|
x += alpha * p |
|
|
r -= alpha * Ap |
|
|
rsnew = torch.dot(r, r) |
|
|
|
|
|
if torch.sqrt(rsnew) < tol: |
|
|
break |
|
|
p = r + (rsnew / rsold) * p |
|
|
rsold = rsnew |
|
|
|
|
|
return x |
|
|
|
|
|
def solver(a): |
|
|
"""Solve Darcy equation with corrected sparse matrix construction.""" |
|
|
batch_size, N, _ = a.shape |
|
|
h = 1.0 / (N - 1) |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
solutions = np.zeros((batch_size, N, N), dtype=np.float32) |
|
|
|
|
|
|
|
|
i_grid, j_grid = torch.meshgrid(torch.arange(N, device=device), |
|
|
torch.arange(N, device=device), indexing='ij') |
|
|
k = i_grid * N + j_grid |
|
|
interior = (i_grid > 0) & (i_grid < N-1) & (j_grid > 0) & (j_grid < N-1) |
|
|
|
|
|
for b in range(batch_size): |
|
|
a_t = torch.tensor(a[b], dtype=torch.float32, device=device) |
|
|
|
|
|
|
|
|
a_x = 0.5 * (a_t[1:] + a_t[:-1]) |
|
|
a_y = 0.5 * (a_t[:, 1:] + a_t[:, :-1]) |
|
|
|
|
|
|
|
|
diag = torch.zeros(N, N, device=device) |
|
|
diag[1:-1, :] = (a_x[:-1] + a_x[1:])/h**2 |
|
|
diag[:, 1:-1] += (a_y[:, :-1] + a_y[:, 1:])/h**2 |
|
|
diag = -diag |
|
|
|
|
|
rows, cols, values = [], [], [] |
|
|
|
|
|
|
|
|
valid = interior |
|
|
rows.append(k[valid]) |
|
|
cols.append(k[valid]) |
|
|
values.append(diag[valid]) |
|
|
|
|
|
|
|
|
directions = [ |
|
|
(1, 0, a_x, lambda i,j: (i, j)), |
|
|
(-1, 0, a_x, lambda i,j: (i-1, j)), |
|
|
(0, 1, a_y, lambda i,j: (i, j)), |
|
|
(0, -1, a_y, lambda i,j: (i, j-1)) |
|
|
] |
|
|
|
|
|
for di, dj, coeffs, idx_fn in directions: |
|
|
valid_neighbor = (i_grid + di >= 0) & (i_grid + di < N) & (j_grid + dj >= 0) & (j_grid + dj < N) |
|
|
current_valid = interior & valid_neighbor |
|
|
|
|
|
|
|
|
i_vals = i_grid[current_valid] |
|
|
j_vals = j_grid[current_valid] |
|
|
|
|
|
|
|
|
ci, cj = idx_fn(i_vals, j_vals) |
|
|
coeff_vals = coeffs[ci, cj] / h**2 |
|
|
|
|
|
rows.append(k[current_valid]) |
|
|
cols.append(k[current_valid] + di*N + dj) |
|
|
values.append(coeff_vals) |
|
|
|
|
|
|
|
|
boundary = ~interior |
|
|
rows.append(k[boundary]) |
|
|
cols.append(k[boundary]) |
|
|
values.append(torch.ones(boundary.sum(), device=device)) |
|
|
|
|
|
|
|
|
A_indices = torch.stack([torch.cat(rows), torch.cat(cols)]) |
|
|
A_data = torch.cat(values) |
|
|
|
|
|
|
|
|
f = torch.zeros(N*N, device=device) |
|
|
f[interior.flatten()] = -1.0 |
|
|
|
|
|
|
|
|
u = solve_sparse(A_data, A_indices, (N*N, N*N), f) |
|
|
solutions[b] = u.reshape(N, N).cpu().numpy() |
|
|
|
|
|
print(f"Batch {b+1}/{batch_size}", end='\r') |
|
|
|
|
|
print("\nSolver completed with corrected sparse construction.") |
|
|
return solutions |
|
|
|