| import torch | |
| from torchvision.models.optical_flow import raft_large | |
| from modules.flow_models.raft.rfr_new import RAFT | |
| def raft_flow( | |
| I0: torch.Tensor, | |
| I1: torch.Tensor, | |
| data_domain: str = "animation", | |
| device: str = 'cuda' | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if I0.dtype != torch.float32 or I1.dtype != torch.float32: | |
| I0 = I0.to(torch.float32) | |
| I1 = I1.to(torch.float32) | |
| if data_domain == "animation": | |
| raft = RAFT().requires_grad_(False).eval().to(device) | |
| elif data_domain == "photorealism": | |
| raft = raft_large().requires_grad_(False).eval().to(device) | |
| else: | |
| raise ValueError("data_domain must be either 'animation' or 'photorealism'") | |
| return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1] |