manbeast3b commited on
Commit
24e51a4
·
verified ·
1 Parent(s): e37d677

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +6 -4
src/loss.py CHANGED
@@ -1,6 +1,6 @@
 
1
  count=0
2
  image_count=0
3
- _A=None
4
  import torch
5
  from tqdm import tqdm
6
  import numpy as np
@@ -107,11 +107,13 @@ class SchedulerWrapper:
107
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
108
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
109
  def step(B,model_output,timestep,sample,**F):
 
 
110
  D=sample;E=model_output;A=timestep;
111
- if timestep==1: print("resetting"); image_count+=1;count=0
112
  if True:
113
- C=B.scheduler.step(E,A,D,**F);A=A.tolist(); print(timestep, B.timesteps)
114
- if timestep+1==len(B.timesteps): torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_{count}.pth")
115
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
116
  B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
117
  else:C=B.loss_scheduler.step(E,A,D,**F);return C
 
1
+ _A=None
2
  count=0
3
  image_count=0
 
4
  import torch
5
  from tqdm import tqdm
6
  import numpy as np
 
107
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
108
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
109
  def step(B,model_output,timestep,sample,**F):
110
+ global image_count
111
+ global count
112
  D=sample;E=model_output;A=timestep;
113
+ if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0
114
  if True:
115
+ C=B.scheduler.step(E,A,D,**F);A=A.tolist();
116
+ if timestep==1: torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_{count}.pth")
117
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
118
  B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
119
  else:C=B.loss_scheduler.step(E,A,D,**F);return C