try add . in import
Browse files- text2world_hf.py +27 -7
text2world_hf.py
CHANGED
|
@@ -3,11 +3,11 @@ import argparse
|
|
| 3 |
import torch
|
| 4 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
|
| 6 |
-
from cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
| 7 |
-
from cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
| 8 |
-
import cosmos1.utils.log as log
|
| 9 |
-
import cosmos1.utils.misc as misc
|
| 10 |
-
from cosmos1.utils.io import read_prompts_from_file, save_video
|
| 11 |
|
| 12 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
| 13 |
model_type = "DiffusionText2World"
|
|
@@ -46,8 +46,28 @@ class DiffusionText2World(PreTrainedModel):
|
|
| 46 |
torch.enable_grad(False) # TODO: do we need this?
|
| 47 |
self.config = config
|
| 48 |
inference_type = "text2world"
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
def forward(self, prompt):
|
| 53 |
cfg = self.config
|
|
|
|
| 3 |
import torch
|
| 4 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
|
| 6 |
+
from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
| 7 |
+
from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
| 8 |
+
import .cosmos1.utils.log as log
|
| 9 |
+
import .cosmos1.utils.misc as misc
|
| 10 |
+
from .cosmos1.utils.io import read_prompts_from_file, save_video
|
| 11 |
|
| 12 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
| 13 |
model_type = "DiffusionText2World"
|
|
|
|
| 46 |
torch.enable_grad(False) # TODO: do we need this?
|
| 47 |
self.config = config
|
| 48 |
inference_type = "text2world"
|
| 49 |
+
config.prompt = 1 # TODO: this is to hack args validation, maybe find a better way
|
| 50 |
+
validate_args(config, inference_type)
|
| 51 |
+
del config.prompt
|
| 52 |
+
self.pipeline = DiffusionText2WorldGenerationPipeline(
|
| 53 |
+
inference_type=inference_type,
|
| 54 |
+
checkpoint_dir=config.checkpoint_dir,
|
| 55 |
+
checkpoint_name=config.diffusion_transformer_dir,
|
| 56 |
+
prompt_upsampler_dir=config.prompt_upsampler_dir,
|
| 57 |
+
enable_prompt_upsampler=not config.disable_prompt_upsampler,
|
| 58 |
+
offload_network=config.offload_diffusion_transformer,
|
| 59 |
+
offload_tokenizer=config.offload_tokenizer,
|
| 60 |
+
offload_text_encoder_model=config.offload_text_encoder_model,
|
| 61 |
+
offload_prompt_upsampler=config.offload_prompt_upsampler,
|
| 62 |
+
offload_guardrail_models=config.offload_guardrail_models,
|
| 63 |
+
guidance=config.guidance,
|
| 64 |
+
num_steps=config.num_steps,
|
| 65 |
+
height=config.height,
|
| 66 |
+
width=config.width,
|
| 67 |
+
fps=config.fps,
|
| 68 |
+
num_video_frames=config.num_video_frames,
|
| 69 |
+
seed=config.seed,
|
| 70 |
+
)
|
| 71 |
|
| 72 |
def forward(self, prompt):
|
| 73 |
cfg = self.config
|