| from __future__ import annotations |
|
|
| import math |
| import re |
| from dataclasses import dataclass |
| from typing import Any, Callable |
|
|
| import torch |
|
|
| _RELAY_MARKER_RE = re.compile(r"\[([^\]]+)\]") |
| _NUMERIC_RE = re.compile(r"^\d+(?:\.\d+)?$") |
| _SECONDS_RE = re.compile(r"^(\d+(?:\.\d+)?)(?:s|sec|secs|second|seconds)$", re.IGNORECASE) |
|
|
| __all__ = [ |
| "PromptRelayBound", |
| "PromptRelayConditioning", |
| "PromptRelayMaskBuilder", |
| "PromptRelayPlan", |
| "PromptRelaySegment", |
| "encode_prompt_relay", |
| "parse_prompt_relay", |
| ] |
|
|
|
|
| @dataclass(frozen=True) |
| class PromptRelayBound: |
| value: float |
| unit: str |
|
|
| def resolve(self, total_seconds: float, total_frames: int, inclusive_end: bool = False) -> float: |
| if self.unit == "percent": |
| return max(0.0, min(1.0, self.value)) |
| if self.unit == "frame": |
| if total_frames <= 1: |
| return 0.0 |
| frame_index = self.value if inclusive_end else self.value - 1.0 |
| frame_index = max(frame_index, 0.0) |
| return max(0.0, min(1.0, frame_index / float(total_frames - 1))) |
| if total_seconds <= 0: |
| return 0.0 |
| return max(0.0, min(1.0, self.value / total_seconds)) |
|
|
|
|
| @dataclass(frozen=True) |
| class PromptRelaySegment: |
| start: PromptRelayBound |
| end: PromptRelayBound | None |
| prompt: str |
| key_start: int = 0 |
| key_end: int = 0 |
|
|
|
|
| @dataclass(frozen=True) |
| class PromptRelayPlan: |
| global_prompt: str |
| segments: tuple[PromptRelaySegment, ...] |
|
|
|
|
| @dataclass(frozen=True) |
| class _RuntimeSegment: |
| start: float |
| end: float |
| key_start: int |
| key_end: int |
|
|
|
|
| @dataclass(frozen=True) |
| class PromptRelayConditioning: |
| video_context: torch.Tensor |
| audio_context: torch.Tensor | None |
| video_mask_builder: PromptRelayMaskBuilder | None |
| audio_mask_builder: PromptRelayMaskBuilder | None |
|
|
| @property |
| def mask_builder(self) -> PromptRelayMaskBuilder | None: |
| return self.video_mask_builder |
|
|
|
|
| class PromptRelayMaskBuilder: |
| def __init__( |
| self, |
| key_valid: torch.Tensor, |
| segments: list[_RuntimeSegment], |
| positive_key_count: int, |
| visible_start_ratio: float = 0.0, |
| epsilon: float = 1e-3, |
| padding_bias: float = -100.0, |
| ) -> None: |
| self.key_valid = key_valid.detach().to("cpu", dtype=torch.bool) |
| self.segments = tuple(segments) |
| self.positive_key_count = int(positive_key_count) |
| self.visible_start_ratio = max(0.0, min(1.0, float(visible_start_ratio))) |
| self.sigma = float(1.0 / math.log(1.0 / epsilon)) if 0 < epsilon < 1 else 0.1448 |
| self.padding_bias = float(padding_bias) |
|
|
| def __call__(self, state: Any, frame_indices: torch.Tensor | None, context: Any) -> torch.Tensor | None: |
| if not self.segments: |
| return None |
| context_len = _context_seq_len(context) |
| if context_len <= 0: |
| return None |
| device = state.latent.device |
| dtype = state.latent.dtype if torch.is_floating_point(state.latent) else torch.float32 |
| frame_indices = _resolve_frame_indices(state, frame_indices).to(device=device) |
| batch_size, query_len = frame_indices.shape |
| key_valid = self.key_valid.to(device=device) |
| if key_valid.numel() < self.positive_key_count: |
| key_valid = torch.cat([key_valid, torch.ones(self.positive_key_count - key_valid.numel(), device=device, dtype=torch.bool)]) |
| key_valid = key_valid[: self.positive_key_count] |
| if context_len > self.positive_key_count: |
| key_valid = torch.cat([key_valid, torch.ones(context_len - self.positive_key_count, device=device, dtype=torch.bool)]) |
| else: |
| key_valid = key_valid[:context_len] |
|
|
| positive_len = min(self.positive_key_count, context_len) |
| mask = torch.zeros((batch_size, query_len, context_len), device=device, dtype=torch.float32) |
|
|
| raw_query_frames = frame_indices.to(torch.float32) |
| raw_max_frame = raw_query_frames.amax(dim=1, keepdim=True).clamp_min(1.0) |
| visible_start = raw_max_frame * self.visible_start_ratio |
| query_frames = raw_query_frames - visible_start |
| max_frame = (raw_max_frame - visible_start).clamp_min(1.0) |
| sigma_sq = 2.0 * self.sigma * self.sigma |
| for segment in self.segments: |
| start_key = min(segment.key_start, positive_len) |
| end_key = min(segment.key_end, positive_len) |
| if start_key >= end_key: |
| continue |
| start = torch.tensor(segment.start, device=device, dtype=torch.float32) * max_frame |
| end = torch.tensor(segment.end, device=device, dtype=torch.float32) * max_frame |
| length = (end - start).clamp_min(1.0) |
| midpoint = (start + end) * 0.5 |
| window = (length * 0.5 - 2.0).clamp_min(0.0) |
| distance = (query_frames - midpoint).abs() |
| cost = torch.relu(distance - window).square() / sigma_sq |
| mask[:, :, start_key:end_key] = -cost.unsqueeze(-1) |
|
|
| if key_valid.numel() < context_len: |
| key_valid = torch.cat([key_valid, torch.zeros(context_len - key_valid.numel(), device=device, dtype=torch.bool)]) |
| mask[:, :, ~key_valid[:context_len]] = self.padding_bias |
| return mask.to(dtype=dtype).unsqueeze(2) |
|
|
|
|
| def parse_prompt_relay(prompt: str) -> PromptRelayPlan | None: |
| current_bounds: tuple[PromptRelayBound, PromptRelayBound | None] | None = None |
| last_valid_end = 0 |
| global_parts: list[str] = [] |
| segments: list[PromptRelaySegment] = [] |
| for match in _RELAY_MARKER_RE.finditer(prompt): |
| bounds = _parse_marker(match.group(1)) |
| if bounds is None: |
| continue |
| if current_bounds is None: |
| global_parts.append(prompt[last_valid_end : match.start()]) |
| else: |
| segment_prompt = prompt[last_valid_end : match.start()].strip() |
| if segment_prompt: |
| segments.append(PromptRelaySegment(current_bounds[0], current_bounds[1], segment_prompt)) |
| current_bounds = bounds |
| last_valid_end = match.end() |
| if current_bounds is None: |
| return None |
| segment_prompt = prompt[last_valid_end:].strip() |
| if segment_prompt: |
| segments.append(PromptRelaySegment(current_bounds[0], current_bounds[1], segment_prompt)) |
| if not segments: |
| return None |
| return PromptRelayPlan("".join(global_parts).strip(), tuple(segments)) |
|
|
|
|
| def encode_prompt_relay( |
| prompt: str, |
| encode_fn: Callable[[list[str]], list[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor]]], |
| text_encoder_cache: Any, |
| device: torch.device | str, |
| num_frames: int, |
| frame_rate: float, |
| tokenizer: Any, |
| visible_frame_offset: int = 0, |
| ) -> PromptRelayConditioning | None: |
| plan = parse_prompt_relay(prompt) |
| if plan is None: |
| return None |
| full_prompt, token_ranges = _build_full_prompt_and_token_ranges(plan, tokenizer) |
| encoded = text_encoder_cache.encode( |
| encode_fn, |
| [full_prompt], |
| device=device, |
| parallel=True, |
| cache_keys=[("prompt_relay_full", full_prompt)], |
| )[0] |
| video_context, audio_context, video_mask, audio_mask = encoded |
| return PromptRelayConditioning( |
| video_context=video_context, |
| audio_context=audio_context, |
| video_mask_builder=_build_mask_builder(plan, video_context, video_mask, token_ranges, num_frames, frame_rate, visible_frame_offset), |
| audio_mask_builder=None if audio_context is None else _build_mask_builder(plan, audio_context, audio_mask, token_ranges, num_frames, frame_rate, visible_frame_offset), |
| ) |
|
|
|
|
| def _build_mask_builder( |
| plan: PromptRelayPlan, |
| context: torch.Tensor, |
| mask: torch.Tensor, |
| token_ranges: list[tuple[int, int]], |
| num_frames: int, |
| frame_rate: float, |
| visible_frame_offset: int = 0, |
| ) -> PromptRelayMaskBuilder | None: |
| runtime_segments = [] |
| num_frames = max(int(num_frames), 1) |
| visible_frame_offset = min(max(int(visible_frame_offset), 0), max(num_frames - 1, 0)) |
| visible_num_frames = max(num_frames - visible_frame_offset, 1) |
| visible_start_ratio = float(visible_frame_offset) / float(num_frames - 1) if num_frames > 1 else 0.0 |
| total_seconds = max((visible_num_frames - 1) / max(float(frame_rate), 1e-6), 0.0) |
| seq_len = _seq_len(context) |
| for segment, (start_key, end_key) in zip(plan.segments, token_ranges, strict=True): |
| start_key = min(max(int(start_key), 0), seq_len) |
| end_key = min(max(int(end_key), start_key), seq_len) |
| if start_key >= end_key: |
| continue |
| start = segment.start.resolve(total_seconds, visible_num_frames) |
| end = 1.0 if segment.end is None else segment.end.resolve(total_seconds, visible_num_frames, inclusive_end=True) |
| end = max(start, end) |
| runtime_segments.append(_RuntimeSegment(start, end, start_key, end_key)) |
| if not any(segment.start > 0.0 or segment.end < 1.0 for segment in runtime_segments): |
| return None |
| return PromptRelayMaskBuilder(_normalize_key_mask(mask, seq_len), runtime_segments, seq_len, visible_start_ratio=visible_start_ratio) |
|
|
|
|
| def _parse_marker(marker: str) -> tuple[PromptRelayBound, PromptRelayBound | None] | None: |
| candidate = None |
| for index, char in enumerate(marker): |
| if char != ":": |
| continue |
| start = _parse_bound(marker[:index].strip()) |
| if start is None: |
| continue |
| end_text = marker[index + 1 :].strip() |
| end = None if not end_text else _parse_bound(end_text) |
| if end_text and end is None: |
| continue |
| if end is not None and end.unit != start.unit: |
| continue |
| if end is not None and end.value < start.value: |
| continue |
| candidate = (start, end) |
| return candidate |
|
|
|
|
| def _parse_bound(text: str) -> PromptRelayBound | None: |
| if not text: |
| return None |
| if text.endswith("%"): |
| value = text[:-1].strip() |
| return PromptRelayBound(float(value) / 100.0, "percent") if _NUMERIC_RE.match(value) else None |
| seconds_match = _SECONDS_RE.match(text) |
| if seconds_match: |
| return PromptRelayBound(float(seconds_match.group(1)), "seconds") |
| if ":" in text: |
| parts = text.split(":") |
| if not all(_NUMERIC_RE.match(part) for part in parts): |
| return None |
| total = 0.0 |
| for part in parts: |
| total = total * 60.0 + float(part) |
| return PromptRelayBound(total, "seconds") |
| if _NUMERIC_RE.match(text): |
| return PromptRelayBound(float(text), "frame") |
| return None |
|
|
|
|
| def _build_full_prompt_and_token_ranges(plan: PromptRelayPlan, tokenizer: Any) -> tuple[str, list[tuple[int, int]]]: |
| full_prompt = plan.global_prompt.strip() |
| token_ranges: list[tuple[int, int]] = [] |
| for segment in plan.segments: |
| separator = "\n" if full_prompt else "" |
| prefix = full_prompt + separator |
| full_prompt = prefix + segment.prompt.strip() |
| token_ranges.append((_content_token_count(tokenizer, prefix), _content_token_count(tokenizer, full_prompt))) |
| return full_prompt, token_ranges |
|
|
|
|
| def _content_token_count(tokenizer: Any, text: str) -> int: |
| raw_tokenizer = getattr(tokenizer, "tokenizer", tokenizer) |
| max_length = int(getattr(tokenizer, "max_length", getattr(raw_tokenizer, "model_max_length", 1024))) |
| encoded = raw_tokenizer( |
| text.strip(), |
| padding=False, |
| max_length=max_length, |
| truncation=True, |
| return_tensors=None, |
| ) |
| input_ids = encoded["input_ids"] |
| if input_ids and isinstance(input_ids[0], list): |
| input_ids = input_ids[0] |
| eos_token_id = getattr(raw_tokenizer, "eos_token_id", None) |
| if eos_token_id is not None and input_ids and input_ids[-1] == eos_token_id: |
| input_ids = input_ids[:-1] |
| return len(input_ids) |
|
|
|
|
| def _seq_len(context: torch.Tensor) -> int: |
| return int(context.shape[0] if context.dim() == 2 else context.shape[1]) |
|
|
|
|
| def _normalize_key_mask(mask: torch.Tensor, seq_len: int) -> torch.Tensor: |
| mask = mask.detach().to("cpu", dtype=torch.bool).reshape(-1) |
| if mask.numel() < seq_len: |
| mask = torch.cat([mask, torch.ones(seq_len - mask.numel(), dtype=torch.bool)]) |
| return mask[:seq_len] |
|
|
|
|
| def _context_seq_len(context: Any) -> int: |
| tensor = getattr(context, "projected_context", None) |
| if tensor is None: |
| tensor = getattr(context, "context", context) |
| if tensor is None: |
| return 0 |
| return _seq_len(tensor) |
|
|
|
|
| def _resolve_frame_indices(state: Any, frame_indices: torch.Tensor | None) -> torch.Tensor: |
| if frame_indices is not None: |
| return frame_indices |
| positions = state.positions |
| if positions is not None and positions.ndim >= 4 and positions.shape[1] > 0: |
| frame_times = positions[:, 0, :, 0] |
| changes = torch.cat( |
| [torch.zeros((frame_times.shape[0], 1), device=frame_times.device, dtype=torch.long), (frame_times[:, 1:] != frame_times[:, :-1]).to(torch.long)], |
| dim=1, |
| ) |
| return torch.cumsum(changes, dim=1) |
| return torch.zeros((state.latent.shape[0], state.latent.shape[1]), device=state.latent.device, dtype=torch.long) |
|
|