vidfom commited on
Commit
8c6346f
·
verified ·
1 Parent(s): 61f9885
Files changed (8) hide show
  1. .gitattributes +4 -35
  2. .gitignore +165 -0
  3. deploy.sh +19 -0
  4. gradio_ap.py +541 -0
  5. gradio_app.py +593 -0
  6. gradio_appp.py +312 -0
  7. inference.py +444 -0
  8. requirements.txt +15 -0
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.gif filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ .idea/
163
+
164
+ # From inference.py
165
+ video_output_*.mp4
deploy.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Replace <repository_url> with your GitHub repository URL
4
+ REPO_URL="https://github.com/unitconvert/Model-.git"
5
+
6
+ # Navigate to the directory containing your pages
7
+ cd "D:/modle/LTX-Video-gradio-ui" || exit
8
+
9
+ # Initialize a Git repository if not already initialized
10
+ if [ ! -d ".git" ]; then
11
+ git init
12
+ git remote add origin "$REPO_URL"
13
+ git checkout -b gh-pages
14
+ fi
15
+
16
+ # Add all files, commit, and push to gh-pages branch
17
+ git add .
18
+ git commit -m "Update pages"
19
+ git push origin gh-pages
gradio_ap.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from gradio_toggle import Toggle
4
+ import argparse
5
+ import json
6
+ import os
7
+ import random
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from diffusers.utils import logging
11
+
12
+ import imageio
13
+ import numpy as np
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from PIL import Image
18
+ from transformers import T5EncoderModel, T5Tokenizer
19
+ import tempfile
20
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
21
+ CausalVideoAutoencoder,
22
+ )
23
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
24
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
25
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
26
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
27
+ from ltx_video.utils.conditioning_method import ConditioningMethod
28
+ from torchao.quantization import quantize_, int8_weight_only
29
+
30
+ MAX_HEIGHT = 720
31
+ MAX_WIDTH = 1280
32
+ MAX_NUM_FRAMES = 257
33
+
34
+ def load_vae(vae_dir, int8=False):
35
+ vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
36
+ vae_config_path = vae_dir / "config.json"
37
+ with open(vae_config_path, "r") as f:
38
+ vae_config = json.load(f)
39
+ vae = CausalVideoAutoencoder.from_config(vae_config)
40
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
41
+ vae.load_state_dict(vae_state_dict)
42
+ if int8:
43
+ print("vae - quantization = true")
44
+ quantize_(vae, int8_weight_only())
45
+ return vae.to(torch.bfloat16)
46
+
47
+ def load_unet(unet_dir, int8=False):
48
+ unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
49
+ unet_config_path = unet_dir / "config.json"
50
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
51
+ transformer = Transformer3DModel.from_config(transformer_config)
52
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
53
+ transformer.load_state_dict(unet_state_dict, strict=True)
54
+ if int8:
55
+ print("unet - quantization = true")
56
+ quantize_(transformer, int8_weight_only())
57
+ return transformer
58
+
59
+ def load_scheduler(scheduler_dir):
60
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
61
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
62
+ return RectifiedFlowScheduler.from_config(scheduler_config)
63
+
64
+ def load_image_to_tensor_with_resize_and_crop(image_path, target_height=512, target_width=768):
65
+ image = Image.open(image_path).convert("RGB")
66
+ input_width, input_height = image.size
67
+ aspect_ratio_target = target_width / target_height
68
+ aspect_ratio_frame = input_width / input_height
69
+ if aspect_ratio_frame > aspect_ratio_target:
70
+ new_width = int(input_height * aspect_ratio_target)
71
+ new_height = input_height
72
+ x_start = (input_width - new_width) // 2
73
+ y_start = 0
74
+ else:
75
+ new_width = input_width
76
+ new_height = int(input_width / aspect_ratio_target)
77
+ x_start = 0
78
+ y_start = (input_height - new_height) // 2
79
+
80
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
81
+ image = image.resize((target_width, target_height))
82
+ frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
83
+ frame_tensor = (frame_tensor / 127.5) - 1.0
84
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
85
+
86
+ def calculate_padding(
87
+ source_height: int, source_width: int, target_height: int, target_width: int
88
+ ) -> tuple[int, int, int, int]:
89
+ pad_height = target_height - source_height
90
+ pad_width = target_width - source_width
91
+ pad_top = pad_height // 2
92
+ pad_bottom = pad_height - pad_top
93
+ pad_left = pad_width // 2
94
+ pad_right = pad_width - pad_left
95
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
96
+ return padding
97
+
98
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
99
+ clean_text = "".join(
100
+ char.lower() for char in text if char.isalpha() or char.isspace()
101
+ )
102
+ words = clean_text.split()
103
+ result = []
104
+ current_length = 0
105
+
106
+ for word in words:
107
+ new_length = current_length + len(word)
108
+ if new_length <= max_len:
109
+ result.append(word)
110
+ current_length += len(word)
111
+ else:
112
+ break
113
+
114
+ return "-".join(result)
115
+
116
+ def get_unique_filename(
117
+ base: str,
118
+ ext: str,
119
+ prompt: str,
120
+ seed: int,
121
+ resolution: tuple[int, int, int],
122
+ dir: Path,
123
+ endswith=None,
124
+ index_range=1000,
125
+ ) -> Path:
126
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
127
+ for i in range(index_range):
128
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
129
+ if not os.path.exists(filename):
130
+ return filename
131
+ raise FileExistsError(
132
+ f"Could not find a unique filename after {index_range} attempts."
133
+ )
134
+
135
+ def seed_everething(seed: int):
136
+ random.seed(seed)
137
+ np.random.seed(seed)
138
+ torch.manual_seed(seed)
139
+
140
+ def main(
141
+ img2vid_image="",
142
+ prompt="",
143
+ txt2vid_analytics_toggle=False,
144
+ negative_prompt="",
145
+ frame_rate=25,
146
+ seed=0,
147
+ num_inference_steps=30,
148
+ guidance_scale=3,
149
+ height=512,
150
+ width=768,
151
+ num_frames=121,
152
+ progress=gr.Progress(),
153
+ ):
154
+
155
+ logger = logging.get_logger(__name__)
156
+
157
+ args = {
158
+ "ckpt_dir": "Lightricks/LTX-Video",
159
+ "num_inference_steps": num_inference_steps,
160
+ "guidance_scale": guidance_scale,
161
+ "height": height,
162
+ "width": width,
163
+ "num_frames": num_frames,
164
+ "frame_rate": frame_rate,
165
+ "prompt": prompt,
166
+ "negative_prompt": negative_prompt,
167
+ "seed": 0,
168
+ "output_path": os.path.join(tempfile.gettempdir(), "gradio"),
169
+ "num_images_per_prompt": 1,
170
+ "input_image_path": img2vid_image,
171
+ "input_video_path": "",
172
+ "bfloat16": True,
173
+ "disable_load_needed_only": False
174
+ }
175
+ logger.warning(f"Running generation with arguments: {args}")
176
+
177
+ seed_everething(args['seed'])
178
+
179
+ output_dir = (
180
+ Path(args['output_path'])
181
+ if args['output_path']
182
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
183
+ )
184
+ output_dir.mkdir(parents=True, exist_ok=True)
185
+
186
+ if args['input_image_path']:
187
+ media_items_prepad = load_image_to_tensor_with_resize_and_crop(
188
+ args['input_image_path'], args['height'], args['width']
189
+ )
190
+ else:
191
+ media_items_prepad = None
192
+
193
+ height = args['height'] if args['height'] else media_items_prepad.shape[-2]
194
+ width = args['width'] if args['width'] else media_items_prepad.shape[-1]
195
+ num_frames = args['num_frames']
196
+
197
+ if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES:
198
+ logger.warning(
199
+ f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}."
200
+ )
201
+
202
+ height_padded = ((height - 1) // 32 + 1) * 32
203
+ width_padded = ((width - 1) // 32 + 1) * 32
204
+ num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
205
+
206
+ padding = calculate_padding(height, width, height_padded, width_padded)
207
+
208
+ logger.warning(
209
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
210
+ )
211
+
212
+ if media_items_prepad is not None:
213
+ media_items = F.pad(
214
+ media_items_prepad, padding, mode="constant", value=-1
215
+ )
216
+ else:
217
+ media_items = None
218
+
219
+ ckpt_dir = Path(args['ckpt_dir'])
220
+ unet_dir = ckpt_dir / "unet"
221
+ vae_dir = ckpt_dir / "vae"
222
+ scheduler_dir = ckpt_dir / "scheduler"
223
+
224
+ vae = load_vae(vae_dir, txt2vid_analytics_toggle)
225
+ unet = load_unet(unet_dir, txt2vid_analytics_toggle)
226
+ scheduler = load_scheduler(scheduler_dir)
227
+ patchifier = SymmetricPatchifier(patch_size=1)
228
+ text_encoder = T5EncoderModel.from_pretrained(
229
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
230
+ ).to(torch.bfloat16)
231
+
232
+ tokenizer = T5Tokenizer.from_pretrained(
233
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
234
+ )
235
+
236
+ if args['bfloat16'] and unet.dtype != torch.bfloat16:
237
+ unet = unet.to(torch.bfloat16)
238
+
239
+ submodel_dict = {
240
+ "transformer": unet,
241
+ "patchifier": patchifier,
242
+ "text_encoder": text_encoder,
243
+ "tokenizer": tokenizer,
244
+ "scheduler": scheduler,
245
+ "vae": vae,
246
+ }
247
+
248
+ pipeline = LTXVideoPipeline(**submodel_dict)
249
+
250
+ sample = {
251
+ "prompt": args['prompt'],
252
+ "prompt_attention_mask": None,
253
+ "negative_prompt": args['negative_prompt'],
254
+ "negative_prompt_attention_mask": None,
255
+ "media_items": media_items,
256
+ }
257
+
258
+ generator = torch.Generator(
259
+ device="cpu"
260
+ ).manual_seed(args['seed'])
261
+
262
+ images = pipeline(
263
+ num_inference_steps=args['num_inference_steps'],
264
+ num_images_per_prompt=args['num_images_per_prompt'],
265
+ guidance_scale=args['guidance_scale'],
266
+ generator=generator,
267
+ output_type="pt",
268
+ callback_on_step_end=None,
269
+ height=height_padded,
270
+ width=width_padded,
271
+ num_frames=num_frames_padded,
272
+ frame_rate=args['frame_rate'],
273
+ **sample,
274
+ is_video=True,
275
+ vae_per_channel_normalize=True,
276
+ conditioning_method=(
277
+ ConditioningMethod.FIRST_FRAME
278
+ if media_items is not None
279
+ else ConditioningMethod.UNCONDITIONAL
280
+ ),
281
+ mixed_precision=not args['bfloat16'],
282
+ load_needed_only=not args['disable_load_needed_only']
283
+ ).images
284
+
285
+ (pad_left, pad_right, pad_top, pad_bottom) = padding
286
+ pad_bottom = -pad_bottom
287
+ pad_right = -pad_right
288
+ if pad_bottom == 0:
289
+ pad_bottom = images.shape[3]
290
+ if pad_right == 0:
291
+ pad_right = images.shape[4]
292
+ images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
293
+
294
+ for i in range(images.shape[0]):
295
+ video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
296
+ video_np = (video_np * 255).astype(np.uint8)
297
+ fps = args['frame_rate']
298
+ height, width = video_np.shape[1:3]
299
+ if video_np.shape[0] == 1:
300
+ output_filename = get_unique_filename(
301
+ f"image_output_{i}",
302
+ ".png",
303
+ prompt=args['prompt'],
304
+ seed=args['seed'],
305
+ resolution=(height, width, num_frames),
306
+ dir=output_dir,
307
+ )
308
+ imageio.imwrite(output_filename, video_np[0])
309
+ else:
310
+ if args['input_image_path']:
311
+ base_filename = f"img_to_vid_{i}"
312
+ else:
313
+ base_filename = f"text_to_vid_{i}"
314
+ output_filename = get_unique_filename(
315
+ base_filename,
316
+ ".mp4",
317
+ prompt=args['prompt'],
318
+ seed=args['seed'],
319
+ resolution=(height, width, num_frames),
320
+ dir=output_dir,
321
+ )
322
+
323
+ with imageio.get_writer(output_filename, fps=fps) as video:
324
+ for frame in video_np:
325
+ video.append_data(frame)
326
+
327
+ if args['input_image_path']:
328
+ reference_image = (
329
+ (
330
+ media_items_prepad[0, :, 0].permute(1, 2, 0).cpu().data.numpy()
331
+ + 1.0
332
+ )
333
+ / 2.0
334
+ * 255
335
+ )
336
+ imageio.imwrite(
337
+ get_unique_filename(
338
+ base_filename,
339
+ ".png",
340
+ prompt=args['prompt'],
341
+ seed=args['seed'],
342
+ resolution=(height, width, num_frames),
343
+ dir=output_dir,
344
+ endswith="_condition",
345
+ ),
346
+ reference_image.astype(np.uint8),
347
+ )
348
+ logger.warning(f"Output saved {output_filename}")
349
+ return output_filename
350
+
351
+ preset_options = [
352
+ {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
353
+ {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
354
+ {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
355
+ {"label": "992x608, 65 frames", "width": 992, "height": 608, "num_frames": 65},
356
+ {"label": "896x608, 73 frames", "width": 896, "height": 608, "num_frames": 73},
357
+ {"label": "896x544, 81 frames", "width": 896, "height": 544, "num_frames": 81},
358
+ {"label": "832x544, 89 frames", "width": 832, "height": 544, "num_frames": 89},
359
+ {"label": "800x512, 97 frames", "width": 800, "height": 512, "num_frames": 97},
360
+ {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
361
+ {"label": "800x480, 105 frames", "width": 800, "height": 480, "num_frames": 105},
362
+ {"label": "736x480, 113 frames", "width": 736, "height": 480, "num_frames": 113},
363
+ {"label": "704x480, 121 frames", "width": 704, "height": 480, "num_frames": 121},
364
+ {"label": "704x448, 129 frames", "width": 704, "height": 448, "num_frames": 129},
365
+ {"label": "672x448, 137 frames", "width": 672, "height": 448, "num_frames": 137},
366
+ {"label": "640x416, 153 frames", "width": 640, "height": 416, "num_frames": 153},
367
+ {"label": "672x384, 161 frames", "width": 672, "height": 384, "num_frames": 161},
368
+ {"label": "640x384, 169 frames", "width": 640, "height": 384, "num_frames": 169},
369
+ {"label": "608x384, 177 frames", "width": 608, "height": 384, "num_frames": 177},
370
+ {"label": "576x384, 185 frames", "width": 576, "height": 384, "num_frames": 185},
371
+ {"label": "608x352, 193 frames", "width": 608, "height": 352, "num_frames": 193},
372
+ {"label": "576x352, 201 frames", "width": 576, "height": 352, "num_frames": 201},
373
+ {"label": "544x352, 209 frames", "width": 544, "height": 352, "num_frames": 209},
374
+ {"label": "512x352, 225 frames", "width": 512, "height": 352, "num_frames": 225},
375
+ {"label": "512x352, 233 frames", "width": 512, "height": 352, "num_frames": 233},
376
+ {"label": "544x320, 241 frames", "width": 544, "height": 320, "num_frames": 241},
377
+ {"label": "512x320, 249 frames", "width": 512, "height": 320, "num_frames": 249},
378
+ {"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
379
+ ]
380
+
381
+ def create_advanced_options():
382
+ with gr.Accordion("Advanced Options (Optional)", open=False):
383
+ seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=0)
384
+ inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=30)
385
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=3.0)
386
+
387
+ height_slider = gr.Slider(
388
+ label="Height",
389
+ minimum=256,
390
+ maximum=1024,
391
+ step=64,
392
+ value=512,
393
+ visible=False,
394
+ )
395
+ width_slider = gr.Slider(
396
+ label="Width",
397
+ minimum=256,
398
+ maximum=1024,
399
+ step=64,
400
+ value=768,
401
+ visible=False,
402
+ )
403
+ num_frames_slider = gr.Slider(
404
+ label="Number of Frames",
405
+ minimum=1,
406
+ maximum=200,
407
+ step=1,
408
+ value=97,
409
+ visible=False,
410
+ )
411
+
412
+ return [
413
+ seed,
414
+ inference_steps,
415
+ guidance_scale,
416
+ height_slider,
417
+ width_slider,
418
+ num_frames_slider,
419
+ ]
420
+
421
+ def preset_changed(preset):
422
+ if preset != "Custom":
423
+ selected = next(item for item in preset_options if item["label"] == preset)
424
+ return (
425
+ selected["height"],
426
+ selected["width"],
427
+ selected["num_frames"],
428
+ gr.update(visible=False),
429
+ gr.update(visible=False),
430
+ gr.update(visible=False),
431
+ )
432
+ else:
433
+ return (
434
+ None,
435
+ None,
436
+ None,
437
+ gr.update(visible=True),
438
+ gr.update(visible=True),
439
+ gr.update(visible=True),
440
+ )
441
+
442
+
443
+ css="""
444
+ #col-container {
445
+ margin: 0 auto;
446
+ max-width: 1220px;
447
+ }
448
+ """
449
+
450
+ with gr.Blocks(css=css) as demo:
451
+ with gr.Row():
452
+ with gr.Column():
453
+ img2vid_image = gr.Image(
454
+ type="filepath",
455
+ label="Upload Input Image",
456
+ elem_id="image_upload",
457
+ )
458
+
459
+ txt2vid_prompt = gr.Textbox(
460
+ label="Enter Your Prompt",
461
+ placeholder="Describe the video you want to generate (minimum 50 characters)...",
462
+ value="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.",
463
+ lines=5,
464
+ )
465
+
466
+ txt2vid_analytics_toggle = Toggle(
467
+ label="torchao.quantization.",
468
+ value=False,
469
+ interactive=True,
470
+ )
471
+
472
+ txt2vid_negative_prompt = gr.Textbox(
473
+ label="Enter Negative Prompt",
474
+ placeholder="Describe what you don't want in the video...",
475
+ value="low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
476
+ lines=2,
477
+ )
478
+
479
+ txt2vid_preset = gr.Dropdown(
480
+ choices=[p["label"] for p in preset_options],
481
+ value="768x512, 97 frames",
482
+ label="Choose Resolution Preset",
483
+ )
484
+
485
+ txt2vid_frame_rate = gr.Slider(
486
+ label="Frame Rate",
487
+ minimum=21,
488
+ maximum=30,
489
+ step=1,
490
+ value=25,
491
+ )
492
+
493
+ txt2vid_advanced = create_advanced_options()
494
+
495
+ txt2vid_generate = gr.Button(
496
+ "Generate Video",
497
+ variant="primary",
498
+ size="lg",
499
+ )
500
+
501
+ with gr.Column():
502
+ txt2vid_output = gr.Video(label="Generated Output")
503
+
504
+ with gr.Row():
505
+ gr.Examples(
506
+ examples=[
507
+ [
508
+ "A young woman in a traditional Mongolian dress is peeking through a sheer white curtain, her face showing a mix of curiosity and apprehension. The woman has long black hair styled in two braids, adorned with white beads, and her eyes are wide with a hint of surprise. Her dress is a vibrant blue with intricate gold embroidery, and she wears a matching headband with a similar design. The background is a simple white curtain, which creates a sense of mystery and intrigue.ith long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair’s face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
509
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
510
+ ],
511
+ [
512
+ "A young man with blond hair wearing a yellow jacket stands in a forest and looks around. He has light skin and his hair is styled with a middle part. He looks to the left and then to the right, his gaze lingering in each direction. The camera angle is low, looking up at the man, and remains stationary throughout the video. The background is slightly out of focus, with green trees and the sun shining brightly behind the man. The lighting is natural and warm, with the sun creating a lens flare that moves across the man’s face. The scene is captured in real-life footage.",
513
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
514
+ ],
515
+ [
516
+ "A cyclist races along a winding mountain road. Clad in aerodynamic gear, he pedals intensely, sweat glistening on his brow. The camera alternates between close-ups of his determined expression and wide shots of the breathtaking landscape. Pine trees blur past, and the sky is a crisp blue. The scene is invigorating and competitive.",
517
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
518
+ ],
519
+ ],
520
+ inputs=[txt2vid_prompt, txt2vid_negative_prompt, txt2vid_output],
521
+ label="Example Text-to-Video Generations",
522
+ )
523
+
524
+ txt2vid_preset.change(fn=preset_changed, inputs=[txt2vid_preset], outputs=txt2vid_advanced[3:])
525
+
526
+ txt2vid_generate.click(
527
+ fn=main,
528
+ inputs=[
529
+ img2vid_image,
530
+ txt2vid_prompt,
531
+ txt2vid_analytics_toggle,
532
+ txt2vid_negative_prompt,
533
+ txt2vid_frame_rate,
534
+ *txt2vid_advanced,
535
+ ],
536
+ outputs=txt2vid_output,
537
+ concurrency_limit=1,
538
+ concurrency_id="generate_video",
539
+ queue=True,
540
+ )
541
+ demo.launch()
gradio_app.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from gradio_toggle import Toggle
4
+ import argparse
5
+ import json
6
+ import os
7
+ import random
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from diffusers.utils import logging
11
+
12
+ import imageio
13
+ import numpy as np
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from PIL import Image
18
+ from transformers import T5EncoderModel, T5Tokenizer
19
+ import tempfile
20
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
21
+ CausalVideoAutoencoder,
22
+ )
23
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
24
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
25
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
26
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
27
+ from ltx_video.utils.conditioning_method import ConditioningMethod
28
+ from torchao.quantization import quantize_, int8_weight_only
29
+
30
+ MAX_HEIGHT = 720
31
+ MAX_WIDTH = 1280
32
+ MAX_NUM_FRAMES = 257
33
+
34
+
35
+ def load_vae(vae_dir, int8=False):
36
+ vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
37
+ vae_config_path = vae_dir / "config.json"
38
+ with open(vae_config_path, "r") as f:
39
+ vae_config = json.load(f)
40
+ vae = CausalVideoAutoencoder.from_config(vae_config)
41
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
42
+ vae.load_state_dict(vae_state_dict)
43
+ if torch.cuda.is_available():
44
+ vae = vae.cuda()
45
+ if int8:
46
+ print("vae - quantization = true")
47
+ quantize_(vae, int8_weight_only())
48
+ torch.cuda.empty_cache()
49
+ return vae.to(torch.bfloat16)
50
+
51
+
52
+ def load_unet(unet_dir, int8=False):
53
+ unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
54
+ unet_config_path = unet_dir / "config.json"
55
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
56
+ transformer = Transformer3DModel.from_config(transformer_config)
57
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
58
+ transformer.load_state_dict(unet_state_dict, strict=True)
59
+ if torch.cuda.is_available():
60
+ transformer = transformer.cuda()
61
+ if int8:
62
+ print("unet - quantization = true")
63
+ quantize_(transformer, int8_weight_only())
64
+ torch.cuda.empty_cache()
65
+ return transformer
66
+
67
+
68
+ def load_scheduler(scheduler_dir):
69
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
70
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
71
+ return RectifiedFlowScheduler.from_config(scheduler_config)
72
+
73
+
74
+ def load_image_to_tensor_with_resize_and_crop(image_path, target_height=512, target_width=768):
75
+ image = Image.open(image_path).convert("RGB")
76
+ input_width, input_height = image.size
77
+ aspect_ratio_target = target_width / target_height
78
+ aspect_ratio_frame = input_width / input_height
79
+ if aspect_ratio_frame > aspect_ratio_target:
80
+ new_width = int(input_height * aspect_ratio_target)
81
+ new_height = input_height
82
+ x_start = (input_width - new_width) // 2
83
+ y_start = 0
84
+ else:
85
+ new_width = input_width
86
+ new_height = int(input_width / aspect_ratio_target)
87
+ x_start = 0
88
+ y_start = (input_height - new_height) // 2
89
+
90
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
91
+ image = image.resize((target_width, target_height))
92
+ frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
93
+ frame_tensor = (frame_tensor / 127.5) - 1.0
94
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
95
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
96
+
97
+
98
+ def calculate_padding(
99
+ source_height: int, source_width: int, target_height: int, target_width: int
100
+ ) -> tuple[int, int, int, int]:
101
+
102
+ # Calculate total padding needed
103
+ pad_height = target_height - source_height
104
+ pad_width = target_width - source_width
105
+
106
+ # Calculate padding for each side
107
+ pad_top = pad_height // 2
108
+ pad_bottom = pad_height - pad_top # Handles odd padding
109
+ pad_left = pad_width // 2
110
+ pad_right = pad_width - pad_left # Handles odd padding
111
+
112
+ # Return padded tensor
113
+ # Padding format is (left, right, top, bottom)
114
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
115
+ return padding
116
+
117
+
118
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
119
+ # Remove non-letters and convert to lowercase
120
+ clean_text = "".join(
121
+ char.lower() for char in text if char.isalpha() or char.isspace()
122
+ )
123
+
124
+ # Split into words
125
+ words = clean_text.split()
126
+
127
+ # Build result string keeping track of length
128
+ result = []
129
+ current_length = 0
130
+
131
+ for word in words:
132
+ # Add word length plus 1 for underscore (except for first word)
133
+ new_length = current_length + len(word)
134
+
135
+ if new_length <= max_len:
136
+ result.append(word)
137
+ current_length += len(word)
138
+ else:
139
+ break
140
+
141
+ return "-".join(result)
142
+
143
+
144
+ # Generate output video name
145
+ def get_unique_filename(
146
+ base: str,
147
+ ext: str,
148
+ prompt: str,
149
+ seed: int,
150
+ resolution: tuple[int, int, int],
151
+ dir: Path,
152
+ endswith=None,
153
+ index_range=1000,
154
+ ) -> Path:
155
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
156
+ for i in range(index_range):
157
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
158
+ if not os.path.exists(filename):
159
+ return filename
160
+ raise FileExistsError(
161
+ f"Could not find a unique filename after {index_range} attempts."
162
+ )
163
+
164
+
165
+ def seed_everething(seed: int):
166
+ random.seed(seed)
167
+ np.random.seed(seed)
168
+ torch.manual_seed(seed)
169
+ if torch.cuda.is_available():
170
+ torch.cuda.manual_seed(seed)
171
+
172
+
173
+ def main(
174
+ img2vid_image="",
175
+ prompt="",
176
+ txt2vid_analytics_toggle=False,
177
+ negative_prompt="",
178
+ frame_rate=25,
179
+ seed=0,
180
+ num_inference_steps=30,
181
+ guidance_scale=3,
182
+ height=512,
183
+ width=768,
184
+ num_frames=121,
185
+ progress=gr.Progress(),
186
+ ):
187
+
188
+ logger = logging.get_logger(__name__)
189
+
190
+ # args = parser.parse_args()
191
+ args = {
192
+ "ckpt_dir": "Lightricks/LTX-Video",
193
+ "num_inference_steps": num_inference_steps,
194
+ "guidance_scale": guidance_scale,
195
+ "height": height,
196
+ "width": width,
197
+ "num_frames": num_frames,
198
+ "frame_rate": frame_rate,
199
+ "prompt": prompt,
200
+ "negative_prompt": negative_prompt,
201
+ "seed": 0,
202
+ "output_path": os.path.join(tempfile.gettempdir(), "gradio"),
203
+ "num_images_per_prompt": 1,
204
+ "input_image_path": img2vid_image,
205
+ "input_video_path": "",
206
+ "bfloat16": True,
207
+ "disable_load_needed_only": False
208
+ }
209
+ logger.warning(f"Running generation with arguments: {args}")
210
+
211
+ seed_everething(args['seed'])
212
+
213
+ output_dir = (
214
+ Path(args['output_path'])
215
+ if args['output_path']
216
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
217
+ )
218
+ output_dir.mkdir(parents=True, exist_ok=True)
219
+
220
+ # Load image
221
+ if args['input_image_path']:
222
+ media_items_prepad = load_image_to_tensor_with_resize_and_crop(
223
+ args['input_image_path'], args['height'], args['width']
224
+ )
225
+ else:
226
+ media_items_prepad = None
227
+
228
+ height = args['height'] if args['height'] else media_items_prepad.shape[-2]
229
+ width = args['width'] if args['width'] else media_items_prepad.shape[-1]
230
+ num_frames = args['num_frames']
231
+
232
+ if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES:
233
+ logger.warning(
234
+ f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}."
235
+ )
236
+
237
+ # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
238
+ height_padded = ((height - 1) // 32 + 1) * 32
239
+ width_padded = ((width - 1) // 32 + 1) * 32
240
+ num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
241
+
242
+ padding = calculate_padding(height, width, height_padded, width_padded)
243
+
244
+ logger.warning(
245
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
246
+ )
247
+
248
+ if media_items_prepad is not None:
249
+ media_items = F.pad(
250
+ media_items_prepad, padding, mode="constant", value=-1
251
+ ) # -1 is the value for padding since the image is normalized to -1, 1
252
+ else:
253
+ media_items = None
254
+
255
+ # Paths for the separate mode directories
256
+ ckpt_dir = Path(args['ckpt_dir'])
257
+ unet_dir = ckpt_dir / "unet"
258
+ vae_dir = ckpt_dir / "vae"
259
+ scheduler_dir = ckpt_dir / "scheduler"
260
+
261
+ # Load models
262
+ vae = load_vae(vae_dir, txt2vid_analytics_toggle)
263
+ unet = load_unet(unet_dir, txt2vid_analytics_toggle)
264
+ scheduler = load_scheduler(scheduler_dir)
265
+ patchifier = SymmetricPatchifier(patch_size=1)
266
+ text_encoder = T5EncoderModel.from_pretrained(
267
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
268
+ ).to(torch.bfloat16)
269
+
270
+ # if torch.cuda.is_available():
271
+ # text_encoder = text_encoder.to("cuda")
272
+
273
+ tokenizer = T5Tokenizer.from_pretrained(
274
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
275
+ )
276
+
277
+ if args['bfloat16'] and unet.dtype != torch.bfloat16:
278
+ unet = unet.to(torch.bfloat16)
279
+
280
+ # Use submodels for the pipeline
281
+ submodel_dict = {
282
+ "transformer": unet,
283
+ "patchifier": patchifier,
284
+ "text_encoder": text_encoder,
285
+ "tokenizer": tokenizer,
286
+ "scheduler": scheduler,
287
+ "vae": vae,
288
+ }
289
+
290
+ pipeline = LTXVideoPipeline(**submodel_dict)
291
+ if torch.cuda.is_available() and args['disable_load_needed_only']:
292
+ pipeline = pipeline.to("cuda")
293
+
294
+ # Prepare input for the pipeline
295
+ sample = {
296
+ "prompt": args['prompt'],
297
+ "prompt_attention_mask": None,
298
+ "negative_prompt": args['negative_prompt'],
299
+ "negative_prompt_attention_mask": None,
300
+ "media_items": media_items,
301
+ }
302
+
303
+ generator = torch.Generator(
304
+ device="cuda" if torch.cuda.is_available() else "cpu"
305
+ ).manual_seed(args['seed'])
306
+
307
+ images = pipeline(
308
+ num_inference_steps=args['num_inference_steps'],
309
+ num_images_per_prompt=args['num_images_per_prompt'],
310
+ guidance_scale=args['guidance_scale'],
311
+ generator=generator,
312
+ output_type="pt",
313
+ callback_on_step_end=None,
314
+ height=height_padded,
315
+ width=width_padded,
316
+ num_frames=num_frames_padded,
317
+ frame_rate=args['frame_rate'],
318
+ **sample,
319
+ is_video=True,
320
+ vae_per_channel_normalize=True,
321
+ conditioning_method=(
322
+ ConditioningMethod.FIRST_FRAME
323
+ if media_items is not None
324
+ else ConditioningMethod.UNCONDITIONAL
325
+ ),
326
+ mixed_precision=not args['bfloat16'],
327
+ load_needed_only=not args['disable_load_needed_only']
328
+ ).images
329
+
330
+ # Crop the padded images to the desired resolution and number of frames
331
+ (pad_left, pad_right, pad_top, pad_bottom) = padding
332
+ pad_bottom = -pad_bottom
333
+ pad_right = -pad_right
334
+ if pad_bottom == 0:
335
+ pad_bottom = images.shape[3]
336
+ if pad_right == 0:
337
+ pad_right = images.shape[4]
338
+ images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
339
+
340
+ for i in range(images.shape[0]):
341
+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
342
+ video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
343
+ # Unnormalizing images to [0, 255] range
344
+ video_np = (video_np * 255).astype(np.uint8)
345
+ fps = args['frame_rate']
346
+ height, width = video_np.shape[1:3]
347
+ # In case a single image is generated
348
+ if video_np.shape[0] == 1:
349
+ output_filename = get_unique_filename(
350
+ f"image_output_{i}",
351
+ ".png",
352
+ prompt=args['prompt'],
353
+ seed=args['seed'],
354
+ resolution=(height, width, num_frames),
355
+ dir=output_dir,
356
+ )
357
+ imageio.imwrite(output_filename, video_np[0])
358
+ else:
359
+ if args['input_image_path']:
360
+ base_filename = f"img_to_vid_{i}"
361
+ else:
362
+ base_filename = f"text_to_vid_{i}"
363
+ output_filename = get_unique_filename(
364
+ base_filename,
365
+ ".mp4",
366
+ prompt=args['prompt'],
367
+ seed=args['seed'],
368
+ resolution=(height, width, num_frames),
369
+ dir=output_dir,
370
+ )
371
+
372
+ # Write video
373
+ with imageio.get_writer(output_filename, fps=fps) as video:
374
+ for frame in video_np:
375
+ video.append_data(frame)
376
+
377
+ # Write condition image
378
+ if args['input_image_path']:
379
+ reference_image = (
380
+ (
381
+ media_items_prepad[0, :, 0].permute(1, 2, 0).cpu().data.numpy()
382
+ + 1.0
383
+ )
384
+ / 2.0
385
+ * 255
386
+ )
387
+ imageio.imwrite(
388
+ get_unique_filename(
389
+ base_filename,
390
+ ".png",
391
+ prompt=args['prompt'],
392
+ seed=args['seed'],
393
+ resolution=(height, width, num_frames),
394
+ dir=output_dir,
395
+ endswith="_condition",
396
+ ),
397
+ reference_image.astype(np.uint8),
398
+ )
399
+ logger.warning(f"Output saved {output_filename}")
400
+ return output_filename
401
+
402
+
403
+ preset_options = [
404
+ {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
405
+ {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
406
+ {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
407
+ {"label": "992x608, 65 frames", "width": 992, "height": 608, "num_frames": 65},
408
+ {"label": "896x608, 73 frames", "width": 896, "height": 608, "num_frames": 73},
409
+ {"label": "896x544, 81 frames", "width": 896, "height": 544, "num_frames": 81},
410
+ {"label": "832x544, 89 frames", "width": 832, "height": 544, "num_frames": 89},
411
+ {"label": "800x512, 97 frames", "width": 800, "height": 512, "num_frames": 97},
412
+ {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
413
+ {"label": "800x480, 105 frames", "width": 800, "height": 480, "num_frames": 105},
414
+ {"label": "736x480, 113 frames", "width": 736, "height": 480, "num_frames": 113},
415
+ {"label": "704x480, 121 frames", "width": 704, "height": 480, "num_frames": 121},
416
+ {"label": "704x448, 129 frames", "width": 704, "height": 448, "num_frames": 129},
417
+ {"label": "672x448, 137 frames", "width": 672, "height": 448, "num_frames": 137},
418
+ {"label": "640x416, 153 frames", "width": 640, "height": 416, "num_frames": 153},
419
+ {"label": "672x384, 161 frames", "width": 672, "height": 384, "num_frames": 161},
420
+ {"label": "640x384, 169 frames", "width": 640, "height": 384, "num_frames": 169},
421
+ {"label": "608x384, 177 frames", "width": 608, "height": 384, "num_frames": 177},
422
+ {"label": "576x384, 185 frames", "width": 576, "height": 384, "num_frames": 185},
423
+ {"label": "608x352, 193 frames", "width": 608, "height": 352, "num_frames": 193},
424
+ {"label": "576x352, 201 frames", "width": 576, "height": 352, "num_frames": 201},
425
+ {"label": "544x352, 209 frames", "width": 544, "height": 352, "num_frames": 209},
426
+ {"label": "512x352, 225 frames", "width": 512, "height": 352, "num_frames": 225},
427
+ {"label": "512x352, 233 frames", "width": 512, "height": 352, "num_frames": 233},
428
+ {"label": "544x320, 241 frames", "width": 544, "height": 320, "num_frames": 241},
429
+ {"label": "512x320, 249 frames", "width": 512, "height": 320, "num_frames": 249},
430
+ {"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
431
+ ]
432
+
433
+ def create_advanced_options():
434
+ with gr.Accordion("Advanced Options (Optional)", open=False):
435
+ seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=0)
436
+ inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=30)
437
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=3.0)
438
+
439
+ height_slider = gr.Slider(
440
+ label="Height",
441
+ minimum=256,
442
+ maximum=1024,
443
+ step=64,
444
+ value=512,
445
+ visible=False,
446
+ )
447
+ width_slider = gr.Slider(
448
+ label="Width",
449
+ minimum=256,
450
+ maximum=1024,
451
+ step=64,
452
+ value=768,
453
+ visible=False,
454
+ )
455
+ num_frames_slider = gr.Slider(
456
+ label="Number of Frames",
457
+ minimum=1,
458
+ maximum=200,
459
+ step=1,
460
+ value=97,
461
+ visible=False,
462
+ )
463
+
464
+ return [
465
+ seed,
466
+ inference_steps,
467
+ guidance_scale,
468
+ height_slider,
469
+ width_slider,
470
+ num_frames_slider,
471
+ ]
472
+
473
+ def preset_changed(preset):
474
+ if preset != "Custom":
475
+ selected = next(item for item in preset_options if item["label"] == preset)
476
+ return (
477
+ selected["height"],
478
+ selected["width"],
479
+ selected["num_frames"],
480
+ gr.update(visible=False),
481
+ gr.update(visible=False),
482
+ gr.update(visible=False),
483
+ )
484
+ else:
485
+ return (
486
+ None,
487
+ None,
488
+ None,
489
+ gr.update(visible=True),
490
+ gr.update(visible=True),
491
+ gr.update(visible=True),
492
+ )
493
+
494
+
495
+ css="""
496
+ #col-container {
497
+ margin: 0 auto;
498
+ max-width: 1220px;
499
+ }
500
+ """
501
+
502
+ with gr.Blocks(css=css) as demo:
503
+ with gr.Row():
504
+ with gr.Column():
505
+ img2vid_image = gr.Image(
506
+ type="filepath",
507
+ label="Upload Input Image",
508
+ elem_id="image_upload",
509
+ )
510
+
511
+ txt2vid_prompt = gr.Textbox(
512
+ label="Enter Your Prompt",
513
+ placeholder="Describe the video you want to generate (minimum 50 characters)...",
514
+ value="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.",
515
+ lines=5,
516
+ )
517
+
518
+ txt2vid_analytics_toggle = Toggle(
519
+ label="torchao.quantization.",
520
+ value=False,
521
+ interactive=True,
522
+ )
523
+
524
+ txt2vid_negative_prompt = gr.Textbox(
525
+ label="Enter Negative Prompt",
526
+ placeholder="Describe what you don't want in the video...",
527
+ value="low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
528
+ lines=2,
529
+ )
530
+
531
+ txt2vid_preset = gr.Dropdown(
532
+ choices=[p["label"] for p in preset_options],
533
+ value="768x512, 97 frames",
534
+ label="Choose Resolution Preset",
535
+ )
536
+
537
+ txt2vid_frame_rate = gr.Slider(
538
+ label="Frame Rate",
539
+ minimum=21,
540
+ maximum=30,
541
+ step=1,
542
+ value=25,
543
+ )
544
+
545
+ txt2vid_advanced = create_advanced_options()
546
+
547
+ txt2vid_generate = gr.Button(
548
+ "Generate Video",
549
+ variant="primary",
550
+ size="lg",
551
+ )
552
+
553
+ with gr.Column():
554
+ txt2vid_output = gr.Video(label="Generated Output")
555
+
556
+ with gr.Row():
557
+ gr.Examples(
558
+ examples=[
559
+ [
560
+ "A young woman in a traditional Mongolian dress is peeking through a sheer white curtain, her face showing a mix of curiosity and apprehension. The woman has long black hair styled in two braids, adorned with white beads, and her eyes are wide with a hint of surprise. Her dress is a vibrant blue with intricate gold embroidery, and she wears a matching headband with a similar design. The background is a simple white curtain, which creates a sense of mystery and intrigue.ith long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair’s face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
561
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
562
+ ],
563
+ [
564
+ "A young man with blond hair wearing a yellow jacket stands in a forest and looks around. He has light skin and his hair is styled with a middle part. He looks to the left and then to the right, his gaze lingering in each direction. The camera angle is low, looking up at the man, and remains stationary throughout the video. The background is slightly out of focus, with green trees and the sun shining brightly behind the man. The lighting is natural and warm, with the sun creating a lens flare that moves across the man’s face. The scene is captured in real-life footage.",
565
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
566
+ ],
567
+ [
568
+ "A cyclist races along a winding mountain road. Clad in aerodynamic gear, he pedals intensely, sweat glistening on his brow. The camera alternates between close-ups of his determined expression and wide shots of the breathtaking landscape. Pine trees blur past, and the sky is a crisp blue. The scene is invigorating and competitive.",
569
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
570
+ ],
571
+ ],
572
+ inputs=[txt2vid_prompt, txt2vid_negative_prompt, txt2vid_output],
573
+ label="Example Text-to-Video Generations",
574
+ )
575
+
576
+ txt2vid_preset.change(fn=preset_changed, inputs=[txt2vid_preset], outputs=txt2vid_advanced[3:])
577
+
578
+ txt2vid_generate.click(
579
+ fn=main,
580
+ inputs=[
581
+ img2vid_image,
582
+ txt2vid_prompt,
583
+ txt2vid_analytics_toggle,
584
+ txt2vid_negative_prompt,
585
+ txt2vid_frame_rate,
586
+ *txt2vid_advanced,
587
+ ],
588
+ outputs=txt2vid_output,
589
+ concurrency_limit=1,
590
+ concurrency_id="generate_video",
591
+ queue=True,
592
+ )
593
+ demo.launch()
gradio_appp.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from gradio_toggle import Toggle
4
+ import argparse
5
+ import json
6
+ import os
7
+ import random
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from diffusers.utils import logging
11
+
12
+ import imageio
13
+ import numpy as np
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from PIL import Image
18
+ from transformers import T5EncoderModel, T5Tokenizer
19
+ import tempfile
20
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
21
+ CausalVideoAutoencoder,
22
+ )
23
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
24
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
25
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
26
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
27
+ from ltx_video.utils.conditioning_method import ConditioningMethod
28
+ from torchao.quantization import quantize_, int8_weight_only
29
+
30
+ MAX_HEIGHT = 720
31
+ MAX_WIDTH = 1280
32
+ MAX_NUM_FRAMES = 257
33
+
34
+
35
+ def load_vae(vae_dir, int8=False):
36
+ vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
37
+ vae_config_path = vae_dir / "config.json"
38
+ with open(vae_config_path, "r") as f:
39
+ vae_config = json.load(f)
40
+ vae = CausalVideoAutoencoder.from_config(vae_config)
41
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
42
+ vae.load_state_dict(vae_state_dict)
43
+ # Ensure everything runs on the CPU
44
+ vae = vae.to('cpu')
45
+ if int8:
46
+ print("vae - quantization = true")
47
+ quantize_(vae, int8_weight_only())
48
+ return vae
49
+
50
+
51
+ def load_unet(unet_dir, int8=False):
52
+ unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
53
+ unet_config_path = unet_dir / "config.json"
54
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
55
+ transformer = Transformer3DModel.from_config(transformer_config)
56
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
57
+ transformer.load_state_dict(unet_state_dict, strict=True)
58
+ # Ensure everything runs on the CPU
59
+ transformer = transformer.to('cpu')
60
+ if int8:
61
+ print("unet - quantization = true")
62
+ quantize_(transformer, int8_weight_only())
63
+ return transformer
64
+
65
+
66
+ def load_scheduler(scheduler_dir):
67
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
68
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
69
+ return RectifiedFlowScheduler.from_config(scheduler_config)
70
+
71
+
72
+ def load_image_to_tensor_with_resize_and_crop(image_path, target_height=512, target_width=768):
73
+ image = Image.open(image_path).convert("RGB")
74
+ input_width, input_height = image.size
75
+ aspect_ratio_target = target_width / target_height
76
+ aspect_ratio_frame = input_width / input_height
77
+ if aspect_ratio_frame > aspect_ratio_target:
78
+ new_width = int(input_height * aspect_ratio_target)
79
+ new_height = input_height
80
+ x_start = (input_width - new_width) // 2
81
+ y_start = 0
82
+ else:
83
+ new_width = input_width
84
+ new_height = int(input_width / aspect_ratio_target)
85
+ x_start = 0
86
+ y_start = (input_height - new_height) // 2
87
+
88
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
89
+ image = image.resize((target_width, target_height))
90
+ frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
91
+ frame_tensor = (frame_tensor / 127.5) - 1.0
92
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
93
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
94
+
95
+
96
+ def calculate_padding(
97
+ source_height: int, source_width: int, target_height: int, target_width: int
98
+ ) -> tuple[int, int, int, int]:
99
+
100
+ # Calculate total padding needed
101
+ pad_height = target_height - source_height
102
+ pad_width = target_width - source_width
103
+
104
+ # Calculate padding for each side
105
+ pad_top = pad_height // 2
106
+ pad_bottom = pad_height - pad_top # Handles odd padding
107
+ pad_left = pad_width // 2
108
+ pad_right = pad_width - pad_left # Handles odd padding
109
+
110
+ # Return padded tensor
111
+ # Padding format is (left, right, top, bottom)
112
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
113
+ return padding
114
+
115
+
116
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
117
+ # Remove non-letters and convert to lowercase
118
+ clean_text = "".join(
119
+ char.lower() for char in text if char.isalpha() or char.isspace()
120
+ )
121
+
122
+ # Split into words
123
+ words = clean_text.split()
124
+
125
+ # Build result string keeping track of length
126
+ result = []
127
+ current_length = 0
128
+
129
+ for word in words:
130
+ # Add word length plus 1 for underscore (except for first word)
131
+ new_length = current_length + len(word)
132
+
133
+ if new_length <= max_len:
134
+ result.append(word)
135
+ current_length += len(word)
136
+ else:
137
+ break
138
+
139
+ return "-".join(result)
140
+
141
+
142
+ # Generate output video name
143
+ def get_unique_filename(
144
+ base: str,
145
+ ext: str,
146
+ prompt: str,
147
+ seed: int,
148
+ resolution: tuple[int, int, int],
149
+ dir: Path,
150
+ endswith=None,
151
+ index_range=1000,
152
+ ) -> Path:
153
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
154
+ for i in range(index_range):
155
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
156
+ if not os.path.exists(filename):
157
+ return filename
158
+ raise FileExistsError(
159
+ f"Could not find a unique filename after {index_range} attempts."
160
+ )
161
+
162
+
163
+ def seed_everething(seed: int):
164
+ random.seed(seed)
165
+ np.random.seed(seed)
166
+ torch.manual_seed(seed)
167
+
168
+
169
+ def main(
170
+ img2vid_image="",
171
+ prompt="",
172
+ txt2vid_analytics_toggle=False,
173
+ negative_prompt="",
174
+ frame_rate=25,
175
+ seed=0,
176
+ num_inference_steps=30,
177
+ guidance_scale=3,
178
+ height=512,
179
+ width=768,
180
+ num_frames=121,
181
+ progress=gr.Progress(),
182
+ ):
183
+
184
+ logger = logging.get_logger(__name__)
185
+
186
+ args = {
187
+ "ckpt_dir": "Lightricks/LTX-Video",
188
+ "num_inference_steps": num_inference_steps,
189
+ "guidance_scale": guidance_scale,
190
+ "height": height,
191
+ "width": width,
192
+ "num_frames": num_frames,
193
+ "frame_rate": frame_rate,
194
+ "prompt": prompt,
195
+ "negative_prompt": negative_prompt,
196
+ "seed": 0,
197
+ "output_path": os.path.join(tempfile.gettempdir(), "gradio"),
198
+ "num_images_per_prompt": 1,
199
+ "input_image_path": img2vid_image,
200
+ "input_video_path": "",
201
+ "bfloat16": True,
202
+ "disable_load_needed_only": False
203
+ }
204
+ logger.warning(f"Running generation with arguments: {args}")
205
+
206
+ seed_everething(args['seed'])
207
+
208
+ output_dir = (
209
+ Path(args['output_path'])
210
+ if args['output_path']
211
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
212
+ )
213
+ output_dir.mkdir(parents=True, exist_ok=True)
214
+
215
+ # Load image
216
+ if args['input_image_path']:
217
+ media_items_prepad = load_image_to_tensor_with_resize_and_crop(
218
+ args['input_image_path'], args['height'], args['width']
219
+ )
220
+ else:
221
+ media_items_prepad = None
222
+
223
+ height = args['height'] if args['height'] else media_items_prepad.shape[-2]
224
+ width = args['width'] if args['width'] else media_items_prepad.shape[-1]
225
+ num_frames = args['num_frames']
226
+
227
+ if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES:
228
+ logger.warning(
229
+ f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}."
230
+ )
231
+
232
+ # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
233
+ height_padded = ((height - 1) // 32 + 1) * 32
234
+ width_padded = ((width - 1) // 32 + 1) * 32
235
+ num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
236
+
237
+ padding = calculate_padding(height, width, height_padded, width_padded)
238
+
239
+ logger.warning(
240
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
241
+ )
242
+
243
+ if media_items_prepad is not None:
244
+ media_items = F.pad(
245
+ media_items_prepad, padding, mode="constant", value=-1
246
+ ) # -1 is the value for padding since the image is normalized to -1, 1
247
+ else:
248
+ media_items = None
249
+
250
+ # Load models
251
+ vae = load_vae(Path(args['ckpt_dir']) / "vae", txt2vid_analytics_toggle)
252
+ unet = load_unet(Path(args['ckpt_dir']) / "unet", txt2vid_analytics_toggle)
253
+ scheduler = load_scheduler(Path(args['ckpt_dir']) / "scheduler")
254
+ patchifier = SymmetricPatchifier(patch_size=1)
255
+ text_encoder = T5EncoderModel.from_pretrained(
256
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
257
+ ).to('cpu') # Force to CPU
258
+
259
+ tokenizer = T5Tokenizer.from_pretrained(
260
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
261
+ )
262
+
263
+ # Use submodels for the pipeline
264
+ submodel_dict = {
265
+ "transformer": unet,
266
+ "patchifier": patchifier,
267
+ "text_encoder": text_encoder,
268
+ "tokenizer": tokenizer,
269
+ "scheduler": scheduler,
270
+ "vae": vae,
271
+ }
272
+
273
+ pipeline = LTXVideoPipeline(**submodel_dict)
274
+ pipeline = pipeline.to('cpu') # Ensure pipeline runs on CPU
275
+
276
+ # Prepare input for the pipeline
277
+ sample = {
278
+ "prompt": args['prompt'],
279
+ "prompt_attention_mask": None,
280
+ "negative_prompt": args['negative_prompt'],
281
+ "negative_prompt_attention_mask": None,
282
+ "media_items": media_items,
283
+ }
284
+
285
+ generator = torch.Generator(device="cpu").manual_seed(args['seed']) # Force CPU
286
+
287
+ images = pipeline(
288
+ num_inference_steps=args['num_inference_steps'],
289
+ num_images_per_prompt=args['num_images_per_prompt'],
290
+ guidance_scale=args['guidance_scale'],
291
+ generator=generator,
292
+ output_type="pt",
293
+ callback_on_step_end=None,
294
+ height=height_padded,
295
+ width=width_padded,
296
+ num_frames=num_frames_padded,
297
+ frame_rate=args['frame_rate'],
298
+ **sample,
299
+ is_video=True,
300
+ vae_per_channel_normalize=True,
301
+ conditioning_method=(
302
+ ConditioningMethod.FIRST_FRAME
303
+ if media_items is not None
304
+ else ConditioningMethod.UNCONDITIONAL
305
+ ),
306
+ mixed_precision=not args['bfloat16'],
307
+ load_needed_only=not args['disable_load_needed_only']
308
+ ).images
309
+
310
+ # Further processing and saving logic can go here...
311
+
312
+
inference.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from diffusers.utils import logging
8
+
9
+ import imageio
10
+ import numpy as np
11
+ import safetensors.torch
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ from transformers import T5EncoderModel, T5Tokenizer
16
+
17
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
18
+ CausalVideoAutoencoder,
19
+ )
20
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
21
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
22
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
23
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
24
+ from ltx_video.utils.conditioning_method import ConditioningMethod
25
+
26
+
27
+ MAX_HEIGHT = 720
28
+ MAX_WIDTH = 1280
29
+ MAX_NUM_FRAMES = 257
30
+
31
+
32
+ def load_vae(vae_dir):
33
+ vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
34
+ vae_config_path = vae_dir / "config.json"
35
+ with open(vae_config_path, "r") as f:
36
+ vae_config = json.load(f)
37
+ vae = CausalVideoAutoencoder.from_config(vae_config)
38
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
39
+ vae.load_state_dict(vae_state_dict)
40
+ if torch.cuda.is_available():
41
+ vae = vae.cuda()
42
+ return vae.to(torch.bfloat16)
43
+
44
+
45
+ def load_unet(unet_dir):
46
+ unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
47
+ unet_config_path = unet_dir / "config.json"
48
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
49
+ transformer = Transformer3DModel.from_config(transformer_config)
50
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
51
+ transformer.load_state_dict(unet_state_dict, strict=True)
52
+ if torch.cuda.is_available():
53
+ transformer = transformer.cuda()
54
+ return transformer
55
+
56
+
57
+ def load_scheduler(scheduler_dir):
58
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
59
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
60
+ return RectifiedFlowScheduler.from_config(scheduler_config)
61
+
62
+
63
+ def load_image_to_tensor_with_resize_and_crop(
64
+ image_path, target_height=512, target_width=768
65
+ ):
66
+ image = Image.open(image_path).convert("RGB")
67
+ input_width, input_height = image.size
68
+ aspect_ratio_target = target_width / target_height
69
+ aspect_ratio_frame = input_width / input_height
70
+ if aspect_ratio_frame > aspect_ratio_target:
71
+ new_width = int(input_height * aspect_ratio_target)
72
+ new_height = input_height
73
+ x_start = (input_width - new_width) // 2
74
+ y_start = 0
75
+ else:
76
+ new_width = input_width
77
+ new_height = int(input_width / aspect_ratio_target)
78
+ x_start = 0
79
+ y_start = (input_height - new_height) // 2
80
+
81
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
82
+ image = image.resize((target_width, target_height))
83
+ frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
84
+ frame_tensor = (frame_tensor / 127.5) - 1.0
85
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
86
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
87
+
88
+
89
+ def calculate_padding(
90
+ source_height: int, source_width: int, target_height: int, target_width: int
91
+ ) -> tuple[int, int, int, int]:
92
+
93
+ # Calculate total padding needed
94
+ pad_height = target_height - source_height
95
+ pad_width = target_width - source_width
96
+
97
+ # Calculate padding for each side
98
+ pad_top = pad_height // 2
99
+ pad_bottom = pad_height - pad_top # Handles odd padding
100
+ pad_left = pad_width // 2
101
+ pad_right = pad_width - pad_left # Handles odd padding
102
+
103
+ # Return padded tensor
104
+ # Padding format is (left, right, top, bottom)
105
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
106
+ return padding
107
+
108
+
109
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
110
+ # Remove non-letters and convert to lowercase
111
+ clean_text = "".join(
112
+ char.lower() for char in text if char.isalpha() or char.isspace()
113
+ )
114
+
115
+ # Split into words
116
+ words = clean_text.split()
117
+
118
+ # Build result string keeping track of length
119
+ result = []
120
+ current_length = 0
121
+
122
+ for word in words:
123
+ # Add word length plus 1 for underscore (except for first word)
124
+ new_length = current_length + len(word)
125
+
126
+ if new_length <= max_len:
127
+ result.append(word)
128
+ current_length += len(word)
129
+ else:
130
+ break
131
+
132
+ return "-".join(result)
133
+
134
+
135
+ # Generate output video name
136
+ def get_unique_filename(
137
+ base: str,
138
+ ext: str,
139
+ prompt: str,
140
+ seed: int,
141
+ resolution: tuple[int, int, int],
142
+ dir: Path,
143
+ endswith=None,
144
+ index_range=1000,
145
+ ) -> Path:
146
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
147
+ for i in range(index_range):
148
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
149
+ if not os.path.exists(filename):
150
+ return filename
151
+ raise FileExistsError(
152
+ f"Could not find a unique filename after {index_range} attempts."
153
+ )
154
+
155
+
156
+ def seed_everething(seed: int):
157
+ random.seed(seed)
158
+ np.random.seed(seed)
159
+ torch.manual_seed(seed)
160
+ if torch.cuda.is_available():
161
+ torch.cuda.manual_seed(seed)
162
+
163
+
164
+ def main():
165
+ parser = argparse.ArgumentParser(
166
+ description="Load models from separate directories and run the pipeline."
167
+ )
168
+
169
+ # Directories
170
+ parser.add_argument(
171
+ "--ckpt_dir",
172
+ type=str,
173
+ required=True,
174
+ help="Path to the directory containing unet, vae, and scheduler subdirectories",
175
+ )
176
+ parser.add_argument(
177
+ "--input_video_path",
178
+ type=str,
179
+ help="Path to the input video file (first frame used)",
180
+ )
181
+ parser.add_argument(
182
+ "--input_image_path", type=str, help="Path to the input image file"
183
+ )
184
+ parser.add_argument(
185
+ "--output_path",
186
+ type=str,
187
+ default=None,
188
+ help="Path to the folder to save output video, if None will save in outputs/ directory.",
189
+ )
190
+ parser.add_argument("--seed", type=int, default="171198")
191
+
192
+ # Pipeline parameters
193
+ parser.add_argument(
194
+ "--num_inference_steps", type=int, default=40, help="Number of inference steps"
195
+ )
196
+ parser.add_argument(
197
+ "--num_images_per_prompt",
198
+ type=int,
199
+ default=1,
200
+ help="Number of images per prompt",
201
+ )
202
+ parser.add_argument(
203
+ "--guidance_scale",
204
+ type=float,
205
+ default=3,
206
+ help="Guidance scale for the pipeline",
207
+ )
208
+ parser.add_argument(
209
+ "--height",
210
+ type=int,
211
+ default=480,
212
+ help="Height of the output video frames. Optional if an input image provided.",
213
+ )
214
+ parser.add_argument(
215
+ "--width",
216
+ type=int,
217
+ default=704,
218
+ help="Width of the output video frames. If None will infer from input image.",
219
+ )
220
+ parser.add_argument(
221
+ "--num_frames",
222
+ type=int,
223
+ default=121,
224
+ help="Number of frames to generate in the output video",
225
+ )
226
+ parser.add_argument(
227
+ "--frame_rate", type=int, default=25, help="Frame rate for the output video"
228
+ )
229
+
230
+ parser.add_argument(
231
+ "--bfloat16",
232
+ action="store_true",
233
+ help="Denoise in bfloat16",
234
+ )
235
+
236
+ # Prompts
237
+ parser.add_argument(
238
+ "--prompt",
239
+ type=str,
240
+ help="Text prompt to guide generation",
241
+ )
242
+ parser.add_argument(
243
+ "--negative_prompt",
244
+ type=str,
245
+ default="worst quality, inconsistent motion, blurry, jittery, distorted",
246
+ help="Negative prompt for undesired features",
247
+ )
248
+
249
+ logger = logging.get_logger(__name__)
250
+
251
+ args = parser.parse_args()
252
+
253
+ logger.warning(f"Running generation with arguments: {args}")
254
+
255
+ seed_everething(args.seed)
256
+
257
+ output_dir = (
258
+ Path(args.output_path)
259
+ if args.output_path
260
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
261
+ )
262
+ output_dir.mkdir(parents=True, exist_ok=True)
263
+
264
+ # Load image
265
+ if args.input_image_path:
266
+ media_items_prepad = load_image_to_tensor_with_resize_and_crop(
267
+ args.input_image_path, args.height, args.width
268
+ )
269
+ else:
270
+ media_items_prepad = None
271
+
272
+ height = args.height if args.height else media_items_prepad.shape[-2]
273
+ width = args.width if args.width else media_items_prepad.shape[-1]
274
+ num_frames = args.num_frames
275
+
276
+ if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES:
277
+ logger.warning(
278
+ f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}."
279
+ )
280
+
281
+ # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
282
+ height_padded = ((height - 1) // 32 + 1) * 32
283
+ width_padded = ((width - 1) // 32 + 1) * 32
284
+ num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
285
+
286
+ padding = calculate_padding(height, width, height_padded, width_padded)
287
+
288
+ logger.warning(
289
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
290
+ )
291
+
292
+ if media_items_prepad is not None:
293
+ media_items = F.pad(
294
+ media_items_prepad, padding, mode="constant", value=-1
295
+ ) # -1 is the value for padding since the image is normalized to -1, 1
296
+ else:
297
+ media_items = None
298
+
299
+ # Paths for the separate mode directories
300
+ ckpt_dir = Path(args.ckpt_dir)
301
+ unet_dir = ckpt_dir / "unet"
302
+ vae_dir = ckpt_dir / "vae"
303
+ scheduler_dir = ckpt_dir / "scheduler"
304
+
305
+ # Load models
306
+ vae = load_vae(vae_dir)
307
+ unet = load_unet(unet_dir)
308
+ scheduler = load_scheduler(scheduler_dir)
309
+ patchifier = SymmetricPatchifier(patch_size=1)
310
+ text_encoder = T5EncoderModel.from_pretrained(
311
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
312
+ )
313
+ if torch.cuda.is_available():
314
+ text_encoder = text_encoder.to("cuda")
315
+ tokenizer = T5Tokenizer.from_pretrained(
316
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
317
+ )
318
+
319
+ if args.bfloat16 and unet.dtype != torch.bfloat16:
320
+ unet = unet.to(torch.bfloat16)
321
+
322
+ # Use submodels for the pipeline
323
+ submodel_dict = {
324
+ "transformer": unet,
325
+ "patchifier": patchifier,
326
+ "text_encoder": text_encoder,
327
+ "tokenizer": tokenizer,
328
+ "scheduler": scheduler,
329
+ "vae": vae,
330
+ }
331
+
332
+ pipeline = LTXVideoPipeline(**submodel_dict)
333
+ if torch.cuda.is_available():
334
+ pipeline = pipeline.to("cuda")
335
+
336
+ # Prepare input for the pipeline
337
+ sample = {
338
+ "prompt": args.prompt,
339
+ "prompt_attention_mask": None,
340
+ "negative_prompt": args.negative_prompt,
341
+ "negative_prompt_attention_mask": None,
342
+ "media_items": media_items,
343
+ }
344
+
345
+ generator = torch.Generator(
346
+ device="cuda" if torch.cuda.is_available() else "cpu"
347
+ ).manual_seed(args.seed)
348
+
349
+ images = pipeline(
350
+ num_inference_steps=args.num_inference_steps,
351
+ num_images_per_prompt=args.num_images_per_prompt,
352
+ guidance_scale=args.guidance_scale,
353
+ generator=generator,
354
+ output_type="pt",
355
+ callback_on_step_end=None,
356
+ height=height_padded,
357
+ width=width_padded,
358
+ num_frames=num_frames_padded,
359
+ frame_rate=args.frame_rate,
360
+ **sample,
361
+ is_video=True,
362
+ vae_per_channel_normalize=True,
363
+ conditioning_method=(
364
+ ConditioningMethod.FIRST_FRAME
365
+ if media_items is not None
366
+ else ConditioningMethod.UNCONDITIONAL
367
+ ),
368
+ mixed_precision=not args.bfloat16,
369
+ ).images
370
+
371
+ # Crop the padded images to the desired resolution and number of frames
372
+ (pad_left, pad_right, pad_top, pad_bottom) = padding
373
+ pad_bottom = -pad_bottom
374
+ pad_right = -pad_right
375
+ if pad_bottom == 0:
376
+ pad_bottom = images.shape[3]
377
+ if pad_right == 0:
378
+ pad_right = images.shape[4]
379
+ images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
380
+
381
+ for i in range(images.shape[0]):
382
+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
383
+ video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
384
+ # Unnormalizing images to [0, 255] range
385
+ video_np = (video_np * 255).astype(np.uint8)
386
+ fps = args.frame_rate
387
+ height, width = video_np.shape[1:3]
388
+ # In case a single image is generated
389
+ if video_np.shape[0] == 1:
390
+ output_filename = get_unique_filename(
391
+ f"image_output_{i}",
392
+ ".png",
393
+ prompt=args.prompt,
394
+ seed=args.seed,
395
+ resolution=(height, width, num_frames),
396
+ dir=output_dir,
397
+ )
398
+ imageio.imwrite(output_filename, video_np[0])
399
+ else:
400
+ if args.input_image_path:
401
+ base_filename = f"img_to_vid_{i}"
402
+ else:
403
+ base_filename = f"text_to_vid_{i}"
404
+ output_filename = get_unique_filename(
405
+ base_filename,
406
+ ".mp4",
407
+ prompt=args.prompt,
408
+ seed=args.seed,
409
+ resolution=(height, width, num_frames),
410
+ dir=output_dir,
411
+ )
412
+
413
+ # Write video
414
+ with imageio.get_writer(output_filename, fps=fps) as video:
415
+ for frame in video_np:
416
+ video.append_data(frame)
417
+
418
+ # Write condition image
419
+ if args.input_image_path:
420
+ reference_image = (
421
+ (
422
+ media_items_prepad[0, :, 0].permute(1, 2, 0).cpu().data.numpy()
423
+ + 1.0
424
+ )
425
+ / 2.0
426
+ * 255
427
+ )
428
+ imageio.imwrite(
429
+ get_unique_filename(
430
+ base_filename,
431
+ ".png",
432
+ prompt=args.prompt,
433
+ seed=args.seed,
434
+ resolution=(height, width, num_frames),
435
+ dir=output_dir,
436
+ endswith="_condition",
437
+ ),
438
+ reference_image.astype(np.uint8),
439
+ )
440
+ logger.warning(f"Output saved to {output_dir}")
441
+
442
+
443
+ if __name__ == "__main__":
444
+ main()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch>=2.1.0
3
+ torchvision
4
+ diffusers>=0.28.2
5
+ transformers==4.44.2
6
+ sentencepiece>=0.1.96
7
+ wheel==0.44.0
8
+ einops==0.8.0
9
+ accelerate==1.1.1
10
+ matplotlib
11
+ imageio[ffmpeg]
12
+ gradio==5.7.1
13
+ gradio_toggle==2.0.2
14
+ --extra-index-url https://download.pytorch.org/whl/nightly/cpu
15
+ torchao