Finished rearranging model and config files
Browse files- .gitignore +1 -0
- model_config.py → ar_configs_model_config.py +0 -0
- cosmos1/models/autoregressive/diffusion_decoder/inference.py → ar_diffusion_decoder_inference.py +3 -3
- cosmos1/models/autoregressive/diffusion_decoder/model.py → ar_diffusion_decoder_model.py +5 -5
- cosmos1/models/autoregressive/diffusion_decoder/utils.py → ar_diffusion_decoder_utils.py +0 -0
- cosmos1/models/autoregressive/inference/world_generation_pipeline.py +3 -3
- cosmos1/models/autoregressive/nemo/utils.py +2 -2
- futureworld_hf.py +29 -16
- text2world_prompt_upsampler_inference.py +1 -1
- video2world_prompt_upsampler_inference.py +1 -1
- world_generation_pipeline.py +11 -10
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
huggingface.txt
|
model_config.py → ar_configs_model_config.py
RENAMED
|
File without changes
|
cosmos1/models/autoregressive/diffusion_decoder/inference.py → ar_diffusion_decoder_inference.py
RENAMED
|
@@ -19,9 +19,9 @@ from typing import List
|
|
| 19 |
|
| 20 |
import torch
|
| 21 |
|
| 22 |
-
from inference_config import DiffusionDecoderSamplingConfig
|
| 23 |
-
from
|
| 24 |
-
from
|
| 25 |
from .log import log
|
| 26 |
|
| 27 |
|
|
|
|
| 19 |
|
| 20 |
import torch
|
| 21 |
|
| 22 |
+
from .inference_config import DiffusionDecoderSamplingConfig
|
| 23 |
+
from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
| 24 |
+
from .ar_diffusion_decoder_utils import linear_blend_video_list, split_with_overlap
|
| 25 |
from .log import log
|
| 26 |
|
| 27 |
|
cosmos1/models/autoregressive/diffusion_decoder/model.py → ar_diffusion_decoder_model.py
RENAMED
|
@@ -19,11 +19,11 @@ from typing import Callable, Dict, Optional, Tuple
|
|
| 19 |
import torch
|
| 20 |
from torch import Tensor
|
| 21 |
|
| 22 |
-
from conditioner import BaseVideoCondition
|
| 23 |
-
from batch_ops import batch_mul
|
| 24 |
-
from res_sampler import COMMON_SOLVER_OPTIONS
|
| 25 |
-
from model_t2w import DiffusionT2WModel as VideoDiffusionModel
|
| 26 |
-
from lazy_config_init import instantiate as lazy_instantiate
|
| 27 |
|
| 28 |
|
| 29 |
@dataclass
|
|
|
|
| 19 |
import torch
|
| 20 |
from torch import Tensor
|
| 21 |
|
| 22 |
+
from .conditioner import BaseVideoCondition
|
| 23 |
+
from .batch_ops import batch_mul
|
| 24 |
+
from .res_sampler import COMMON_SOLVER_OPTIONS
|
| 25 |
+
from .model_t2w import DiffusionT2WModel as VideoDiffusionModel
|
| 26 |
+
from .lazy_config_init import instantiate as lazy_instantiate
|
| 27 |
|
| 28 |
|
| 29 |
@dataclass
|
cosmos1/models/autoregressive/diffusion_decoder/utils.py → ar_diffusion_decoder_utils.py
RENAMED
|
File without changes
|
cosmos1/models/autoregressive/inference/world_generation_pipeline.py
CHANGED
|
@@ -22,7 +22,7 @@ import numpy as np
|
|
| 22 |
import torch
|
| 23 |
from einops import rearrange
|
| 24 |
|
| 25 |
-
from
|
| 26 |
from ar_config_tokenizer import TokenizerConfig
|
| 27 |
from inference_config import (
|
| 28 |
DataShapeConfig,
|
|
@@ -30,8 +30,8 @@ from inference_config import (
|
|
| 30 |
InferenceConfig,
|
| 31 |
SamplingConfig,
|
| 32 |
)
|
| 33 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
| 34 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
| 35 |
from ar_model import AutoRegressiveModel
|
| 36 |
from cosmos1.models.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving
|
| 37 |
from base_world_generation_pipeline import BaseWorldGenerationPipeline
|
|
|
|
| 22 |
import torch
|
| 23 |
from einops import rearrange
|
| 24 |
|
| 25 |
+
from ar_configs_model_config import create_video2world_model_config
|
| 26 |
from ar_config_tokenizer import TokenizerConfig
|
| 27 |
from inference_config import (
|
| 28 |
DataShapeConfig,
|
|
|
|
| 30 |
InferenceConfig,
|
| 31 |
SamplingConfig,
|
| 32 |
)
|
| 33 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens
|
| 34 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
| 35 |
from ar_model import AutoRegressiveModel
|
| 36 |
from cosmos1.models.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving
|
| 37 |
from base_world_generation_pipeline import BaseWorldGenerationPipeline
|
cosmos1/models/autoregressive/nemo/utils.py
CHANGED
|
@@ -24,8 +24,8 @@ import torchvision
|
|
| 24 |
from huggingface_hub import snapshot_download
|
| 25 |
|
| 26 |
from inference_config import DiffusionDecoderSamplingConfig
|
| 27 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
| 28 |
-
from cosmos1.models.autoregressive.diffusion_decoder.
|
| 29 |
from inference_utils import (
|
| 30 |
load_network_model,
|
| 31 |
load_tokenizer_model,
|
|
|
|
| 24 |
from huggingface_hub import snapshot_download
|
| 25 |
|
| 26 |
from inference_config import DiffusionDecoderSamplingConfig
|
| 27 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens
|
| 28 |
+
from cosmos1.models.autoregressive.diffusion_decoder.ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
| 29 |
from inference_utils import (
|
| 30 |
load_network_model,
|
| 31 |
load_tokenizer_model,
|
futureworld_hf.py
CHANGED
|
@@ -19,15 +19,23 @@ class AutoregressiveFutureWorldConfig(PretrainedConfig):
|
|
| 19 |
def __init__(self, **kwargs):
|
| 20 |
super().__init__(**kwargs)
|
| 21 |
self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
|
| 22 |
-
self.
|
| 23 |
self.disable_diffusion_decoder = kwargs.get("disable_diffusion_decoder", False)
|
| 24 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
| 25 |
self.offload_diffusion_decoder = kwargs.get("offload_diffusion_decoder", False)
|
| 26 |
-
self.
|
| 27 |
self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
|
| 28 |
self.video_save_name = kwargs.get("video_save_name", "output")
|
| 29 |
self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
|
| 30 |
-
self.seed = kwargs.get()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# custom model class
|
| 33 |
class AutoregressiveFutureWorld(PreTrainedModel):
|
|
@@ -37,17 +45,16 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
| 37 |
super().__init__(config)
|
| 38 |
torch._C._jit_set_texpr_fuser_enabled(False)
|
| 39 |
self.config = config
|
| 40 |
-
inference_type = "base"
|
| 41 |
-
sampling_config = validate_args(config, inference_type)
|
| 42 |
self.pipeline = ARBaseGenerationPipeline(
|
| 43 |
-
inference_type=inference_type,
|
| 44 |
-
checkpoint_dir=self.checkpoint_dir,
|
| 45 |
-
checkpoint_name=self.ar_model_dir,
|
| 46 |
-
disable_diffusion_decoder=self.disable_diffusion_decoder,
|
| 47 |
-
offload_guardrail_models=self.offload_guardrail_models,
|
| 48 |
-
offload_diffusion_decoder=self.offload_diffusion_decoder,
|
| 49 |
-
offload_network=self.offload_ar_model,
|
| 50 |
-
offload_tokenizer=self.offload_tokenizer,
|
| 51 |
)
|
| 52 |
|
| 53 |
# modifed from text2world.py demo function
|
|
@@ -63,6 +70,12 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
| 63 |
data_resolution=data_resolution,
|
| 64 |
num_input_frames=num_input_frames,
|
| 65 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
for idx, input_filename in enumerate(input_videos):
|
| 68 |
inp_vid = input_videos[input_filename]
|
|
@@ -71,7 +84,7 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
| 71 |
out_vid = self.pipeline.generate(
|
| 72 |
inp_vid=inp_vid,
|
| 73 |
num_input_frames=num_input_frames,
|
| 74 |
-
seed=
|
| 75 |
sampling_config=sampling_config,
|
| 76 |
)
|
| 77 |
if out_vid is None:
|
|
@@ -80,9 +93,9 @@ class AutoregressiveFutureWorld(PreTrainedModel):
|
|
| 80 |
|
| 81 |
# Save video
|
| 82 |
if input_image_or_video_path:
|
| 83 |
-
out_vid_path = os.path.join(
|
| 84 |
else:
|
| 85 |
-
out_vid_path = os.path.join(
|
| 86 |
|
| 87 |
imageio.mimsave(out_vid_path, out_vid, fps=25)
|
| 88 |
|
|
|
|
| 19 |
def __init__(self, **kwargs):
|
| 20 |
super().__init__(**kwargs)
|
| 21 |
self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
|
| 22 |
+
self.ar_model_dir = kwargs.get("ar_model_dir", "Cosmos-1.0-Autoregressive-4B")
|
| 23 |
self.disable_diffusion_decoder = kwargs.get("disable_diffusion_decoder", False)
|
| 24 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
| 25 |
self.offload_diffusion_decoder = kwargs.get("offload_diffusion_decoder", False)
|
| 26 |
+
self.offload_ar_model = kwargs.get("offload_ar_model", False)
|
| 27 |
self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
|
| 28 |
self.video_save_name = kwargs.get("video_save_name", "output")
|
| 29 |
self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
|
| 30 |
+
self.seed = kwargs.get("seed", 0)
|
| 31 |
+
self.temperature = kwargs.get("temperature", 1.0)
|
| 32 |
+
self.top_p = kwargs.get("top_p", 0.8)
|
| 33 |
+
self.input_type = None
|
| 34 |
+
self.batch_input_path = None
|
| 35 |
+
self.input_image_or_video_path = None
|
| 36 |
+
self.data_resolution = None
|
| 37 |
+
self.num_input_frames = None
|
| 38 |
+
|
| 39 |
|
| 40 |
# custom model class
|
| 41 |
class AutoregressiveFutureWorld(PreTrainedModel):
|
|
|
|
| 45 |
super().__init__(config)
|
| 46 |
torch._C._jit_set_texpr_fuser_enabled(False)
|
| 47 |
self.config = config
|
| 48 |
+
self.inference_type = "base"
|
|
|
|
| 49 |
self.pipeline = ARBaseGenerationPipeline(
|
| 50 |
+
inference_type=self.inference_type,
|
| 51 |
+
checkpoint_dir=self.config.checkpoint_dir,
|
| 52 |
+
checkpoint_name=self.config.ar_model_dir,
|
| 53 |
+
disable_diffusion_decoder=self.config.disable_diffusion_decoder,
|
| 54 |
+
offload_guardrail_models=self.config.offload_guardrail_models,
|
| 55 |
+
offload_diffusion_decoder=self.config.offload_diffusion_decoder,
|
| 56 |
+
offload_network=self.config.offload_ar_model,
|
| 57 |
+
offload_tokenizer=self.config.offload_tokenizer,
|
| 58 |
)
|
| 59 |
|
| 60 |
# modifed from text2world.py demo function
|
|
|
|
| 70 |
data_resolution=data_resolution,
|
| 71 |
num_input_frames=num_input_frames,
|
| 72 |
)
|
| 73 |
+
self.config.input_type = input_type
|
| 74 |
+
self.config.batch_input_path = batch_input_path
|
| 75 |
+
self.config.input_image_or_video_path = input_image_or_video_path
|
| 76 |
+
self.config.data_resolution = data_resolution
|
| 77 |
+
self.config.num_input_frames = num_input_frames
|
| 78 |
+
sampling_config = validate_args(self.config, self.inference_type)
|
| 79 |
|
| 80 |
for idx, input_filename in enumerate(input_videos):
|
| 81 |
inp_vid = input_videos[input_filename]
|
|
|
|
| 84 |
out_vid = self.pipeline.generate(
|
| 85 |
inp_vid=inp_vid,
|
| 86 |
num_input_frames=num_input_frames,
|
| 87 |
+
seed=self.config.seed,
|
| 88 |
sampling_config=sampling_config,
|
| 89 |
)
|
| 90 |
if out_vid is None:
|
|
|
|
| 93 |
|
| 94 |
# Save video
|
| 95 |
if input_image_or_video_path:
|
| 96 |
+
out_vid_path = os.path.join(self.config.video_save_folder, f"{self.config.video_save_name}.mp4")
|
| 97 |
else:
|
| 98 |
+
out_vid_path = os.path.join(self.config.video_save_folder, f"{idx}.mp4")
|
| 99 |
|
| 100 |
imageio.mimsave(out_vid_path, out_vid, fps=25)
|
| 101 |
|
text2world_prompt_upsampler_inference.py
CHANGED
|
@@ -23,7 +23,7 @@ import argparse
|
|
| 23 |
import os
|
| 24 |
import re
|
| 25 |
|
| 26 |
-
from .
|
| 27 |
from .ar_model import AutoRegressiveModel
|
| 28 |
from .inference import chat_completion
|
| 29 |
from .presets import presets as guardrail_presets
|
|
|
|
| 23 |
import os
|
| 24 |
import re
|
| 25 |
|
| 26 |
+
from .ar_configs_model_config import create_text_model_config
|
| 27 |
from .ar_model import AutoRegressiveModel
|
| 28 |
from .inference import chat_completion
|
| 29 |
from .presets import presets as guardrail_presets
|
video2world_prompt_upsampler_inference.py
CHANGED
|
@@ -26,7 +26,7 @@ from math import ceil
|
|
| 26 |
|
| 27 |
from PIL import Image
|
| 28 |
|
| 29 |
-
from .
|
| 30 |
from .ar_model import AutoRegressiveModel
|
| 31 |
from .inference import chat_completion
|
| 32 |
from .presets import presets as guardrail_presets
|
|
|
|
| 26 |
|
| 27 |
from PIL import Image
|
| 28 |
|
| 29 |
+
from .ar_configs_model_config import create_vision_language_model_config
|
| 30 |
from .ar_model import AutoRegressiveModel
|
| 31 |
from .inference import chat_completion
|
| 32 |
from .presets import presets as guardrail_presets
|
world_generation_pipeline.py
CHANGED
|
@@ -21,25 +21,26 @@ import numpy as np
|
|
| 21 |
import torch
|
| 22 |
from einops import rearrange
|
| 23 |
|
| 24 |
-
from
|
| 25 |
-
from
|
| 26 |
-
from
|
| 27 |
DataShapeConfig,
|
| 28 |
DiffusionDecoderSamplingConfig,
|
| 29 |
InferenceConfig,
|
| 30 |
SamplingConfig,
|
| 31 |
)
|
| 32 |
-
from
|
| 33 |
-
from
|
| 34 |
-
from
|
| 35 |
-
from
|
| 36 |
-
from
|
| 37 |
-
from
|
| 38 |
load_model_by_config,
|
| 39 |
load_network_model,
|
| 40 |
load_tokenizer_model,
|
| 41 |
)
|
| 42 |
-
from
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|
|
|
|
| 21 |
import torch
|
| 22 |
from einops import rearrange
|
| 23 |
|
| 24 |
+
from .ar_configs_model_config import create_video2world_model_config
|
| 25 |
+
from .ar_config_tokenizer import TokenizerConfig
|
| 26 |
+
from .ar_configs_inference import (
|
| 27 |
DataShapeConfig,
|
| 28 |
DiffusionDecoderSamplingConfig,
|
| 29 |
InferenceConfig,
|
| 30 |
SamplingConfig,
|
| 31 |
)
|
| 32 |
+
from .ar_diffusion_decoder_inference import diffusion_decoder_process_tokens
|
| 33 |
+
from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
|
| 34 |
+
from .ar_model import AutoRegressiveModel
|
| 35 |
+
from .ar_utils_inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving
|
| 36 |
+
from .base_world_generation_pipeline import BaseWorldGenerationPipeline
|
| 37 |
+
from .inference_utils import (
|
| 38 |
load_model_by_config,
|
| 39 |
load_network_model,
|
| 40 |
load_tokenizer_model,
|
| 41 |
)
|
| 42 |
+
from .log import log
|
| 43 |
+
from .misc import misc
|
| 44 |
|
| 45 |
|
| 46 |
def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
|