Updates
Browse files- brlp_lite.py +29 -6
brlp_lite.py
CHANGED
|
@@ -32,10 +32,12 @@
|
|
| 32 |
# }
|
| 33 |
|
| 34 |
import os
|
|
|
|
| 35 |
from typing import Optional, Union
|
| 36 |
import pandas as pd
|
| 37 |
import argparse
|
| 38 |
import numpy as np
|
|
|
|
| 39 |
import warnings
|
| 40 |
import torch
|
| 41 |
import torch.nn as nn
|
|
@@ -43,7 +45,7 @@ from torch import Tensor
|
|
| 43 |
from torch.optim.optimizer import Optimizer
|
| 44 |
from torch.nn import L1Loss
|
| 45 |
from torch.utils.data import DataLoader
|
| 46 |
-
from torch.
|
| 47 |
from torch.amp import GradScaler
|
| 48 |
|
| 49 |
from generative.networks.nets import (
|
|
@@ -57,8 +59,12 @@ from monai import transforms
|
|
| 57 |
from monai.utils import set_determinism
|
| 58 |
from monai.data.meta_tensor import MetaTensor
|
| 59 |
import torch.serialization
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
torch.serialization.add_safe_globals([MetaTensor])
|
|
|
|
|
|
|
| 62 |
|
| 63 |
from tqdm import tqdm
|
| 64 |
import matplotlib.pyplot as plt
|
|
@@ -381,6 +387,22 @@ def train(
|
|
| 381 |
train_df = dataset_df[dataset_df.split == 'train']
|
| 382 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
train_loader = DataLoader(
|
| 385 |
dataset=trainset,
|
| 386 |
num_workers=num_workers,
|
|
@@ -440,13 +462,14 @@ def train(
|
|
| 440 |
total_counter = 0
|
| 441 |
|
| 442 |
for epoch in range(n_epochs):
|
|
|
|
| 443 |
autoencoder.train()
|
| 444 |
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
|
| 445 |
-
progress_bar.set_description(f'Epoch {epoch
|
| 446 |
|
| 447 |
for step, batch in progress_bar:
|
| 448 |
# Generator Training
|
| 449 |
-
with autocast(enabled=True):
|
| 450 |
images = batch["image"].to(device)
|
| 451 |
reconstruction, z_mu, z_sigma = autoencoder(images)
|
| 452 |
|
|
@@ -462,7 +485,7 @@ def train(
|
|
| 462 |
gradacc_g.step(loss_g, step)
|
| 463 |
|
| 464 |
# Discriminator Training
|
| 465 |
-
with autocast(enabled=True):
|
| 466 |
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
|
| 467 |
d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
|
| 468 |
logits_real = discriminator(images.contiguous().detach())[-1]
|
|
@@ -604,4 +627,4 @@ def main():
|
|
| 604 |
|
| 605 |
|
| 606 |
if __name__ == '__main__':
|
| 607 |
-
main()
|
|
|
|
| 32 |
# }
|
| 33 |
|
| 34 |
import os
|
| 35 |
+
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
|
| 36 |
from typing import Optional, Union
|
| 37 |
import pandas as pd
|
| 38 |
import argparse
|
| 39 |
import numpy as np
|
| 40 |
+
|
| 41 |
import warnings
|
| 42 |
import torch
|
| 43 |
import torch.nn as nn
|
|
|
|
| 45 |
from torch.optim.optimizer import Optimizer
|
| 46 |
from torch.nn import L1Loss
|
| 47 |
from torch.utils.data import DataLoader
|
| 48 |
+
from torch.amp import autocast
|
| 49 |
from torch.amp import GradScaler
|
| 50 |
|
| 51 |
from generative.networks.nets import (
|
|
|
|
| 59 |
from monai.utils import set_determinism
|
| 60 |
from monai.data.meta_tensor import MetaTensor
|
| 61 |
import torch.serialization
|
| 62 |
+
from numpy.core.multiarray import _reconstruct
|
| 63 |
+
from numpy import ndarray, dtype
|
| 64 |
+
torch.serialization.add_safe_globals([_reconstruct])
|
| 65 |
torch.serialization.add_safe_globals([MetaTensor])
|
| 66 |
+
torch.serialization.add_safe_globals([ndarray])
|
| 67 |
+
torch.serialization.add_safe_globals([dtype])
|
| 68 |
|
| 69 |
from tqdm import tqdm
|
| 70 |
import matplotlib.pyplot as plt
|
|
|
|
| 387 |
train_df = dataset_df[dataset_df.split == 'train']
|
| 388 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
| 389 |
|
| 390 |
+
print(f"[DEBUG] Using cache_dir={cache_dir}")
|
| 391 |
+
print(f"[DEBUG] trainset length={len(trainset)}")
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
sample_debug = trainset[0] # Force a transform on the first record
|
| 395 |
+
print("[DEBUG] Successfully loaded sample 0 from trainset.")
|
| 396 |
+
except Exception as e:
|
| 397 |
+
print("[DEBUG] Error loading sample 0:", e)
|
| 398 |
+
|
| 399 |
+
import glob
|
| 400 |
+
|
| 401 |
+
hashfiles = glob.glob(os.path.join(cache_dir, "*.pt"))
|
| 402 |
+
print(f"[DEBUG] Found {len(hashfiles)} cached .pt files in {cache_dir}")
|
| 403 |
+
if hashfiles:
|
| 404 |
+
print("[DEBUG] Example cache file:", hashfiles[0])
|
| 405 |
+
|
| 406 |
train_loader = DataLoader(
|
| 407 |
dataset=trainset,
|
| 408 |
num_workers=num_workers,
|
|
|
|
| 462 |
total_counter = 0
|
| 463 |
|
| 464 |
for epoch in range(n_epochs):
|
| 465 |
+
print(f"[DEBUG] Starting epoch {epoch}/{n_epochs-1}")
|
| 466 |
autoencoder.train()
|
| 467 |
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
|
| 468 |
+
progress_bar.set_description(f'Epoch {epoch}')
|
| 469 |
|
| 470 |
for step, batch in progress_bar:
|
| 471 |
# Generator Training
|
| 472 |
+
with autocast(device, enabled=True):
|
| 473 |
images = batch["image"].to(device)
|
| 474 |
reconstruction, z_mu, z_sigma = autoencoder(images)
|
| 475 |
|
|
|
|
| 485 |
gradacc_g.step(loss_g, step)
|
| 486 |
|
| 487 |
# Discriminator Training
|
| 488 |
+
with autocast(device, enabled=True):
|
| 489 |
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
|
| 490 |
d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
|
| 491 |
logits_real = discriminator(images.contiguous().detach())[-1]
|
|
|
|
| 627 |
|
| 628 |
|
| 629 |
if __name__ == '__main__':
|
| 630 |
+
main()
|