Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -263,6 +263,9 @@ class FluxLatentDINOFlow(nn.Module):
|
|
| 263 |
self.type_emb_pixel = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 264 |
self.type_emb_dino = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 265 |
self.initialize_weights()
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
def initialize_weights(self):
|
| 268 |
for name, m in self.named_modules():
|
|
@@ -301,15 +304,30 @@ class FluxLatentDINOFlow(nn.Module):
|
|
| 301 |
pixel_tokens = pixel_tokens.flatten(2).transpose(1, 2)
|
| 302 |
pixel_tokens = pixel_tokens + self.type_emb_pixel
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False)
|
| 315 |
dino_tokens = self.dino_adapter(dino_map_resized)
|
|
|
|
| 263 |
self.type_emb_pixel = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 264 |
self.type_emb_dino = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 265 |
self.initialize_weights()
|
| 266 |
+
# 新增缓存
|
| 267 |
+
self._cached_dino_map = None
|
| 268 |
+
self._cached_lq_hash = None # 可选:缓存输入哈希
|
| 269 |
|
| 270 |
def initialize_weights(self):
|
| 271 |
for name, m in self.named_modules():
|
|
|
|
| 304 |
pixel_tokens = pixel_tokens.flatten(2).transpose(1, 2)
|
| 305 |
pixel_tokens = pixel_tokens + self.type_emb_pixel
|
| 306 |
|
| 307 |
+
# 计算输入 hash
|
| 308 |
+
lq_hash = hash(lq_img.data_ptr()) # 简单用指针做哈希,也可用 tensor.sum().item() 更精确
|
| 309 |
+
if self._cached_dino_map is None or self._cached_lq_hash != lq_hash:
|
| 310 |
+
print("recalculating hash...")
|
| 311 |
+
# 只在缓存不存在或输入变化时计算 DINO
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
mean = torch.tensor([0.485, 0.456, 0.406], device=lq_img.device).view(1, 3, 1, 1)
|
| 314 |
+
std = torch.tensor([0.229, 0.224, 0.225], device=lq_img.device).view(1, 3, 1, 1)
|
| 315 |
+
dino_in = (lq_img * 0.5 + 0.5 - mean) / std
|
| 316 |
+
dino_feats = self.dino.forward_features(dino_in)
|
| 317 |
+
if getattr(self.dino, "num_prefix_tokens", 0) > 0:
|
| 318 |
+
dino_feats = dino_feats[:, self.dino.num_prefix_tokens:]
|
| 319 |
+
d_h = d_w = int(dino_feats.shape[1] ** 0.5)
|
| 320 |
+
dino_map = dino_feats.transpose(1, 2).reshape(B, -1, d_h, d_w)
|
| 321 |
+
dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False)
|
| 322 |
+
dino_tokens = self.dino_adapter(dino_map_resized)
|
| 323 |
+
dino_tokens = dino_tokens.flatten(2).transpose(1, 2)
|
| 324 |
+
dino_tokens = dino_tokens + self.type_emb_dino
|
| 325 |
+
|
| 326 |
+
# 更新缓存
|
| 327 |
+
self._cached_dino_map = dino_tokens
|
| 328 |
+
self._cached_lq_hash = lq_hash
|
| 329 |
+
else:
|
| 330 |
+
dino_tokens = self._cached_dino_map
|
| 331 |
|
| 332 |
dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False)
|
| 333 |
dino_tokens = self.dino_adapter(dino_map_resized)
|