AhmadMustafa commited on
Commit
068b511
·
0 Parent(s):

Initial commit for CogVideoXInterp

Browse files
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
4
+ "python-envs.pythonProjects": []
5
+ }
FILES.txt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Bare Minimum Files for CogVideoX-Interpolation Gradio App
2
+ ===========================================================
3
+
4
+ ESSENTIAL FILES (Must have all):
5
+
6
+ 1. app.py (7.7KB)
7
+ - Main Gradio application
8
+ - Handles UI and video generation
9
+
10
+ 2. requirements.txt (103B)
11
+ - Python package dependencies
12
+ - Install with: pip install -r requirements.txt
13
+
14
+ 3. README.md (232B)
15
+ - HuggingFace Spaces configuration
16
+ - Contains YAML frontmatter for Spaces
17
+
18
+ 4. cogvideox_interpolation/ (directory)
19
+ - pipeline.py (~38KB)
20
+ * Core CogVideoX interpolation pipeline
21
+ * Custom diffusion model implementation
22
+
23
+ - datasets.py (~6KB)
24
+ * Dataset loading utilities
25
+ * Not used in inference but required for imports
26
+
27
+ OPTIONAL (Helpful but not required):
28
+
29
+ 5. SETUP.md (3.1KB)
30
+ - Quick setup instructions
31
+ - Can be deleted after setup
32
+
33
+ TOTAL SIZE: ~64KB (excluding model weights)
34
+
35
+ MODEL DOWNLOAD:
36
+ - Model auto-downloads on first run (~20GB)
37
+ - Model: feizhengcong/CogvideoX-Interpolation
38
+ - Downloads to: ~/.cache/huggingface/
39
+
40
+ WHAT'S NOT NEEDED:
41
+ ✗ Training scripts (finetune.py, finetune.sh)
42
+ ✗ Documentation files (CLAUDE.md, GPU_REQUIREMENTS.md, GRADIO_README.md)
43
+ ✗ Example cases (cases/ directory)
44
+ ✗ Git files (.git, .gitignore)
45
+ ✗ Compiled files (__pycache__, *.pyc)
46
+ ✗ Original README.md from repo
47
+ ✗ requirement.txt (original, uses requirements.txt instead)
48
+
49
+ TO RUN LOCALLY:
50
+ 1. pip install -r requirements.txt
51
+ 2. python app.py
52
+ 3. Open http://localhost:7860
53
+
54
+ TO DEPLOY ON HUGGINGFACE SPACES:
55
+ 1. Upload all files in this directory
56
+ 2. Select GPU hardware (T4 minimum, A10G recommended)
57
+ 3. Space auto-deploys
58
+
59
+ GPU REQUIREMENTS:
60
+ - Minimum: 16GB VRAM
61
+ - Recommended: 24GB VRAM
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CogVideoXInterp
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.47.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SETUP.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX Keyframe Interpolation - Quick Setup
2
+
3
+ This directory contains the **bare minimum files** needed to run the CogVideoX Keyframe Interpolation Gradio app.
4
+
5
+ ## 📁 Contents
6
+
7
+ ```
8
+ CogVideoXInterp/
9
+ ├── README.md # HuggingFace Spaces README
10
+ ├── app.py # Main Gradio application
11
+ ├── requirements.txt # Python dependencies
12
+ ├── cogvideox_interpolation/ # Core pipeline module
13
+ │ ├── datasets.py # Dataset loading (not needed for inference)
14
+ │ └── pipeline.py # Custom interpolation pipeline
15
+ └── SETUP.md # This file
16
+ ```
17
+
18
+ **Total size:** ~64KB (model downloads separately)
19
+
20
+ ---
21
+
22
+ ## 🚀 Quick Start
23
+
24
+ ### Local Setup
25
+
26
+ 1. **Install dependencies:**
27
+ ```bash
28
+ pip install -r requirements.txt
29
+ ```
30
+
31
+ 2. **Run the app:**
32
+ ```bash
33
+ python app.py
34
+ ```
35
+
36
+ 3. **Open browser:**
37
+ Navigate to `http://localhost:7860`
38
+
39
+ ### GPU Requirements
40
+
41
+ - **Minimum:** 16GB VRAM (RTX 4060 Ti 16GB, RTX 4080)
42
+ - **Recommended:** 24GB VRAM (RTX 3090, RTX 4090)
43
+
44
+ ---
45
+
46
+ ## 🤗 Deploy to HuggingFace Spaces
47
+
48
+ ### Method 1: Web Upload
49
+
50
+ 1. Go to https://huggingface.co/spaces
51
+ 2. Click "Create new Space"
52
+ 3. Choose **Gradio** as SDK
53
+ 4. Upload all files from this directory
54
+ 5. Select GPU hardware (T4 minimum, A10G recommended)
55
+ 6. Space will auto-deploy!
56
+
57
+ ### Method 2: Git Push
58
+
59
+ ```bash
60
+ # Create a Space on HuggingFace first, then:
61
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
62
+ cd YOUR_SPACE_NAME
63
+
64
+ # Copy files
65
+ cp -r /path/to/CogVideoXInterp/* .
66
+
67
+ # Push
68
+ git add .
69
+ git commit -m "Initial commit"
70
+ git push
71
+ ```
72
+
73
+ ### HuggingFace Spaces Hardware Options
74
+
75
+ | Hardware | VRAM | Speed | Cost/hr |
76
+ |----------|------|-------|---------|
77
+ | CPU | 0GB | ❌ Won't work | Free |
78
+ | T4 | 16GB | ⚠️ Slow (5-8 min) | ~$0.60 |
79
+ | A10G | 24GB | ✅ Good (2-4 min) | ~$3.15 |
80
+ | A100 | 40GB | ✅ Fast (1-2 min) | ~$7.00 |
81
+
82
+ **Note:** Model will auto-download on first run (~20GB)
83
+
84
+ ---
85
+
86
+ ## 📝 Usage
87
+
88
+ 1. **Load Model** - Enter model path or use default `feizhengcong/CogvideoX-Interpolation`
89
+ 2. **Upload Images** - Provide start and end frame
90
+ 3. **Write Prompt** - Describe the motion/transition
91
+ 4. **Generate** - Wait 2-5 minutes for video
92
+
93
+ ### Example Prompts
94
+
95
+ ✅ "A person walks forward slowly, their body moving naturally with each step"
96
+
97
+ ✅ "The camera smoothly pans from left to right, revealing the scene"
98
+
99
+ ✅ "A dancer gracefully transitions from one pose to another"
100
+
101
+ ---
102
+
103
+ ## 🔧 Troubleshooting
104
+
105
+ ### Out of Memory
106
+
107
+ Reduce parameters in the app:
108
+ - Frames: 49 → 25
109
+ - Steps: 50 → 30
110
+
111
+ ### Model Download Fails
112
+
113
+ Check internet connection. Model is ~20GB and downloads to:
114
+ - Linux/Mac: `~/.cache/huggingface/`
115
+ - Windows: `C:\Users\USERNAME\.cache\huggingface\`
116
+
117
+ ### Import Errors
118
+
119
+ Make sure all files from this directory are in the same location, especially the `cogvideox_interpolation/` folder.
120
+
121
+ ---
122
+
123
+ ## 📚 More Information
124
+
125
+ For detailed documentation, see the parent repository at:
126
+ https://github.com/feizc/CogvideX-Interpolation
127
+
128
+ **Model:** https://huggingface.co/feizhengcong/CogvideoX-Interpolation
129
+
130
+ **License:** Apache 2.0
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers.utils import export_to_video
4
+ from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline
5
+ from PIL import Image
6
+ import tempfile
7
+ import os
8
+
9
+ # Global variable to store the pipeline
10
+ pipe = None
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ def load_model(model_path):
14
+ """Load the CogVideoX-Interpolation model"""
15
+ global pipe
16
+
17
+ print(f"Loading model from {model_path}...")
18
+ print(f"Using device: {device}")
19
+
20
+ # Determine dtype based on model variant
21
+ dtype = torch.bfloat16 if "5b" in model_path.lower() else torch.float16
22
+
23
+ pipe = CogVideoXInterpolationPipeline.from_pretrained(
24
+ model_path,
25
+ torch_dtype=dtype
26
+ )
27
+
28
+ # Memory optimization
29
+ if device == "cuda":
30
+ pipe.enable_sequential_cpu_offload()
31
+ else:
32
+ pipe = pipe.to(device)
33
+
34
+ pipe.vae.enable_tiling()
35
+ pipe.vae.enable_slicing()
36
+
37
+ print("Model loaded successfully!")
38
+ return "✓ Model loaded successfully!"
39
+
40
+ def generate_interpolation(
41
+ first_image,
42
+ last_image,
43
+ prompt,
44
+ num_frames=49,
45
+ num_inference_steps=50,
46
+ guidance_scale=6.0,
47
+ fps=8,
48
+ seed=42
49
+ ):
50
+ """Generate interpolated video between two keyframes"""
51
+
52
+ if pipe is None:
53
+ return None, "⚠️ Please load the model first!"
54
+
55
+ if first_image is None or last_image is None:
56
+ return None, "⚠️ Please upload both start and end frame images!"
57
+
58
+ if not prompt.strip():
59
+ return None, "⚠️ Please provide a text prompt describing the motion!"
60
+
61
+ try:
62
+ # Convert numpy arrays to PIL Images if needed
63
+ if not isinstance(first_image, Image.Image):
64
+ first_image = Image.fromarray(first_image)
65
+ if not isinstance(last_image, Image.Image):
66
+ last_image = Image.fromarray(last_image)
67
+
68
+ print(f"Generating video with prompt: {prompt}")
69
+ print(f"Parameters: frames={num_frames}, steps={num_inference_steps}, guidance={guidance_scale}")
70
+
71
+ # Generate video
72
+ generator = torch.Generator(device=device).manual_seed(seed)
73
+
74
+ video = pipe(
75
+ prompt=prompt,
76
+ first_image=first_image,
77
+ last_image=last_image,
78
+ num_videos_per_prompt=1,
79
+ num_inference_steps=num_inference_steps,
80
+ num_frames=num_frames,
81
+ guidance_scale=guidance_scale,
82
+ generator=generator,
83
+ )[0]
84
+
85
+ # Export to temporary file
86
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
87
+ output_path = temp_file.name
88
+ temp_file.close()
89
+
90
+ export_to_video(video, output_path, fps=fps)
91
+
92
+ status = f"✓ Video generated successfully! ({num_frames} frames at {fps} fps)"
93
+ print(status)
94
+
95
+ return output_path, status
96
+
97
+ except Exception as e:
98
+ error_msg = f"❌ Error: {str(e)}"
99
+ print(error_msg)
100
+ return None, error_msg
101
+
102
+ # Create Gradio interface
103
+ with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
104
+ gr.Markdown("""
105
+ # 🎬 CogVideoX Keyframe Interpolation
106
+
107
+ Generate smooth video transitions between two keyframe images using AI.
108
+
109
+ **Instructions:**
110
+ 1. First, load the model by providing the path to your checkpoint
111
+ 2. Upload start and end frame images
112
+ 3. Describe the motion/transition in the text prompt
113
+ 4. Adjust parameters and generate!
114
+ """)
115
+
116
+ with gr.Row():
117
+ with gr.Column():
118
+ gr.Markdown("### 🔧 Model Setup")
119
+ model_path_input = gr.Textbox(
120
+ label="Model Path",
121
+ placeholder="e.g., /path/to/CogVideoX-5b-I2V-inter or feizhengcong/CogvideoX-Interpolation",
122
+ value="feizhengcong/CogvideoX-Interpolation"
123
+ )
124
+ load_btn = gr.Button("Load Model", variant="primary")
125
+ model_status = gr.Textbox(label="Status", interactive=False)
126
+
127
+ gr.Markdown("---")
128
+
129
+ with gr.Row():
130
+ with gr.Column():
131
+ gr.Markdown("### 🖼️ Input Keyframes")
132
+ first_image_input = gr.Image(
133
+ label="Start Frame",
134
+ type="pil",
135
+ height=300
136
+ )
137
+ last_image_input = gr.Image(
138
+ label="End Frame",
139
+ type="pil",
140
+ height=300
141
+ )
142
+
143
+ with gr.Column():
144
+ gr.Markdown("### ⚙️ Generation Settings")
145
+ prompt_input = gr.Textbox(
146
+ label="Motion Description",
147
+ placeholder="Describe the motion/transition between the frames...",
148
+ lines=4
149
+ )
150
+
151
+ with gr.Row():
152
+ num_frames_slider = gr.Slider(
153
+ label="Number of Frames",
154
+ minimum=13,
155
+ maximum=49,
156
+ step=4,
157
+ value=49,
158
+ info="Must be 4k+1 format (13, 17, 21, ..., 49)"
159
+ )
160
+ fps_slider = gr.Slider(
161
+ label="FPS",
162
+ minimum=4,
163
+ maximum=16,
164
+ step=2,
165
+ value=8
166
+ )
167
+
168
+ with gr.Row():
169
+ num_steps_slider = gr.Slider(
170
+ label="Inference Steps",
171
+ minimum=20,
172
+ maximum=100,
173
+ step=5,
174
+ value=50,
175
+ info="More steps = better quality but slower"
176
+ )
177
+ guidance_slider = gr.Slider(
178
+ label="Guidance Scale",
179
+ minimum=1.0,
180
+ maximum=15.0,
181
+ step=0.5,
182
+ value=6.0,
183
+ info="Higher = stronger prompt following"
184
+ )
185
+
186
+ seed_input = gr.Number(
187
+ label="Random Seed",
188
+ value=42,
189
+ precision=0
190
+ )
191
+
192
+ generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
193
+
194
+ gr.Markdown("---")
195
+
196
+ with gr.Row():
197
+ with gr.Column():
198
+ gr.Markdown("### 🎥 Generated Video")
199
+ output_video = gr.Video(label="Output")
200
+ generation_status = gr.Textbox(label="Generation Status", interactive=False)
201
+
202
+ # Examples
203
+ gr.Markdown("---")
204
+ gr.Markdown("### 💡 Example Prompts")
205
+ gr.Examples(
206
+ examples=[
207
+ ["A person walks forward slowly, their body moving naturally with each step."],
208
+ ["The camera smoothly pans from left to right, revealing the scene."],
209
+ ["A dancer gracefully transitions from one pose to another."],
210
+ ["The sun sets gradually, changing the lighting and colors of the scene."],
211
+ ["A car accelerates down the street, moving from standstill to motion."],
212
+ ],
213
+ inputs=prompt_input,
214
+ label="Click to use example prompts"
215
+ )
216
+
217
+ # Event handlers
218
+ load_btn.click(
219
+ fn=load_model,
220
+ inputs=[model_path_input],
221
+ outputs=[model_status]
222
+ )
223
+
224
+ generate_btn.click(
225
+ fn=generate_interpolation,
226
+ inputs=[
227
+ first_image_input,
228
+ last_image_input,
229
+ prompt_input,
230
+ num_frames_slider,
231
+ num_steps_slider,
232
+ guidance_slider,
233
+ fps_slider,
234
+ seed_input
235
+ ],
236
+ outputs=[output_video, generation_status]
237
+ )
238
+
239
+ if __name__ == "__main__":
240
+ print("="*50)
241
+ print("CogVideoX Keyframe Interpolation Gradio App")
242
+ print("="*50)
243
+ print(f"Device: {device}")
244
+ print(f"CUDA available: {torch.cuda.is_available()}")
245
+ if torch.cuda.is_available():
246
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
247
+ print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
248
+ print("="*50)
249
+
250
+ demo.launch(
251
+ server_name="0.0.0.0",
252
+ server_port=7860,
253
+ share=False
254
+ )
cogvideox_interpolation/datasets.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+ from torch.utils.data import DataLoader, Dataset
5
+ import torchvision.transforms as TT
6
+ from torchvision import transforms
7
+ from torchvision.transforms.functional import center_crop, resize
8
+ from torchvision.transforms import InterpolationMode
9
+ import random
10
+ try:
11
+ import decord
12
+ except ImportError:
13
+ raise ImportError(
14
+ "The `decord` package is required for loading the video dataset. Install with `pip install decord`"
15
+ )
16
+
17
+ decord.bridge.set_bridge("torch")
18
+
19
+ class ImageVideoDataset(Dataset):
20
+ def __init__(
21
+ self,
22
+ data_root,
23
+ tokenizer,
24
+ max_sequence_length: int = 226,
25
+ height: int = 480,
26
+ width: int = 720,
27
+ video_reshape_mode: str = "center",
28
+ fps: int = 8,
29
+ stripe: int = 2,
30
+ max_num_frames: int = 49,
31
+ skip_frames_start: int = 0,
32
+ skip_frames_end: int = 0,
33
+ random_flip: Optional[float] = None,
34
+ ) -> None:
35
+ super().__init__()
36
+
37
+ with open(data_root, 'r') as f:
38
+ self.data_list = json.load(f)
39
+
40
+ self.tokenizer = tokenizer
41
+ self.max_sequence_length = max_sequence_length
42
+ self.height = height
43
+ self.width = width
44
+ self.video_reshape_mode = video_reshape_mode
45
+ self.fps = fps
46
+ self.max_num_frames = max_num_frames
47
+ self.skip_frames_start = skip_frames_start
48
+ self.skip_frames_end = skip_frames_end
49
+ self.stripe = stripe
50
+ self.video_transforms = transforms.Compose(
51
+ [
52
+ transforms.RandomHorizontalFlip(random_flip) if random_flip else transforms.Lambda(lambda x: x),
53
+ transforms.Lambda(lambda x: x / 255.0),
54
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
55
+ ]
56
+ )
57
+
58
+
59
+ def __len__(self):
60
+ return len(self.data_list)
61
+
62
+ def _resize_for_rectangle_crop(self, arr):
63
+ image_size = self.height, self.width
64
+ reshape_mode = self.video_reshape_mode
65
+ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
66
+ arr = resize(
67
+ arr,
68
+ size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
69
+ interpolation=InterpolationMode.BICUBIC,
70
+ )
71
+ else:
72
+ arr = resize(
73
+ arr,
74
+ size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
75
+ interpolation=InterpolationMode.BICUBIC,
76
+ )
77
+
78
+ h, w = arr.shape[2], arr.shape[3]
79
+ arr = arr.squeeze(0)
80
+
81
+ delta_h = h - image_size[0]
82
+ delta_w = w - image_size[1]
83
+
84
+ if reshape_mode == "random" or reshape_mode == "none":
85
+ top = np.random.randint(0, delta_h + 1)
86
+ left = np.random.randint(0, delta_w + 1)
87
+ elif reshape_mode == "center":
88
+ top, left = delta_h // 2, delta_w // 2
89
+ else:
90
+ raise NotImplementedError
91
+ arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
92
+ return arr
93
+
94
+ def __getitem__(self, index):
95
+ while True:
96
+ try:
97
+ video_reader = decord.VideoReader(self.data_list[index]['file_path'], width=self.width, height=self.height)
98
+ video_num_frames = len(video_reader)
99
+ # print(video_num_frames, video_reader.get_avg_fps())
100
+ if self.stripe * self.max_num_frames > video_num_frames:
101
+ stripe = 1
102
+ else:
103
+ stripe = self.stripe
104
+
105
+ random_range = video_num_frames - stripe * self.max_num_frames - 1
106
+ random_range = max(1, random_range)
107
+ start_frame = random.randint(1, random_range) if random_range > 0 else 1
108
+
109
+ indices = list(range(start_frame, start_frame + stripe * self.max_num_frames, stripe)) # (end_frame - start_frame) // self.max_num_frames))
110
+ frames = video_reader.get_batch(indices)
111
+
112
+ # Ensure that we don't go over the limit
113
+ frames = frames[: self.max_num_frames]
114
+ selected_num_frames = frames.shape[0]
115
+
116
+ # Choose first (4k + 1) frames as this is how many is required by the VAE
117
+ remainder = (3 + (selected_num_frames % 4)) % 4
118
+ if remainder != 0:
119
+ frames = frames[:-remainder]
120
+ selected_num_frames = frames.shape[0]
121
+
122
+ assert (selected_num_frames - 1) % 4 == 0
123
+ if selected_num_frames == self.max_num_frames:
124
+ break
125
+ else:
126
+ index = (index + 1) % len(self.data_list)
127
+ continue
128
+
129
+ except Exception as e:
130
+ index = (index + 1) % len(self.data_list)
131
+ print(video_num_frames, start_frame, indices)
132
+ print(
133
+ "Error encounter during audio feature extraction: ", e,
134
+ )
135
+ continue
136
+
137
+ # Training transforms
138
+ # frames = (frames - 127.5) / 127.5
139
+ frames = frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
140
+ frames = self._resize_for_rectangle_crop(frames)
141
+ frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
142
+
143
+ text_inputs = self.tokenizer(
144
+ [self.data_list[index]['text']],
145
+ padding="max_length",
146
+ max_length=self.max_sequence_length,
147
+ truncation=True,
148
+ add_special_tokens=True,
149
+ return_tensors="pt",
150
+ )
151
+ text_input_ids = text_inputs.input_ids[0]
152
+
153
+ return frames.contiguous(), text_input_ids
154
+
cogvideox_interpolation/pipeline.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import PIL
3
+ import inspect
4
+ import torch
5
+ from typing import Callable, Dict, List, Optional, Tuple, Union
6
+ from transformers import T5EncoderModel, T5Tokenizer
7
+
8
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
9
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
10
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
11
+ from diffusers.utils import (
12
+ logging,
13
+ replace_example_docstring,
14
+ )
15
+ from diffusers.image_processor import PipelineImageInput
16
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
17
+ from diffusers.video_processor import VideoProcessor
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
20
+
21
+
22
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
23
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
24
+ tw = tgt_width
25
+ th = tgt_height
26
+ h, w = src
27
+ r = h / w
28
+ if r > (th / tw):
29
+ resize_height = th
30
+ resize_width = int(round(th / h * w))
31
+ else:
32
+ resize_width = tw
33
+ resize_height = int(round(tw / w * h))
34
+
35
+ crop_top = int(round((th - resize_height) / 2.0))
36
+ crop_left = int(round((tw - resize_width) / 2.0))
37
+
38
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
39
+
40
+
41
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
42
+ def retrieve_timesteps(
43
+ scheduler,
44
+ num_inference_steps: Optional[int] = None,
45
+ device: Optional[Union[str, torch.device]] = None,
46
+ timesteps: Optional[List[int]] = None,
47
+ sigmas: Optional[List[float]] = None,
48
+ **kwargs,
49
+ ):
50
+ """
51
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
52
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
53
+
54
+ Args:
55
+ scheduler (`SchedulerMixin`):
56
+ The scheduler to get timesteps from.
57
+ num_inference_steps (`int`):
58
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
59
+ must be `None`.
60
+ device (`str` or `torch.device`, *optional*):
61
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
62
+ timesteps (`List[int]`, *optional*):
63
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
64
+ `num_inference_steps` and `sigmas` must be `None`.
65
+ sigmas (`List[float]`, *optional*):
66
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
67
+ `num_inference_steps` and `timesteps` must be `None`.
68
+
69
+ Returns:
70
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
71
+ second element is the number of inference steps.
72
+ """
73
+ if timesteps is not None and sigmas is not None:
74
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
75
+ if timesteps is not None:
76
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
77
+ if not accepts_timesteps:
78
+ raise ValueError(
79
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
80
+ f" timestep schedules. Please check whether you are using the correct scheduler."
81
+ )
82
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
83
+ timesteps = scheduler.timesteps
84
+ num_inference_steps = len(timesteps)
85
+ elif sigmas is not None:
86
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
87
+ if not accept_sigmas:
88
+ raise ValueError(
89
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
90
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
91
+ )
92
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ num_inference_steps = len(timesteps)
95
+ else:
96
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
97
+ timesteps = scheduler.timesteps
98
+ return timesteps, num_inference_steps
99
+
100
+
101
+
102
+
103
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
104
+ def retrieve_latents(
105
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
106
+ ):
107
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
108
+ return encoder_output.latent_dist.sample(generator)
109
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
110
+ return encoder_output.latent_dist.mode()
111
+ elif hasattr(encoder_output, "latents"):
112
+ return encoder_output.latents
113
+ else:
114
+ raise AttributeError("Could not access latents of provided encoder_output")
115
+
116
+
117
+
118
+
119
+ class CogVideoXInterpolationPipeline(DiffusionPipeline):
120
+ _optional_components = []
121
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
122
+
123
+ _callback_tensor_inputs = [
124
+ "latents",
125
+ "prompt_embeds",
126
+ "negative_prompt_embeds",
127
+ ]
128
+ def __init__(
129
+ self,
130
+ tokenizer: T5Tokenizer,
131
+ text_encoder: T5EncoderModel,
132
+ vae: AutoencoderKLCogVideoX,
133
+ transformer: CogVideoXTransformer3DModel,
134
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
135
+ ):
136
+ super().__init__()
137
+
138
+ self.register_modules(
139
+ tokenizer=tokenizer,
140
+ text_encoder=text_encoder,
141
+ vae=vae,
142
+ transformer=transformer,
143
+ scheduler=scheduler,
144
+ )
145
+ self.vae_scale_factor_spatial = (
146
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
147
+ )
148
+ self.vae_scale_factor_temporal = (
149
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
150
+ )
151
+
152
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
153
+
154
+
155
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
156
+ def _get_t5_prompt_embeds(
157
+ self,
158
+ prompt: Union[str, List[str]] = None,
159
+ num_videos_per_prompt: int = 1,
160
+ max_sequence_length: int = 226,
161
+ device: Optional[torch.device] = None,
162
+ dtype: Optional[torch.dtype] = None,
163
+ ):
164
+ device = device or self._execution_device
165
+ dtype = dtype or self.text_encoder.dtype
166
+
167
+ prompt = [prompt] if isinstance(prompt, str) else prompt
168
+ batch_size = len(prompt)
169
+
170
+ text_inputs = self.tokenizer(
171
+ prompt,
172
+ padding="max_length",
173
+ max_length=max_sequence_length,
174
+ truncation=True,
175
+ add_special_tokens=True,
176
+ return_tensors="pt",
177
+ )
178
+ text_input_ids = text_inputs.input_ids
179
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
180
+
181
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
182
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
183
+ logger.warning(
184
+ "The following part of your input was truncated because `max_sequence_length` is set to "
185
+ f" {max_sequence_length} tokens: {removed_text}"
186
+ )
187
+
188
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
189
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
190
+
191
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
192
+ _, seq_len, _ = prompt_embeds.shape
193
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
194
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
195
+
196
+ return prompt_embeds
197
+
198
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
199
+ def encode_prompt(
200
+ self,
201
+ prompt: Union[str, List[str]],
202
+ negative_prompt: Optional[Union[str, List[str]]] = None,
203
+ do_classifier_free_guidance: bool = True,
204
+ num_videos_per_prompt: int = 1,
205
+ prompt_embeds: Optional[torch.Tensor] = None,
206
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
207
+ max_sequence_length: int = 226,
208
+ device: Optional[torch.device] = None,
209
+ dtype: Optional[torch.dtype] = None,
210
+ ):
211
+ r"""
212
+ Encodes the prompt into text encoder hidden states.
213
+
214
+ Args:
215
+ prompt (`str` or `List[str]`, *optional*):
216
+ prompt to be encoded
217
+ negative_prompt (`str` or `List[str]`, *optional*):
218
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
219
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
220
+ less than `1`).
221
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
222
+ Whether to use classifier free guidance or not.
223
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
224
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
225
+ prompt_embeds (`torch.Tensor`, *optional*):
226
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
227
+ provided, text embeddings will be generated from `prompt` input argument.
228
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
229
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
230
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
231
+ argument.
232
+ device: (`torch.device`, *optional*):
233
+ torch device
234
+ dtype: (`torch.dtype`, *optional*):
235
+ torch dtype
236
+ """
237
+ device = device or self._execution_device
238
+
239
+ prompt = [prompt] if isinstance(prompt, str) else prompt
240
+ if prompt is not None:
241
+ batch_size = len(prompt)
242
+ else:
243
+ batch_size = prompt_embeds.shape[0]
244
+
245
+ if prompt_embeds is None:
246
+ prompt_embeds = self._get_t5_prompt_embeds(
247
+ prompt=prompt,
248
+ num_videos_per_prompt=num_videos_per_prompt,
249
+ max_sequence_length=max_sequence_length,
250
+ device=device,
251
+ dtype=dtype,
252
+ )
253
+
254
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
255
+ negative_prompt = negative_prompt or ""
256
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
257
+
258
+ if prompt is not None and type(prompt) is not type(negative_prompt):
259
+ raise TypeError(
260
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
261
+ f" {type(prompt)}."
262
+ )
263
+ elif batch_size != len(negative_prompt):
264
+ raise ValueError(
265
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
266
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
267
+ " the batch size of `prompt`."
268
+ )
269
+
270
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
271
+ prompt=negative_prompt,
272
+ num_videos_per_prompt=num_videos_per_prompt,
273
+ max_sequence_length=max_sequence_length,
274
+ device=device,
275
+ dtype=dtype,
276
+ )
277
+
278
+ return prompt_embeds, negative_prompt_embeds
279
+
280
+ def prepare_latents(
281
+ self,
282
+ first_image: torch.Tensor,
283
+ last_image: torch.Tensor,
284
+ batch_size: int = 1,
285
+ num_channels_latents: int = 16,
286
+ num_frames: int = 13,
287
+ height: int = 60,
288
+ width: int = 90,
289
+ dtype: Optional[torch.dtype] = None,
290
+ device: Optional[torch.device] = None,
291
+ generator: Optional[torch.Generator] = None,
292
+ latents: Optional[torch.Tensor] = None,
293
+ ):
294
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
295
+ shape = (
296
+ batch_size,
297
+ num_frames,
298
+ num_channels_latents,
299
+ height // self.vae_scale_factor_spatial,
300
+ width // self.vae_scale_factor_spatial,
301
+ )
302
+
303
+ if isinstance(generator, list) and len(generator) != batch_size:
304
+ raise ValueError(
305
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
306
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
307
+ )
308
+
309
+ first_image = first_image.unsqueeze(2) # [B, C, F, H, W]
310
+ last_image = last_image.unsqueeze(2) # [B, C, F, H, W]
311
+
312
+ if isinstance(generator, list):
313
+ first_image_latents = [
314
+ retrieve_latents(self.vae.encode(first_image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
315
+ ]
316
+ else:
317
+ first_image_latents = [retrieve_latents(self.vae.encode(first_img.unsqueeze(0)), generator) for first_img in first_image]
318
+
319
+ if isinstance(generator, list):
320
+ last_image_latents = [
321
+ retrieve_latents(self.vae.encode(last_image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
322
+ ]
323
+ else:
324
+ last_image_latents = [retrieve_latents(self.vae.encode(last_img.unsqueeze(0)), generator) for last_img in last_image]
325
+
326
+
327
+ first_image_latents = torch.cat(first_image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
328
+ first_image_latents = self.vae.config.scaling_factor * first_image_latents
329
+ last_image_latents = torch.cat(last_image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
330
+ last_image_latents = self.vae.config.scaling_factor * last_image_latents
331
+
332
+
333
+ padding_shape = (
334
+ batch_size,
335
+ num_frames - 2,
336
+ num_channels_latents,
337
+ height // self.vae_scale_factor_spatial,
338
+ width // self.vae_scale_factor_spatial,
339
+ )
340
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
341
+ image_latents = torch.cat([first_image_latents, latent_padding, last_image_latents], dim=1)
342
+
343
+ if latents is None:
344
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
345
+ else:
346
+ latents = latents.to(device)
347
+
348
+ # scale the initial noise by the standard deviation required by the scheduler
349
+ latents = latents * self.scheduler.init_noise_sigma
350
+ return latents, image_latents
351
+
352
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
353
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
354
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
355
+ latents = 1 / self.vae.config.scaling_factor * latents
356
+
357
+ frames = self.vae.decode(latents).sample
358
+ return frames
359
+
360
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
361
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
362
+ # get the original timestep using init_timestep
363
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
364
+
365
+ t_start = max(num_inference_steps - init_timestep, 0)
366
+ timesteps = timesteps[t_start * self.scheduler.order :]
367
+
368
+ return timesteps, num_inference_steps - t_start
369
+
370
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
371
+ def prepare_extra_step_kwargs(self, generator, eta):
372
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
373
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
374
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
375
+ # and should be between [0, 1]
376
+
377
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
378
+ extra_step_kwargs = {}
379
+ if accepts_eta:
380
+ extra_step_kwargs["eta"] = eta
381
+
382
+ # check if the scheduler accepts generator
383
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
384
+ if accepts_generator:
385
+ extra_step_kwargs["generator"] = generator
386
+ return extra_step_kwargs
387
+
388
+ def check_inputs(
389
+ self,
390
+ first_image,
391
+ last_image,
392
+ prompt,
393
+ height,
394
+ width,
395
+ negative_prompt,
396
+ callback_on_step_end_tensor_inputs,
397
+ video=None,
398
+ latents=None,
399
+ prompt_embeds=None,
400
+ negative_prompt_embeds=None,
401
+ ):
402
+ if (
403
+ not isinstance(first_image, torch.Tensor)
404
+ and not isinstance(first_image, PIL.Image.Image)
405
+ and not isinstance(first_image, list)
406
+ ):
407
+ raise ValueError(
408
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
409
+ f" {type(first_image)}"
410
+ )
411
+
412
+ if (
413
+ not isinstance(last_image, torch.Tensor)
414
+ and not isinstance(last_image, PIL.Image.Image)
415
+ and not isinstance(last_image, list)
416
+ ):
417
+ raise ValueError(
418
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
419
+ f" {type(last_image)}"
420
+ )
421
+
422
+
423
+ if height % 8 != 0 or width % 8 != 0:
424
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
425
+
426
+ if callback_on_step_end_tensor_inputs is not None and not all(
427
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
428
+ ):
429
+ raise ValueError(
430
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
431
+ )
432
+ if prompt is not None and prompt_embeds is not None:
433
+ raise ValueError(
434
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
435
+ " only forward one of the two."
436
+ )
437
+ elif prompt is None and prompt_embeds is None:
438
+ raise ValueError(
439
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
440
+ )
441
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
442
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
443
+
444
+ if prompt is not None and negative_prompt_embeds is not None:
445
+ raise ValueError(
446
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
447
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
448
+ )
449
+
450
+ if negative_prompt is not None and negative_prompt_embeds is not None:
451
+ raise ValueError(
452
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
453
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
454
+ )
455
+
456
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
457
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
458
+ raise ValueError(
459
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
460
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
461
+ f" {negative_prompt_embeds.shape}."
462
+ )
463
+
464
+ if video is not None and latents is not None:
465
+ raise ValueError("Only one of `video` or `latents` should be provided")
466
+
467
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
468
+ def fuse_qkv_projections(self) -> None:
469
+ r"""Enables fused QKV projections."""
470
+ self.fusing_transformer = True
471
+ self.transformer.fuse_qkv_projections()
472
+
473
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
474
+ def unfuse_qkv_projections(self) -> None:
475
+ r"""Disable QKV projection fusion if enabled."""
476
+ if not self.fusing_transformer:
477
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
478
+ else:
479
+ self.transformer.unfuse_qkv_projections()
480
+ self.fusing_transformer = False
481
+
482
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
483
+ def _prepare_rotary_positional_embeddings(
484
+ self,
485
+ height: int,
486
+ width: int,
487
+ num_frames: int,
488
+ device: torch.device,
489
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
490
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
491
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
492
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
493
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
494
+
495
+ grid_crops_coords = get_resize_crop_region_for_grid(
496
+ (grid_height, grid_width), base_size_width, base_size_height
497
+ )
498
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
499
+ embed_dim=self.transformer.config.attention_head_dim,
500
+ crops_coords=grid_crops_coords,
501
+ grid_size=(grid_height, grid_width),
502
+ temporal_size=num_frames,
503
+ )
504
+
505
+ freqs_cos = freqs_cos.to(device=device)
506
+ freqs_sin = freqs_sin.to(device=device)
507
+ return freqs_cos, freqs_sin
508
+
509
+ @property
510
+ def guidance_scale(self):
511
+ return self._guidance_scale
512
+
513
+ @property
514
+ def num_timesteps(self):
515
+ return self._num_timesteps
516
+
517
+ @property
518
+ def interrupt(self):
519
+ return self._interrupt
520
+
521
+ @torch.no_grad()
522
+ def __call__(
523
+ self,
524
+ first_image: PipelineImageInput,
525
+ last_image: PipelineImageInput,
526
+ prompt: Optional[Union[str, List[str]]] = None,
527
+ negative_prompt: Optional[Union[str, List[str]]] = None,
528
+ height: int = 480,
529
+ width: int = 720,
530
+ num_frames: int = 49,
531
+ num_inference_steps: int = 50,
532
+ timesteps: Optional[List[int]] = None,
533
+ guidance_scale: float = 6,
534
+ use_dynamic_cfg: bool = False,
535
+ num_videos_per_prompt: int = 1,
536
+ eta: float = 0.0,
537
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
538
+ latents: Optional[torch.FloatTensor] = None,
539
+ prompt_embeds: Optional[torch.FloatTensor] = None,
540
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
541
+ output_type: str = "pil",
542
+ return_dict: bool = True,
543
+ callback_on_step_end: Optional[
544
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
545
+ ] = None,
546
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
547
+ max_sequence_length: int = 226,
548
+ ):
549
+ """
550
+ Function invoked when calling the pipeline for generation.
551
+
552
+ Args:
553
+ image (`PipelineImageInput`):
554
+ The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
555
+ prompt (`str` or `List[str]`, *optional*):
556
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
557
+ instead.
558
+ negative_prompt (`str` or `List[str]`, *optional*):
559
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
560
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
561
+ less than `1`).
562
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
563
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
564
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
565
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
566
+ num_frames (`int`, defaults to `48`):
567
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
568
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
569
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
570
+ needs to be satisfied is that of divisibility mentioned above.
571
+ num_inference_steps (`int`, *optional*, defaults to 50):
572
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
573
+ expense of slower inference.
574
+ timesteps (`List[int]`, *optional*):
575
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
576
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
577
+ passed will be used. Must be in descending order.
578
+ guidance_scale (`float`, *optional*, defaults to 7.0):
579
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
580
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
581
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
582
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
583
+ usually at the expense of lower image quality.
584
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
585
+ The number of videos to generate per prompt.
586
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
587
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
588
+ to make generation deterministic.
589
+ latents (`torch.FloatTensor`, *optional*):
590
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
591
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
592
+ tensor will ge generated by sampling using the supplied random `generator`.
593
+ prompt_embeds (`torch.FloatTensor`, *optional*):
594
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
595
+ provided, text embeddings will be generated from `prompt` input argument.
596
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
597
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
598
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
599
+ argument.
600
+ output_type (`str`, *optional*, defaults to `"pil"`):
601
+ The output format of the generate image. Choose between
602
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
603
+ return_dict (`bool`, *optional*, defaults to `True`):
604
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
605
+ of a plain tuple.
606
+ callback_on_step_end (`Callable`, *optional*):
607
+ A function that calls at the end of each denoising steps during the inference. The function is called
608
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
609
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
610
+ `callback_on_step_end_tensor_inputs`.
611
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
612
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
613
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
614
+ `._callback_tensor_inputs` attribute of your pipeline class.
615
+ max_sequence_length (`int`, defaults to `226`):
616
+ Maximum sequence length in encoded prompt. Must be consistent with
617
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
618
+
619
+ Examples:
620
+
621
+ Returns:
622
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
623
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
624
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
625
+ """
626
+
627
+ if num_frames > 49:
628
+ raise ValueError(
629
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
630
+ )
631
+
632
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
633
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
634
+
635
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
636
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
637
+ num_videos_per_prompt = 1
638
+
639
+ # 1. Check inputs. Raise error if not correct
640
+ self.check_inputs(
641
+ first_image,
642
+ last_image,
643
+ prompt,
644
+ height,
645
+ width,
646
+ negative_prompt,
647
+ callback_on_step_end_tensor_inputs,
648
+ prompt_embeds,
649
+ negative_prompt_embeds,
650
+ )
651
+ self._guidance_scale = guidance_scale
652
+ self._interrupt = False
653
+
654
+ # 2. Default call parameters
655
+ if prompt is not None and isinstance(prompt, str):
656
+ batch_size = 1
657
+ elif prompt is not None and isinstance(prompt, list):
658
+ batch_size = len(prompt)
659
+ else:
660
+ batch_size = prompt_embeds.shape[0]
661
+
662
+ device = self._execution_device
663
+
664
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
665
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
666
+ # corresponds to doing no classifier free guidance.
667
+ do_classifier_free_guidance = guidance_scale > 1.0
668
+
669
+ # 3. Encode input prompt
670
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
671
+ prompt=prompt,
672
+ negative_prompt=negative_prompt,
673
+ do_classifier_free_guidance=do_classifier_free_guidance,
674
+ num_videos_per_prompt=num_videos_per_prompt,
675
+ prompt_embeds=prompt_embeds,
676
+ negative_prompt_embeds=negative_prompt_embeds,
677
+ max_sequence_length=max_sequence_length,
678
+ device=device,
679
+ )
680
+ if do_classifier_free_guidance:
681
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
682
+
683
+ # 4. Prepare timesteps
684
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
685
+ self._num_timesteps = len(timesteps)
686
+
687
+ # 5. Prepare latents
688
+ first_image = self.video_processor.preprocess(first_image, height=height, width=width).to(
689
+ device, dtype=prompt_embeds.dtype
690
+ )
691
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
692
+ device, dtype=prompt_embeds.dtype
693
+ )
694
+
695
+ latent_channels = self.transformer.config.in_channels // 2
696
+ latents, image_latents = self.prepare_latents(
697
+ first_image,
698
+ last_image,
699
+ batch_size * num_videos_per_prompt,
700
+ latent_channels,
701
+ num_frames,
702
+ height,
703
+ width,
704
+ prompt_embeds.dtype,
705
+ device,
706
+ generator,
707
+ latents,
708
+ )
709
+
710
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
711
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
712
+
713
+ # 7. Create rotary embeds if required
714
+ image_rotary_emb = (
715
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
716
+ if self.transformer.config.use_rotary_positional_embeddings
717
+ else None
718
+ )
719
+
720
+ # 8. Denoising loop
721
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
722
+
723
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
724
+ # for DPM-solver++
725
+ old_pred_original_sample = None
726
+ for i, t in enumerate(timesteps):
727
+ if self.interrupt:
728
+ continue
729
+
730
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
731
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
732
+
733
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
734
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
735
+
736
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
737
+ timestep = t.expand(latent_model_input.shape[0])
738
+
739
+ # predict noise model_output
740
+ noise_pred = self.transformer(
741
+ hidden_states=latent_model_input,
742
+ encoder_hidden_states=prompt_embeds,
743
+ timestep=timestep,
744
+ image_rotary_emb=image_rotary_emb,
745
+ return_dict=False,
746
+ )[0]
747
+ noise_pred = noise_pred.float()
748
+
749
+ # perform guidance
750
+ if use_dynamic_cfg:
751
+ self._guidance_scale = 1 + guidance_scale * (
752
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
753
+ )
754
+ if do_classifier_free_guidance:
755
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
756
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
757
+
758
+ # compute the previous noisy sample x_t -> x_t-1
759
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
760
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
761
+ else:
762
+ latents, old_pred_original_sample = self.scheduler.step(
763
+ noise_pred,
764
+ old_pred_original_sample,
765
+ t,
766
+ timesteps[i - 1] if i > 0 else None,
767
+ latents,
768
+ **extra_step_kwargs,
769
+ return_dict=False,
770
+ )
771
+ latents = latents.to(prompt_embeds.dtype)
772
+
773
+ # call the callback, if provided
774
+ if callback_on_step_end is not None:
775
+ callback_kwargs = {}
776
+ for k in callback_on_step_end_tensor_inputs:
777
+ callback_kwargs[k] = locals()[k]
778
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
779
+
780
+ latents = callback_outputs.pop("latents", latents)
781
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
782
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
783
+
784
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
785
+ progress_bar.update()
786
+
787
+ if not output_type == "latent":
788
+ video = self.decode_latents(latents)
789
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
790
+ else:
791
+ video = latents
792
+
793
+ # Offload all models
794
+ self.maybe_free_model_hooks()
795
+
796
+ if not return_dict:
797
+ return (video,)
798
+
799
+ return (video,)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers==0.30.3
2
+ transformers==4.44.2
3
+ accelerate==0.34.0
4
+ gradio>=4.0.0
5
+ torch>=2.0.0
6
+ torchvision
7
+ Pillow