tchoudha21 commited on
Commit
5604d40
·
verified ·
1 Parent(s): 7e47687

Change utils dependency

Browse files
Files changed (1) hide show
  1. video2world_hf.py +6 -21
video2world_hf.py CHANGED
@@ -1,33 +1,17 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
  import argparse
17
  import os
18
 
19
- from .utils import misc
20
  import torch
21
 
22
  from .inference_utils import add_common_arguments, check_input_frames, validate_args
23
  from .world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
24
- from .utils import log
25
- from .utils.utils_io import read_prompts_from_file, save_video
26
 
27
  from .download_diffusion import main as download_diffusion
28
  from transformers import PreTrainedModel, PretrainedConfig
29
 
30
- torch.enable_grad(False)
31
 
32
  #custom config class
33
  class DiffusionVideo2WorldConfig(PretrainedConfig):
@@ -65,9 +49,10 @@ class DiffusionVideo2World(PreTrainedModel):
65
 
66
  def __init__(self, config=DiffusionVideo2WorldConfig()):
67
  super().__init__(config)
 
68
  cfg = config
69
-
70
- misc.set_random_seed(cfg.seed)
71
  inference_type = "video2world"
72
  validate_args(cfg, inference_type)
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import os
3
 
4
+ # from .utils import misc
5
  import torch
6
 
7
  from .inference_utils import add_common_arguments, check_input_frames, validate_args
8
  from .world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
9
+ from .log import log
10
+ from .utils_io import read_prompts_from_file, save_video
11
 
12
  from .download_diffusion import main as download_diffusion
13
  from transformers import PreTrainedModel, PretrainedConfig
14
 
 
15
 
16
  #custom config class
17
  class DiffusionVideo2WorldConfig(PretrainedConfig):
 
49
 
50
  def __init__(self, config=DiffusionVideo2WorldConfig()):
51
  super().__init__(config)
52
+ torch.enable_grad(False)
53
  cfg = config
54
+ self.config = config
55
+ # misc.set_random_seed(cfg.seed)
56
  inference_type = "video2world"
57
  validate_args(cfg, inference_type)
58