Spaces:
Runtime error
Runtime error
| import random | |
| from statistics import mean | |
| from typing import List, Tuple | |
| import torch as th | |
| import pytorch_lightning as pl | |
| from jaxtyping import Float, Int | |
| import numpy as np | |
| from torch_geometric.nn.conv import GATv2Conv | |
| from models.SAP.dpsr import DPSR | |
| from models.SAP.model import PSR2Mesh | |
| # Constants | |
| th.manual_seed(0) | |
| np.random.seed(0) | |
| BATCH_SIZE = 1 # BS | |
| IN_DIM = 1 | |
| OUT_DIM = 1 | |
| LATENT_DIM = 32 | |
| DROPOUT_PROB = 0.1 | |
| GRID_SIZE = 128 | |
| def generate_grid_edge_list(gs: int = 128): | |
| grid_edge_list = [] | |
| for k in range(gs): | |
| for j in range(gs): | |
| for i in range(gs): | |
| current_idx = i + gs*j + k*gs*gs | |
| if (i - 1) >= 0: | |
| grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs]) | |
| if (i + 1) < gs: | |
| grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs]) | |
| if (j - 1) >= 0: | |
| grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs]) | |
| if (j + 1) < gs: | |
| grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs]) | |
| if (k - 1) >= 0: | |
| grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs]) | |
| if (k + 1) < gs: | |
| grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs]) | |
| return grid_edge_list | |
| GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE) | |
| GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int) | |
| GRID_EDGE_LIST = GRID_EDGE_LIST.T | |
| # GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda")) | |
| GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train | |
| class FormOptimizer(th.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| layers = [] | |
| self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) | |
| self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) | |
| self.actv = th.nn.Sigmoid() | |
| self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM) | |
| def forward(self, | |
| field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]: | |
| """ | |
| Args: | |
| field (Tensor [GS, GS, GS]): vertices and normals tensor. | |
| """ | |
| vertex_features = field.clone() | |
| vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM) | |
| vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST) | |
| vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST) | |
| field_delta = self.head(self.actv(vertex_features)) | |
| field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE) | |
| field_delta += field # field_delta carries the gradient | |
| field_delta = th.clamp(field_delta, min=-0.5, max=0.5) | |
| return field_delta | |
| class Model(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.form_optimizer = FormOptimizer() | |
| self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0) | |
| self.field2mesh = PSR2Mesh().apply | |
| self.metric = th.nn.MSELoss() | |
| self.val_losses = [] | |
| self.train_losses = [] | |
| def log_h5(self, points, normals): | |
| dset = self.log_points_file.create_dataset( | |
| name=str(self.h5_frame), | |
| shape=points.shape, | |
| dtype=np.float16, | |
| compression="gzip") | |
| dset[:] = points | |
| dset = self.log_normals_file.create_dataset( | |
| name=str(self.h5_frame), | |
| shape=normals.shape, | |
| dtype=np.float16, | |
| compression="gzip") | |
| dset[:] = normals | |
| self.h5_frame += 1 | |
| def forward(self, | |
| v: Float[th.Tensor, "BS N 3"], | |
| n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices | |
| Int[th.Tensor, "2 E"], # f - faces | |
| Float[th.Tensor, "BS N 3"], # n - vertices normals | |
| Float[th.Tensor, "BS GR GR GR"]]: # field: | |
| field = self.dpsr(v, n) | |
| field = self.form_optimizer(field) | |
| v, f, n = self.field2mesh(field) | |
| return v, f, n, field | |
| def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]: | |
| vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch | |
| mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5) | |
| vertices = vertices[:, mask] | |
| vertices_normals = vertices_normals[:, mask] | |
| vr, fr, nr, field_r = model(vertices, vertices_normals) | |
| loss = self.metric(field_r, field_gt) | |
| train_per_step_loss = loss.item() | |
| self.train_losses.append(train_per_step_loss) | |
| return loss | |
| def on_train_epoch_end(self): | |
| mean_train_per_epoch_loss = mean(self.train_losses) | |
| self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True) | |
| self.train_losses = [] | |
| def validation_step(self, batch, batch_idx): | |
| vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch | |
| vr, fr, nr, field_r = model(vertices, vertices_normals) | |
| loss = self.metric(field_r, field_gt) | |
| val_per_step_loss = loss.item() | |
| self.val_losses.append(val_per_step_loss) | |
| return loss | |
| def on_validation_epoch_end(self): | |
| mean_val_per_epoch_loss = mean(self.val_losses) | |
| self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True) | |
| self.val_losses = [] | |
| def configure_optimizers(self): | |
| optimizer = th.optim.Adam(self.parameters(), lr=LR) | |
| scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": scheduler, | |
| "monitor": "mean_val_per_epoch_loss", | |
| "interval": "epoch", | |
| "frequency": 1, | |
| # If set to `True`, will enforce that the value specified 'monitor' | |
| # is available when the scheduler is updated, thus stopping | |
| # training if not found. If set to `False`, it will only produce a warning | |
| "strict": True, | |
| # If using the `LearningRateMonitor` callback to monitor the | |
| # learning rate progress, this keyword can be used to specify | |
| # a custom logged name | |
| "name": None, | |
| } | |
| } | |