Update processing_penguinvl.py
#2
by Cyril666 - opened
- processing_penguinvl.py +93 -104
processing_penguinvl.py
CHANGED
|
@@ -204,120 +204,109 @@ def floor_by_factor(number: int, factor: int) -> int:
|
|
| 204 |
return math.floor(number / factor) * factor
|
| 205 |
|
| 206 |
def smart_resize(
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
| 211 |
"""
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 217 |
-
|
| 218 |
-
3. The aspect ratio of the image is maintained as closely as possible.
|
| 219 |
"""
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
raise ValueError(
|
| 223 |
-
f"
|
| 224 |
)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
if
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
elif
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
return max(
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
# 分块处理
|
| 261 |
-
patch1 = rearrange(
|
| 262 |
-
frame1_tensor, "c (h p1) (w p2) -> h w (c p1 p2)", p1=patch_size, p2=patch_size).float()
|
| 263 |
-
patch2 = rearrange(
|
| 264 |
-
frame2_tensor, "c (h p1) (w p2) -> h w (c p1 p2)", p1=patch_size, p2=patch_size).float()
|
| 265 |
|
| 266 |
norm1 = torch.norm(patch1, p=2, dim=-1, keepdim=True) + epsilon
|
| 267 |
norm2 = torch.norm(patch2, p=2, dim=-1, keepdim=True) + epsilon
|
| 268 |
-
|
| 269 |
-
normalized1 = patch1 / norm1
|
| 270 |
-
normalized2 = patch2 / norm2
|
| 271 |
-
cos_sim = (normalized1 * normalized2).sum(dim=-1)
|
| 272 |
-
|
| 273 |
-
zero_vector_mask = (norm1.squeeze() < 0.01) & (norm2.squeeze() < 0.01) # 全黑图
|
| 274 |
-
|
| 275 |
-
similar = torch.ones_like(cos_sim) # 默认全部相似
|
| 276 |
-
|
| 277 |
-
non_zero_mask = ~zero_vector_mask
|
| 278 |
-
similar[non_zero_mask] = (cos_sim[non_zero_mask] > threshold).float()
|
| 279 |
-
|
| 280 |
-
return similar[non_zero_mask].float().mean().item()
|
| 281 |
-
|
| 282 |
-
def extract_slow_fast_frames(frames, threshold = 0.95):
|
| 283 |
-
def _extract_slow_indices(frames):
|
| 284 |
-
assert frames.dim() == 4, "输入必须是4D张量 [N, C, H, W]"
|
| 285 |
-
|
| 286 |
-
# 首帧一定是Slow
|
| 287 |
-
slow_indices = [0]
|
| 288 |
-
# 定位这里,检查和image[0]报错是不是同一视频
|
| 289 |
-
last_key_frame = frames[0]
|
| 290 |
-
for i in range(1, frames.size(0)):
|
| 291 |
-
current_frame = frames[i]
|
| 292 |
-
sim = get_frame_sim(last_key_frame, current_frame)
|
| 293 |
-
|
| 294 |
-
if sim < threshold:
|
| 295 |
-
slow_indices.append(i)
|
| 296 |
-
last_key_frame = current_frame # 更新关键帧
|
| 297 |
-
|
| 298 |
-
return slow_indices
|
| 299 |
-
|
| 300 |
-
_, _, height, width = frames.shape
|
| 301 |
-
resized_height, resized_width = smart_resize(
|
| 302 |
-
height,
|
| 303 |
-
width,
|
| 304 |
-
factor=14,
|
| 305 |
-
min_pixels=10 * 14 * 14,
|
| 306 |
-
max_pixels=10240 * 14 * 14,
|
| 307 |
-
)
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
antialias=True,
|
| 314 |
-
).float()
|
| 315 |
|
| 316 |
-
slow_indices = _extract_slow_indices(resized_frames)
|
| 317 |
-
frame_types = torch.ones(size=(frames.size(0), ), dtype=torch.int32)
|
| 318 |
-
frame_types[slow_indices] = 0
|
| 319 |
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
|
| 323 |
class ChatTemplateKwargs(TypedDict, total=False):
|
|
|
|
| 204 |
return math.floor(number / factor) * factor
|
| 205 |
|
| 206 |
def smart_resize(
|
| 207 |
+
height: int,
|
| 208 |
+
width: int,
|
| 209 |
+
factor: int = 14,
|
| 210 |
+
min_pixels: int = 0,
|
| 211 |
+
max_pixels: int = 16384,
|
| 212 |
+
):
|
| 213 |
"""
|
| 214 |
+
Compute target (height, width) such that:
|
| 215 |
+
- Both dimensions are divisible by factor.
|
| 216 |
+
- Total pixels lie in [min_pixels, max_pixels].
|
| 217 |
+
- Aspect ratio is preserved as closely as possible.
|
|
|
|
|
|
|
|
|
|
| 218 |
"""
|
| 219 |
+
def round_by_factor(number: int, factor: int) -> int:
|
| 220 |
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
| 221 |
+
return round(number / factor) * factor
|
| 222 |
+
def ceil_by_factor(number: int, factor: int) -> int:
|
| 223 |
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
| 224 |
+
return math.ceil(number / factor) * factor
|
| 225 |
+
def floor_by_factor(number: int, factor: int) -> int:
|
| 226 |
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
| 227 |
+
return math.floor(number / factor) * factor
|
| 228 |
+
|
| 229 |
+
max_ratio = 200
|
| 230 |
+
if max(height, width) / min(height, width) > max_ratio:
|
| 231 |
raise ValueError(
|
| 232 |
+
f"Aspect ratio must be < {max_ratio}, got {max(height, width) / min(height, width)}"
|
| 233 |
)
|
| 234 |
+
h = max(factor, round_by_factor(height, factor))
|
| 235 |
+
w = max(factor, round_by_factor(width, factor))
|
| 236 |
+
if h * w > max_pixels:
|
| 237 |
+
scale = math.sqrt((height * width) / max_pixels)
|
| 238 |
+
h = floor_by_factor(height / scale, factor)
|
| 239 |
+
w = floor_by_factor(width / scale, factor)
|
| 240 |
+
elif h * w < min_pixels:
|
| 241 |
+
scale = math.sqrt(min_pixels / (height * width))
|
| 242 |
+
h = ceil_by_factor(height * scale, factor)
|
| 243 |
+
w = ceil_by_factor(width * scale, factor)
|
| 244 |
+
return max(h, factor), max(w, factor)
|
| 245 |
+
|
| 246 |
+
# Adapted from Keye-VL: https://github.com/Kwai-Keye/Keye
|
| 247 |
+
def get_frame_sim(
|
| 248 |
+
frame1: torch.Tensor,
|
| 249 |
+
frame2: torch.Tensor,
|
| 250 |
+
patch_size: int = 14,
|
| 251 |
+
threshold: float = 0.7,
|
| 252 |
+
epsilon: float = 1e-8,
|
| 253 |
+
) -> float:
|
| 254 |
+
"""Cosine similarity between two frames in HSV, averaged over patches. Returns mean similarity in [0, 1]."""
|
| 255 |
+
assert frame1.dim() == 3 and frame2.dim() == 3, "Frames must be 3D tensors [C, H, W]"
|
| 256 |
+
|
| 257 |
+
def to_hsv_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
| 258 |
+
arr = tensor.cpu().permute(1, 2, 0).numpy()
|
| 259 |
+
if arr.dtype in (np.float32, np.float64):
|
| 260 |
+
arr = arr.astype(np.uint8)
|
| 261 |
+
hsv = cv2.cvtColor(arr, cv2.COLOR_RGB2HSV)
|
| 262 |
+
return torch.from_numpy(hsv).permute(2, 0, 1).to(tensor.device).float()
|
| 263 |
+
|
| 264 |
+
f1 = to_hsv_tensor(frame1)
|
| 265 |
+
f2 = to_hsv_tensor(frame2)
|
| 266 |
+
patch1 = rearrange(f1, "c (h p1) (w p2) -> h w (c p1 p2)", p1=patch_size, p2=patch_size).float()
|
| 267 |
+
patch2 = rearrange(f2, "c (h p1) (w p2) -> h w (c p1 p2)", p1=patch_size, p2=patch_size).float()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
norm1 = torch.norm(patch1, p=2, dim=-1, keepdim=True) + epsilon
|
| 270 |
norm2 = torch.norm(patch2, p=2, dim=-1, keepdim=True) + epsilon
|
| 271 |
+
cos_sim = (patch1 / norm1 * patch2 / norm2).sum(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
+
both_near_zero = (norm1.squeeze() < 0.01) & (norm2.squeeze() < 0.01)
|
| 274 |
+
similar = torch.ones_like(cos_sim)
|
| 275 |
+
similar[~both_near_zero] = (cos_sim[~both_near_zero] > threshold).float()
|
| 276 |
+
return similar[~both_near_zero].float().mean().item()
|
|
|
|
|
|
|
| 277 |
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
+
# KI: keyframe indices (formerly slow/fast). 0 = key frame, 1 = intermediate frame.
|
| 280 |
+
K_PATCH = 14
|
| 281 |
+
K_MIN_PIXELS = 10 * 14 * 14
|
| 282 |
+
K_MAX_PIXELS = 10240 * 14 * 14
|
| 283 |
+
|
| 284 |
+
def extract_ki_frames(
|
| 285 |
+
frames: torch.Tensor,
|
| 286 |
+
threshold: float = MIN_FRAME_SIMILARITY,
|
| 287 |
+
) -> list:
|
| 288 |
+
"""
|
| 289 |
+
Label each frame as keyframe (0) or non-keyframe (1) by comparing to the previous keyframe.
|
| 290 |
+
First frame is always a keyframe; a new keyframe is chosen when similarity drops below threshold.
|
| 291 |
+
"""
|
| 292 |
+
assert frames.dim() == 4, "Frames must be 4D tensor [N, C, H, W]"
|
| 293 |
+
|
| 294 |
+
def _keyframe_indices(f: torch.Tensor) -> list:
|
| 295 |
+
indices = [0]
|
| 296 |
+
key = f[0]
|
| 297 |
+
for i in range(1, f.size(0)):
|
| 298 |
+
if get_frame_sim(key, f[i]) < threshold:
|
| 299 |
+
indices.append(i)
|
| 300 |
+
key = f[i]
|
| 301 |
+
return indices
|
| 302 |
+
|
| 303 |
+
_, _, h, w = frames.shape
|
| 304 |
+
rh, rw = smart_resize(h, w, factor=K_PATCH, min_pixels=K_MIN_PIXELS, max_pixels=K_MAX_PIXELS)
|
| 305 |
+
resized = nn.functional.interpolate(frames, (rh, rw), mode="bilinear", antialias=True).float()
|
| 306 |
+
k_indices = _keyframe_indices(resized)
|
| 307 |
+
frame_types = torch.ones(frames.size(0), dtype=torch.int32)
|
| 308 |
+
frame_types[k_indices] = 0
|
| 309 |
+
return frame_types.tolist()
|
| 310 |
|
| 311 |
|
| 312 |
class ChatTemplateKwargs(TypedDict, total=False):
|