Spaces:
Runtime error
Runtime error
Commit
·
b1c39d5
1
Parent(s):
64efbd0
Update free_lunch_utils.py
Browse files- free_lunch_utils.py +4 -4
free_lunch_utils.py
CHANGED
|
@@ -97,10 +97,10 @@ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
|
| 97 |
# Only operate on the first two stages
|
| 98 |
if hidden_states.shape[1] == 1280:
|
| 99 |
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 100 |
-
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 101 |
if hidden_states.shape[1] == 640:
|
| 102 |
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 103 |
-
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 104 |
# ---------------------------------------------------------
|
| 105 |
|
| 106 |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
|
@@ -235,10 +235,10 @@ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
|
| 235 |
# Only operate on the first two stages
|
| 236 |
if hidden_states.shape[1] == 1280:
|
| 237 |
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 238 |
-
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 239 |
if hidden_states.shape[1] == 640:
|
| 240 |
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 241 |
-
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 242 |
# ---------------------------------------------------------
|
| 243 |
|
| 244 |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
|
|
|
| 97 |
# Only operate on the first two stages
|
| 98 |
if hidden_states.shape[1] == 1280:
|
| 99 |
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 100 |
+
# # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 101 |
if hidden_states.shape[1] == 640:
|
| 102 |
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 103 |
+
# res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 104 |
# ---------------------------------------------------------
|
| 105 |
|
| 106 |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
|
|
|
| 235 |
# Only operate on the first two stages
|
| 236 |
if hidden_states.shape[1] == 1280:
|
| 237 |
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 238 |
+
# res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 239 |
if hidden_states.shape[1] == 640:
|
| 240 |
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 241 |
+
# res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 242 |
# ---------------------------------------------------------
|
| 243 |
|
| 244 |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|