telecomadm1145 commited on
Commit
3d1ec01
·
verified ·
1 Parent(s): e7b69a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -263,6 +263,9 @@ class FluxLatentDINOFlow(nn.Module):
263
  self.type_emb_pixel = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
264
  self.type_emb_dino = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
265
  self.initialize_weights()
 
 
 
266
 
267
  def initialize_weights(self):
268
  for name, m in self.named_modules():
@@ -301,15 +304,30 @@ class FluxLatentDINOFlow(nn.Module):
301
  pixel_tokens = pixel_tokens.flatten(2).transpose(1, 2)
302
  pixel_tokens = pixel_tokens + self.type_emb_pixel
303
 
304
- with torch.no_grad():
305
- mean = torch.tensor([0.485, 0.456, 0.406], device=lq_img.device).view(1, 3, 1, 1)
306
- std = torch.tensor([0.229, 0.224, 0.225], device=lq_img.device).view(1, 3, 1, 1)
307
- dino_in = (lq_img * 0.5 + 0.5 - mean) / std
308
- dino_feats = self.dino.forward_features(dino_in)
309
- if getattr(self.dino, "num_prefix_tokens", 0) > 0:
310
- dino_feats = dino_feats[:, self.dino.num_prefix_tokens:]
311
- d_h = d_w = int(dino_feats.shape[1] ** 0.5)
312
- dino_map = dino_feats.transpose(1, 2).reshape(B, -1, d_h, d_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False)
315
  dino_tokens = self.dino_adapter(dino_map_resized)
 
263
  self.type_emb_pixel = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
264
  self.type_emb_dino = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
265
  self.initialize_weights()
266
+ # 新增缓存
267
+ self._cached_dino_map = None
268
+ self._cached_lq_hash = None # 可选:缓存输入哈希
269
 
270
  def initialize_weights(self):
271
  for name, m in self.named_modules():
 
304
  pixel_tokens = pixel_tokens.flatten(2).transpose(1, 2)
305
  pixel_tokens = pixel_tokens + self.type_emb_pixel
306
 
307
+ # 计算输入 hash
308
+ lq_hash = hash(lq_img.data_ptr()) # 简单用指针做哈希,也可用 tensor.sum().item() 更精确
309
+ if self._cached_dino_map is None or self._cached_lq_hash != lq_hash:
310
+ print("recalculating hash...")
311
+ # 只在缓存不存在或输入变化时计算 DINO
312
+ with torch.no_grad():
313
+ mean = torch.tensor([0.485, 0.456, 0.406], device=lq_img.device).view(1, 3, 1, 1)
314
+ std = torch.tensor([0.229, 0.224, 0.225], device=lq_img.device).view(1, 3, 1, 1)
315
+ dino_in = (lq_img * 0.5 + 0.5 - mean) / std
316
+ dino_feats = self.dino.forward_features(dino_in)
317
+ if getattr(self.dino, "num_prefix_tokens", 0) > 0:
318
+ dino_feats = dino_feats[:, self.dino.num_prefix_tokens:]
319
+ d_h = d_w = int(dino_feats.shape[1] ** 0.5)
320
+ dino_map = dino_feats.transpose(1, 2).reshape(B, -1, d_h, d_w)
321
+ dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False)
322
+ dino_tokens = self.dino_adapter(dino_map_resized)
323
+ dino_tokens = dino_tokens.flatten(2).transpose(1, 2)
324
+ dino_tokens = dino_tokens + self.type_emb_dino
325
+
326
+ # 更新缓存
327
+ self._cached_dino_map = dino_tokens
328
+ self._cached_lq_hash = lq_hash
329
+ else:
330
+ dino_tokens = self._cached_dino_map
331
 
332
  dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False)
333
  dino_tokens = self.dino_adapter(dino_map_resized)