Update src/loss.py
Browse files- src/loss.py +5 -0
src/loss.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
_A=None
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
from tqdm import tqdm
|
|
|
|
| 4 |
class LossSchedulerModel(torch.nn.Module):
|
| 5 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
| 6 |
def forward(A,t,xT,e_prev):
|
|
@@ -29,6 +32,8 @@ class SchedulerWrapper:
|
|
| 29 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 30 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 31 |
def step(B,model_output,timestep,sample,**F):
|
|
|
|
|
|
|
| 32 |
D=sample;E=model_output;A=timestep
|
| 33 |
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0; torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_rob_{image_count}_0.pth")
|
| 34 |
if True:
|
|
|
|
| 1 |
_A=None
|
| 2 |
+
count=0
|
| 3 |
+
image_count=0
|
| 4 |
import torch
|
| 5 |
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
class LossSchedulerModel(torch.nn.Module):
|
| 8 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
| 9 |
def forward(A,t,xT,e_prev):
|
|
|
|
| 32 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 33 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 34 |
def step(B,model_output,timestep,sample,**F):
|
| 35 |
+
global image_count
|
| 36 |
+
global count
|
| 37 |
D=sample;E=model_output;A=timestep
|
| 38 |
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0; torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_rob_{image_count}_0.pth")
|
| 39 |
if True:
|