| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| from abc import ABC | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from flow_matching.loss import MixturePathGeneralizedKL, EditFlowsLoss, EditFlowsLossReParam | |
| from flow_matching.path import MixtureDiscreteProbPath # for the scheduler only | |
| from flow_matching.path.scheduler import PolynomialConvexScheduler | |
| from flow_matching.path.editflows_adapter import EditFlowsPathAdapter # <-- NEW import | |
| from torch import Tensor | |
| from torch.nn.modules.loss import _Loss | |
| class SourceDistribution(ABC): | |
| def __init__(self) -> None: | |
| ... | |
| def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: | |
| ... | |
| def sample_like(self, tensor_like: Tensor) -> Tensor: | |
| ... | |
| class MaskedSourceDistribution(SourceDistribution): | |
| def __init__(self, mask_token: int) -> None: | |
| self.mask_token = mask_token | |
| def masked(self) -> bool: | |
| return True | |
| def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: | |
| return torch.zeros(tensor_size, device=device).fill_(self.mask_token).long() | |
| def sample_like(self, tensor_like: Tensor) -> Tensor: | |
| return torch.zeros_like(tensor_like).fill_(self.mask_token).long() | |
| import torch | |
| from typing import List, Tuple, Optional, Union | |
| Tensor = torch.Tensor | |
| class UniformSourceDistribution: | |
| def __init__(self, vocab_size, special_token_ids = None): | |
| self.vocab_size = vocab_size | |
| self.special_token_ids = set(special_token_ids) if special_token_ids is not None else set() | |
| # Compute allowed tokens by removing all special tokens from vocab | |
| self._allowed_tokens = [i for i in range(vocab_size) if i not in self.special_token_ids] | |
| if len(self._allowed_tokens) == 0: | |
| raise ValueError(f"All tokens are special tokens: {special_token_ids}") | |
| def masked(self) -> bool: | |
| return False | |
| def _sample_from_allowed(self, shape, device, allowed_tokens = None, generator = None): | |
| """Sample uniformly from allowed tokens with given shape.""" | |
| if allowed_tokens is None: | |
| allowed_tokens = self._allowed_tokens | |
| if len(allowed_tokens) == 0: | |
| raise ValueError("No allowed tokens provided") | |
| allowed_tensor = torch.tensor(allowed_tokens, dtype=torch.long, device=device) | |
| num_allowed = len(allowed_tokens) | |
| if shape.numel() == 0: | |
| # Return an empty tensor (length-0 sequence) | |
| return torch.empty(shape, dtype=torch.long, device=device) | |
| # Sample indices in [0, num_allowed) | |
| indices = torch.randint( | |
| low=0, high=num_allowed, size=shape, device=device, generator=generator | |
| ) | |
| return allowed_tensor[indices] | |
| def sample(self, tensor_size, device, allowed_tokens = None, generator = None): | |
| return self._sample_from_allowed(tensor_size, device, allowed_tokens, generator) | |
| def sample_like(self, tensor_like, allowed_tokens = None, generator = None): | |
| """ | |
| Keep original semantics: sample tokens with the SAME SHAPE(S) as tensor_like. | |
| """ | |
| if isinstance(tensor_like, (list, tuple)): | |
| return [ | |
| self._sample_from_allowed(seq.shape, seq.device, allowed_tokens, generator) | |
| for seq in tensor_like | |
| ] | |
| return self._sample_from_allowed(tensor_like.shape, tensor_like.device, allowed_tokens, generator) | |
| def sample_like( | |
| self, | |
| tensor_like, | |
| allowed_tokens: Optional[List[int]] = None, | |
| min_len_factor: float = 0.0, | |
| max_len_factor: float = 2.0, | |
| generator: Optional[torch.Generator] = None, | |
| ): | |
| """ | |
| For each reference x1, sample x0 with LENGTH L ~ Uniform{ floor(min_len_factor*N) .. floor(max_len_factor*N) }, | |
| defaulting to [0, 2*N]. Returns an empty tensor if L == 0. | |
| Supports Tensor (1D) or List[Tensor] (ragged). If you pass a 2D tensor, we | |
| interpret N as the last dimension and return a 1D sequence for that tensor. | |
| """ | |
| def _one(x1): | |
| assert x1.dim() >= 1, "x1 must be at least 1D" | |
| device = x1.device | |
| dtype = torch.long | |
| N = int(x1.shape[-1]) # use last-dim length as reference | |
| lo = int(max(0, int(min_len_factor * N))) | |
| hi = int(max(0, int(max_len_factor * N))) | |
| # randint is [low, high), so make high inclusive with +1 | |
| L = int(torch.randint(low=lo, high=hi + 1, size=(1,), device=device, generator=generator).item()) | |
| if L == 0: | |
| return torch.empty((0,), dtype=dtype, device=device) | |
| return self._sample_from_allowed(torch.Size([L]), device, allowed_tokens, generator) | |
| if isinstance(tensor_like, (list, tuple)): | |
| return [_one(seq) for seq in tensor_like] | |
| else: | |
| return _one(tensor_like) | |
| def sample_x0_from_x1(self, x1, pad_id, allowed_tokens, scale_size = 1.0, bos_id = 0, eos_id = 2): | |
| """ | |
| For each sequence in x1, sample an x0 whose *core* length (excluding BOS/EOS) | |
| is in [0, scale_size * len_valid(x1)], where len_valid(x1) counts only tokens that are | |
| NOT {BOS, EOS, PAD}. | |
| Rules: | |
| - x0 always starts with BOS and ends with EOS | |
| - x0 core tokens are sampled uniformly from vocab excluding {BOS, EOS, PAD} | |
| - We batch and pad x0 to a common length (B, L0) with pad_id | |
| - "length of x0 does not account for BOS and EOS" = the sampled core length | |
| Returns: | |
| x0: (B, L0) Long, padded with pad_id | |
| """ | |
| device = x1.device | |
| B, L1 = x1.shape | |
| # compute valid length of x1 per sequence | |
| # valid = not pad, not BOS, not EOS | |
| valid_mask_x1 = (x1 != pad_id) & (x1 != bos_id) & (x1 != eos_id) | |
| valid_len = valid_mask_x1.sum(dim=1) # (B,) | |
| # we will store all sequences here before padding | |
| x0_seqs = [] | |
| for b in range(B): | |
| max_core_len = int((1 + scale_size) * valid_len[b].item()) # may be 0 | |
| min_core_len = int((1 - scale_size) * valid_len[b].item()) | |
| # sample core length in [0, max_core_len] | |
| core_len = int(torch.randint(low=min_core_len, high=max_core_len + 1, size=(1,), device=device).item()) | |
| # 4) sample core tokens | |
| if core_len > 0: | |
| idx = torch.randint(0, allowed_tokens.size(0), (core_len,), device=device) | |
| core_tokens = allowed_tokens[idx] # (core_len,) | |
| else: | |
| core_tokens = torch.empty(0, dtype=torch.long, device=device) | |
| # 5) build full x0: [BOS] + core + [EOS] | |
| seq = torch.cat([ | |
| torch.tensor([bos_id], device=device, dtype=torch.long), | |
| core_tokens, | |
| torch.tensor([eos_id], device=device, dtype=torch.long), | |
| ], dim=0) # (1 + core_len + 1,) | |
| x0_seqs.append(seq) | |
| x0 = torch.nn.utils.rnn.pad_sequence(x0_seqs, batch_first=True, padding_value=pad_id) | |
| return x0 | |
| # class UniformSourceDistribution(SourceDistribution): | |
| # def __init__(self, vocab_size: int, special_token_ids: Optional[List[int]] = None) -> None: | |
| # self.vocab_size = vocab_size | |
| # self.special_token_ids = set(special_token_ids) if special_token_ids is not None else set() | |
| # # Compute allowed tokens by removing all special tokens from vocab | |
| # self._allowed_tokens = [i for i in range(vocab_size) if i not in self.special_token_ids] | |
| # if len(self._allowed_tokens) == 0: | |
| # raise ValueError(f"All tokens are special tokens: {special_token_ids}") | |
| # @property | |
| # def masked(self) -> bool: | |
| # return False | |
| # def _sample_from_allowed( | |
| # self, shape: torch.Size, device: torch.device, allowed_tokens: Optional[List[int]] = None | |
| # ) -> Tensor: | |
| # """Sample uniformly from allowed tokens.""" | |
| # if allowed_tokens is None: | |
| # # Use default allowed tokens (vocab minus special tokens) | |
| # allowed_tokens = self._allowed_tokens | |
| # if len(allowed_tokens) == 0: | |
| # raise ValueError(f"No allowed tokens provided") | |
| # # print(f"allowed_tokens: {allowed_tokens}") | |
| # # Sample indices into allowed_tokens, then map to actual token IDs | |
| # allowed_tensor = torch.tensor(allowed_tokens, dtype=torch.long, device=device) | |
| # num_allowed = len(allowed_tokens) | |
| # # Sample indices in [0, num_allowed) | |
| # indices = torch.randint(size=shape, high=num_allowed, device=device) | |
| # return allowed_tensor[indices] | |
| # def sample( | |
| # self, | |
| # tensor_size: Tuple[int, ...], | |
| # device: torch.device, | |
| # allowed_tokens: Optional[List[int]] = None | |
| # ) -> Tensor: | |
| # return self._sample_from_allowed(tensor_size, device, allowed_tokens) | |
| # def sample_like( | |
| # self, | |
| # tensor_like: Union[Tensor, List[Tensor]], | |
| # allowed_tokens: Optional[List[int]] = None | |
| # ) -> Union[Tensor, List[Tensor]]: | |
| # """ | |
| # Sample uniform tokens matching the shape of tensor_like. | |
| # Supports both Tensor and List[Tensor] for ragged inputs. | |
| # Args: | |
| # tensor_like: Either a Tensor or List[Tensor] (for ragged inputs) | |
| # allowed_tokens: Optional list of allowed token IDs. If None, uses vocab minus special tokens. | |
| # Returns: | |
| # Tensor or List[Tensor] matching the input shape(s) | |
| # """ | |
| # # Handle ragged input (list of tensors) | |
| # if isinstance(tensor_like, (list, tuple)): | |
| # return [ | |
| # self._sample_from_allowed(seq.shape, seq.device, allowed_tokens) | |
| # for seq in tensor_like | |
| # ] | |
| # # Handle regular tensor input | |
| # return self._sample_from_allowed(tensor_like.shape, tensor_like.device, allowed_tokens) | |
| class EmpiricalSourceDistribution(SourceDistribution): | |
| def __init__(self, vocab_size: int, probs: torch.Tensor, length: int): | |
| self.vocab_size = vocab_size | |
| self.registered = probs / probs.sum() | |
| self.length = length # e.g., 100 tokens | |
| def masked(self) -> bool: | |
| return False | |
| def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: | |
| B = tensor_size[0] | |
| idx = torch.multinomial(self.registered.to(device), num_samples=self.length, replacement=True) | |
| return idx.view(1, self.length).repeat(B, 1) | |
| def sample_like(self, tensor_like: Tensor) -> Tensor: | |
| B = tensor_like.shape[0] | |
| idx = torch.multinomial(self.registered.to(tensor_like.device), num_samples=self.length, replacement=True) | |
| return idx.view(1, self.length).repeat(B, 1) | |
| # NOTE: return type changed to the adapter (ragged). We only rely on .scheduler(t) and .sample(...). | |
| def get_path(scheduler_type: str, exponent: Optional[float] = None, eps_id: int = -1) -> EditFlowsPathAdapter: | |
| if scheduler_type == "polynomial": | |
| # paper uses cubic => exponent=3 | |
| scheduler = PolynomialConvexScheduler(n=exponent) | |
| else: | |
| raise ValueError(f"{scheduler_type} is not supported") | |
| # MixtureDiscreteProbPath carries the scheduler; the adapter will sample ragged z_t itself | |
| mixture = MixtureDiscreteProbPath(scheduler=scheduler) | |
| return EditFlowsPathAdapter(mixture_path=mixture, eps_id=eps_id) | |
| def get_source_distribution( | |
| source_distribution: str, | |
| p_emp: Optional[Tensor] = None, | |
| length: Optional[int] = None, | |
| vocab_size: Optional[int] = None, | |
| special_token_ids: Optional[List[int]] = None, | |
| ) -> SourceDistribution: | |
| if p_emp is not None: | |
| assert vocab_size is not None and length is not None, "Empirical source requires vocab_size and length" | |
| return EmpiricalSourceDistribution(vocab_size=vocab_size, probs=p_emp, length=length) | |
| if source_distribution == "mask": | |
| assert vocab_size is not None, "Masked source requires vocab_size" | |
| return MaskedSourceDistribution(mask_token=vocab_size) | |
| elif source_distribution == "uniform": | |
| assert vocab_size is not None, "Uniform source requires vocab_size" | |
| return UniformSourceDistribution(vocab_size=vocab_size, special_token_ids=special_token_ids) | |
| else: | |
| raise ValueError(f"{source_distribution} is not supported") | |
| def get_loss_function(loss_function: str, path: Optional[Union[MixtureDiscreteProbPath, EditFlowsPathAdapter]] = None) -> _Loss: | |
| if loss_function == "cross_entropy": | |
| return torch.nn.CrossEntropyLoss() | |
| elif loss_function == "generalized_kl": | |
| assert path is not None | |
| # Generalized KL still expects a (dense) path; fine for DFM experiments | |
| return MixturePathGeneralizedKL(path=path) | |
| elif loss_function == "editflows": | |
| # Ragged EF loss does NOT need the path; training.step precomputes the weight | |
| return EditFlowsLoss(reduction="mean") | |
| elif loss_function == "editflows_reparam": | |
| return EditFlowsLossReParam(reduction="mean") | |
| else: | |
| raise ValueError(f"{loss_function} is not supported") | |
| # # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # # All rights reserved. | |
| # # | |
| # # This source code is licensed under the CC-by-NC license found in the | |
| # # LICENSE file in the root directory of this source tree. | |
| # from abc import ABC | |
| # from typing import Optional, Tuple | |
| # import torch | |
| # from flow_matching.loss import MixturePathGeneralizedKL, EditFlowsLoss | |
| # from flow_matching.path import MixtureDiscreteProbPath, ProbPath, EditFlowsPathAdapter | |
| # from flow_matching.path.scheduler import PolynomialConvexScheduler | |
| # from torch import Tensor | |
| # from torch.nn.modules.loss import _Loss | |
| # class SourceDistribution(ABC): | |
| # def __init__( | |
| # self, | |
| # ) -> None: | |
| # ... | |
| # def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: | |
| # ... | |
| # def sample_like(self, tensor_like: Tensor) -> Tensor: | |
| # ... | |
| # class MaskedSourceDistribution(SourceDistribution): | |
| # def __init__(self, mask_token: int) -> None: | |
| # self.mask_token = mask_token | |
| # @property | |
| # def masked(self) -> bool: | |
| # return True | |
| # def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: | |
| # return torch.zeros(tensor_size, device=device).fill_(self.mask_token).long() | |
| # def sample_like(self, tensor_like: Tensor) -> Tensor: | |
| # return torch.zeros_like(tensor_like).fill_(self.mask_token).long() | |
| # class UniformSourceDistribution(SourceDistribution): | |
| # def __init__(self, vocab_size: int) -> None: | |
| # self.vocab_size = vocab_size | |
| # @property | |
| # def masked(self) -> bool: | |
| # return False | |
| # def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: | |
| # return torch.randint(size=tensor_size, high=self.vocab_size, device=device) | |
| # def sample_like(self, tensor_like: Tensor) -> Tensor: | |
| # return torch.randint_like(tensor_like, high=self.vocab_size) | |
| # class EmpiricalSourceDistribution(SourceDistribution): | |
| # def __init__(self, vocab_size: int, probs: torch.Tensor, length: int): | |
| # self.vocab_size = vocab_size | |
| # self.registered = probs / probs.sum() | |
| # self.length = length # e.g., 100 tokens as in the paper’s variant | |
| # @property | |
| # def masked(self) -> bool: return False | |
| # def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: | |
| # B = tensor_size[0] | |
| # idx = torch.multinomial(self.registered.to(device), num_samples=self.length, replacement=True) | |
| # return idx.view(1, self.length).repeat(B, 1) | |
| # def sample_like(self, tensor_like: Tensor) -> Tensor: | |
| # B = tensor_like.shape[0] | |
| # idx = torch.multinomial(self.registered.to(tensor_like.device), num_samples=self.length, replacement=True) | |
| # return idx.view(1, self.length).repeat(B, 1) | |
| # def get_path(scheduler_type: str, exponent: Optional[float] = None) -> ProbPath: | |
| # if scheduler_type == "polynomial": | |
| # scheduler = PolynomialConvexScheduler(n=exponent) | |
| # else: | |
| # raise ValueError(f"{scheduler_type} is not supported") | |
| # return EditFlowsPathAdapter(path=MixtureDiscreteProbPath(scheduler=scheduler)) # still need to decide (1) whether to implement the z_0, z_1 creation here, and (2) how to pass the eps_id to the adapter | |
| # def get_source_distribution( | |
| # source_distribution: str, p_emp: Optional[Tensor] = None, length: Optional[int] = None, vocab_size: int = None | |
| # ) -> SourceDistribution: | |
| # if p_emp is not None: | |
| # return EmpiricalSourceDistribution(vocab_size=vocab_size, probs=p_emp, length=length) | |
| # if source_distribution == "mask": | |
| # return MaskedSourceDistribution(mask_token=vocab_size) | |
| # elif source_distribution == "uniform": | |
| # return UniformSourceDistribution(vocab_size=vocab_size) | |
| # else: | |
| # raise ValueError(f"{source_distribution} is not supported") | |
| # def get_loss_function(loss_function: str, path: Optional[ProbPath] = None) -> _Loss: | |
| # if loss_function == "cross_entropy": | |
| # return torch.nn.CrossEntropyLoss() | |
| # elif loss_function == "generalized_kl": | |
| # assert path is not None | |
| # return MixturePathGeneralizedKL(path=path) | |
| # elif loss_function == "editflows": | |
| # assert path is not None | |
| # return EditFlowsLoss(path=path) | |
| # else: | |
| # raise ValueError(f"{loss_function} is not supported") | |
Xet Storage Details
- Size:
- 17.9 kB
- Xet hash:
- e2a2020d42b35083a20b65c19034f8da8644fd31a86970050e357b777136a55a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.