Update src/loss.py
Browse files- src/loss.py +1 -1
src/loss.py
CHANGED
|
@@ -98,7 +98,7 @@ class LossScheduler:
|
|
| 98 |
A.e_prev.append(model_output)
|
| 99 |
if timestep==A.timesteps[-1]:
|
| 100 |
with torch.enable_grad():
|
| 101 |
-
orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/
|
| 102 |
A.optimizer.zero_grad()
|
| 103 |
C=A.model(B,A.xT,A.e_prev)
|
| 104 |
L=torch.nn.functional.mse_loss(C, orig_output)
|
|
|
|
| 98 |
A.e_prev.append(model_output)
|
| 99 |
if timestep==A.timesteps[-1]:
|
| 100 |
with torch.enable_grad():
|
| 101 |
+
orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_20.pth")
|
| 102 |
A.optimizer.zero_grad()
|
| 103 |
C=A.model(B,A.xT,A.e_prev)
|
| 104 |
L=torch.nn.functional.mse_loss(C, orig_output)
|