Spaces:
Runtime error
Runtime error
Commit
·
ed16a12
1
Parent(s):
861e5b3
Update free_lunch_utils.py
Browse files- free_lunch_utils.py +15 -2
free_lunch_utils.py
CHANGED
|
@@ -234,10 +234,23 @@ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
|
| 234 |
# --------------- FreeU code -----------------------
|
| 235 |
# Only operate on the first two stages
|
| 236 |
if hidden_states.shape[1] == 1280:
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 239 |
if hidden_states.shape[1] == 640:
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 242 |
# ---------------------------------------------------------
|
| 243 |
|
|
|
|
| 234 |
# --------------- FreeU code -----------------------
|
| 235 |
# Only operate on the first two stages
|
| 236 |
if hidden_states.shape[1] == 1280:
|
| 237 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
| 238 |
+
B = hidden_mean.shape[0]
|
| 239 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 240 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 241 |
+
|
| 242 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
| 243 |
+
|
| 244 |
+
hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
|
| 245 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 246 |
if hidden_states.shape[1] == 640:
|
| 247 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
| 248 |
+
B = hidden_mean.shape[0]
|
| 249 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 250 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 251 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
| 252 |
+
|
| 253 |
+
hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
|
| 254 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 255 |
# ---------------------------------------------------------
|
| 256 |
|