File size: 6,778 Bytes
24e51a4
2500952
 
e73d351
 
2500952
07e13fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e73d351
07e13fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ece2e8f
1edb460
 
 
07e13fd
e73d351
07e13fd
 
 
 
 
086e3dd
 
 
e73d351
07e13fd
 
e73d351
 
086e3dd
1edb460
e73d351
1edb460
 
e73d351
 
086e3dd
 
 
 
f34974b
cfca40b
086e3dd
 
bb312c6
086e3dd
e98e299
086e3dd
 
 
 
9dddae7
ea3e51b
 
086e3dd
 
 
cd9e403
086e3dd
 
07e13fd
 
 
 
086e3dd
 
 
 
 
 
 
07e13fd
 
 
 
 
 
e73d351
07e13fd
 
 
 
 
e73d351
0e4c4ae
e73d351
 
 
07e13fd
f90950e
ce845a0
e73d351
 
 
 
 
 
 
 
 
07e13fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
_A=None
count=0
image_count=0
import torch
from tqdm import tqdm
import numpy as np


class ExponentialScalingMatrix(torch.nn.Module):
	def __init__(self, size):
		super().__init__()
		self.base = torch.nn.Parameter(torch.ones(size, size)).to('cuda')
	def forward(self, timestep):
		return self.base * torch.exp(-torch.tensor(timestep) / 10.0).to('cuda')

class RecurrentDynamicScaling(torch.nn.Module):
	def __init__(self, size):
		super().__init__()
		self.rnn = torch.nn.GRUCell(size, size)
		self.state = torch.nn.Parameter(torch.zeros(size))  # Learnable state

	def forward(self, timestep, x, correction):
		self.state = self.rnn(self.state, correction)  # Gradients flow through GRU
		scaled_correction = correction * torch.sigmoid(self.state)  # Scale dynamically
		return x + scaled_correction

class AttentionDynamicScaling(torch.nn.Module):
	def __init__(self, size):
		super().__init__()
		self.attention = torch.nn.MultiheadAttention(size, num_heads=1)

	def forward(self, timestep, x, correction):
		query = correction.unsqueeze(0)  # Shape (1, batch, size)
		key = x.unsqueeze(0)
		value = correction.unsqueeze(0)
		scaled_correction, _ = self.attention(query, key, value)
		return x + scaled_correction.squeeze(0)
		
class LossSchedulerModel(torch.nn.Module):
	def __init__(A,wx,we,scale_factor=None,strategy="exponential_scaling"):
		super(LossSchedulerModel,A).__init__()
		assert len(wx.shape)==1 and len(we.shape)==2
		B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1]
		A.register_parameter('wx',torch.nn.Parameter(wx))
		A.register_parameter('we',torch.nn.Parameter(we))
		size=13
		if type(scale_factor)!=None:
			A.register_parameter('scale_factor',torch.nn.Parameter(scale_factor))
		else:
			A.scale_factor = torch.nn.Parameter(torch.ones(size)) 
		A.decay_rate = 0.1
		if strategy == "exponential_scaling":
			A.matrix_generator = ExponentialScalingMatrix(size)
		else:
			raise ValueError("Unknown strategy")

	def forward_other(A,t,xT,e_prev):
		B=e_prev
		assert t-len(B)+1==0;C=xT*A.wx[t]
		jitter=A.matrix_generator.forward(t)
		for(D,E,J)in zip(B,A.we[t],jitter[t]):
			dynamic_scale = torch.tanh(A.decay_rate * torch.tensor(t)).to(D.device)
			correction=D*(E+J.to(D.device))
			C += correction * dynamic_scale
		return C.to(xT.dtype) #C.to(xT.dtype)
	def forward(A,t,xT,e_prev):
		B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
		for(D,E)in zip(B,A.we[t]):C+=D*E
		return C.to(xT.dtype)

class LossScheduler:
	def __init__(A,timesteps,model, lr=0.01):
		A.timesteps=timesteps
		A.model=model
		A.init_noise_sigma=1.
		A.order=1
		# A.optimizer = torch.optim.SGD([A.model.scale_factor, A.model.wx, A.model.we], lr=lr)
		A.optimizer = torch.optim.SGD([A.model.wx, A.model.we], lr=lr)
		A.model.train()
	@staticmethod
	def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D)
	def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
	def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
	def scale_model_input(A,sample,*B,**C):return sample
	def step_orig(self,model_output,timestep,sample,*D,**E):
		A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
		if A.t_prev==-1:A.xT=sample
		A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
		if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
		else:A.t_prev=B
		return C,
	def step(self,model_output,timestep,sample,*D,**E):
		global image_count
		global count
		A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
		if timestep==A.timesteps[0]: print("resetting"); image_count+=1;count=0;
		print(timestep)
		if A.t_prev==-1:A.xT=sample
		A.e_prev.append(model_output)
		if timestep==51:
			with torch.enable_grad():
				orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_20.pth")
				A.optimizer.zero_grad()
				C=A.model(B,A.xT,A.e_prev)
				L=torch.nn.functional.mse_loss(C, orig_output)
				L.backward()
				A.optimizer.step(); print("optimized")
		else:
			C=A.model(B,A.xT,A.e_prev)
		if B+1==len(A.timesteps):
			A.xT=_A;A.e_prev=[];
			A.t_prev=-1
			A.save(f"/home/mbhat/weights/weights_latest_{image_count}_{int(timestep)}.pth")    
		else:A.t_prev=B
		return C,
	def step_other(self,model_output,timestep,sample,*D,**E):
		A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
		if A.t_prev==-1:A.xT=sample
		A.e_prev.append(model_output)
		if timestep==A.timesteps.tolist()[-1]:
			with torch.enable_grad():
				A.optimizer.zero_grad()
				C=A.model(B,A.xT,A.e_prev)
				L=torch.nn.functional.mse_loss(C, model_output)
				L.backward()
				A.optimizer.step()
		if B+1==len(A.timesteps):
			A.xT=_A;A.e_prev=[];
			A.t_prev=-1
			# A.save("/home/mbhat/weights.pth")    
		else:A.t_prev=B
		return C,
class SchedulerWrapper:
	def __init__(A,scheduler,loss_params_path='/home/mbhat/weights.pth'):
		A.scheduler=scheduler
		A.catch_x,A.catch_e,A.catch_x_={},{},{}
		A.loss_scheduler=_A
		A.loss_params_path=loss_params_path
	def set_timesteps(A,num_inference_steps,**C):
		D=11
		if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
		else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
	def step(B,model_output,timestep,sample,**F):
		D=sample;E=model_output;A=timestep; 
		if B.loss_scheduler is _A:
			C=B.scheduler.step(E,A,D,**F);A=A.tolist();
			if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
			B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
		else:C=B.loss_scheduler.step(E,A,D,**F);return C
	def scale_model_input(A,sample,timestep):return sample
	def add_noise(A,original_samples,noise,timesteps):B=A.scheduler.add_noise(original_samples,noise,timesteps);return B
	def get_path(C):
		A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
		for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
		H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
	def load_loss_params(A):B,C,D,E=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D,scale_factor=E);A.loss_scheduler=LossScheduler(B,A.loss_model)
	def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()