manbeast3b commited on
Commit
2500952
·
verified ·
1 Parent(s): ce845a0

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +5 -0
src/loss.py CHANGED
@@ -1,6 +1,9 @@
1
  _A=None
 
 
2
  import torch
3
  from tqdm import tqdm
 
4
  class LossSchedulerModel(torch.nn.Module):
5
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
6
  def forward(A,t,xT,e_prev):
@@ -29,6 +32,8 @@ class SchedulerWrapper:
29
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
30
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
31
  def step(B,model_output,timestep,sample,**F):
 
 
32
  D=sample;E=model_output;A=timestep
33
  if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0; torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_rob_{image_count}_0.pth")
34
  if True:
 
1
  _A=None
2
+ count=0
3
+ image_count=0
4
  import torch
5
  from tqdm import tqdm
6
+ import numpy as np
7
  class LossSchedulerModel(torch.nn.Module):
8
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
9
  def forward(A,t,xT,e_prev):
 
32
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
33
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
34
  def step(B,model_output,timestep,sample,**F):
35
+ global image_count
36
+ global count
37
  D=sample;E=model_output;A=timestep
38
  if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0; torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_rob_{image_count}_0.pth")
39
  if True: