Spaces:
Runtime error
Runtime error
| import torch | |
| from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper | |
| from einops import rearrange, repeat | |
| class StreamingWrapper(OpenAIWrapper): | |
| """ | |
| Modelwrapper for StreamingSVD, which holds the CAM model and the base model | |
| """ | |
| def __init__(self, diffusion_model, controlnet, num_frame_conditioning: int, compile_model: bool = False, pipeline_offloading: bool = False): | |
| super().__init__(diffusion_model=diffusion_model, | |
| compile_model=compile_model) | |
| self.controlnet = controlnet | |
| self.num_frame_conditioning = num_frame_conditioning | |
| self.pipeline_offloading = pipeline_offloading | |
| if pipeline_offloading: | |
| raise NotImplementedError( | |
| "Pipeline offloading for StreamingI2V not implemented yet.") | |
| def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs): | |
| batch_size = kwargs.pop("batch_size") | |
| # We apply the controlnet model only to the control frames. | |
| def reduce_to_cond_frames(input): | |
| input = rearrange(input, "(B F) ... -> B F ...", B=batch_size) | |
| input = input[:, :self.num_frame_conditioning] | |
| return rearrange(input, "B F ... -> (B F) ...") | |
| x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) | |
| x_ctrl = reduce_to_cond_frames(x) | |
| t_ctrl = reduce_to_cond_frames(t) | |
| context = c.get("crossattn", None) | |
| # controlnet is not using APM so we remove potentially additional tokens | |
| context_ctrl = context[:, :1] | |
| context_ctrl = reduce_to_cond_frames(context_ctrl) | |
| y = c.get("vector", None) | |
| y_ctrl = reduce_to_cond_frames(y) | |
| num_video_frames = kwargs.pop("num_video_frames") | |
| image_only_indicator = kwargs.pop("image_only_indicator") | |
| ctrl_img_enc_frames = repeat( | |
| kwargs['ctrl_frames'], "B ... -> (2 B) ... ") | |
| controlnet_cond = rearrange( | |
| ctrl_img_enc_frames, "B F ... -> (B F) ...") | |
| if self.diffusion_model.controlnet_mode: | |
| hs_control_input, hs_control_mid = self.controlnet(x=x_ctrl, # video latent | |
| timesteps=t_ctrl, # timestep | |
| context=context_ctrl, # clip image conditioning | |
| y=y_ctrl, # conditionigs, e.g. fps | |
| controlnet_cond=controlnet_cond, # control frames | |
| num_video_frames=self.num_frame_conditioning, | |
| num_video_frames_conditional=self.num_frame_conditioning, | |
| image_only_indicator=image_only_indicator[:, | |
| :self.num_frame_conditioning] | |
| ) | |
| else: | |
| hs_control_input = None | |
| hs_control_mid = None | |
| kwargs["hs_control_input"] = hs_control_input | |
| kwargs["hs_control_mid"] = hs_control_mid | |
| out = self.diffusion_model( | |
| x=x, | |
| timesteps=t, | |
| context=context, # must be (B F) T C | |
| y=y, # must be (B F) 768 | |
| num_video_frames=num_video_frames, | |
| num_conditional_frames=self.num_frame_conditioning, | |
| image_only_indicator=image_only_indicator, | |
| hs_control_input=hs_control_input, | |
| hs_control_mid=hs_control_mid, | |
| ) | |
| return out | |