|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A CLI to export an pre-trained tokenizer checkpoint into a torch.ScriptModule. |
|
|
|
|
|
Usage: |
|
|
python3 -m cosmos_predict1.tokenizer.training.jit_cli \ |
|
|
--ckpt_path=checkpoints/Cosmos-0.1-Tokenizer-CV4x8x8/iter_001000000.pt \ |
|
|
--output_dir=checkpoints/Cosmos-0.1-Tokenizer-CV4x8x8/exported \ |
|
|
--strict_resume \ |
|
|
--config=cosmos_predict1/tokenizer/training/configs/config.py -- \ |
|
|
experiment=CV720_Causal_AE49_4x8x8_cosmos |
|
|
|
|
|
|
|
|
will output: |
|
|
<output-dir>/iter_001000000_ema.jit |
|
|
<output-dir>/iter_001000000_enc.jit |
|
|
<output-dir>/iter_001000000_dec.jit |
|
|
|
|
|
if --reg is specified, it will export the regular model: |
|
|
<output-dir>/iter_001000000_reg.jit |
|
|
<output-dir>/iter_001000000_enc.jit |
|
|
<output-dir>/iter_001000000_dec.jit |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import importlib |
|
|
import os |
|
|
|
|
|
import torch |
|
|
from loguru import logger as logging |
|
|
from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule |
|
|
|
|
|
from cosmos_predict1.tokenizer.training.checkpointer import TokenizerCheckpointer |
|
|
from cosmos_predict1.utils import callback, ema |
|
|
from cosmos_predict1.utils.config import Config |
|
|
from cosmos_predict1.utils.config_helper import get_config_module, override |
|
|
from cosmos_predict1.utils.lazy_config import instantiate |
|
|
from cosmos_predict1.utils.model import Model |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Export a pre-trained model into a torch.jit.ScriptModule.") |
|
|
parser.add_argument( |
|
|
"--config", type=str, default="cosmos_predict1/tokenizer/training/configs/config.py", help="Path to the config file" |
|
|
) |
|
|
parser.add_argument("--ckpt_path", type=str, default=None, help="The full ckpt path.") |
|
|
parser.add_argument("--credentials", type=str, default="credentials/pdx_vfm_base.secret", help="The credentials file.") |
|
|
parser.add_argument("--strict_resume", action="store_true", help="Enable strictly loading into every network weight.") |
|
|
parser.add_argument("--reg", action="store_true", help="Enable regular model export.") |
|
|
parser.add_argument("--output_dir", type=str, default=None, help="Optional output directory.") |
|
|
|
|
|
parser.add_argument( |
|
|
"opts", |
|
|
help=""" |
|
|
Modify config options at the end of the command. For Yacs configs, use |
|
|
space-separated "PATH.KEY VALUE" pairs. |
|
|
For python-based LazyConfig, use "path.key=value". |
|
|
""".strip(), |
|
|
default=None, |
|
|
nargs=argparse.REMAINDER, |
|
|
) |
|
|
|
|
|
logging.info("Initialize args, cfg from command line arguments ...") |
|
|
args = parser.parse_args() |
|
|
config_module = get_config_module(args.config) |
|
|
config: Config = importlib.import_module(config_module).make_config() |
|
|
config = override(config, args.opts) |
|
|
|
|
|
|
|
|
def _compile_jit_models(model: Model) -> dict[str, torch.ScriptModule]: |
|
|
"""Returns a TorchScript version of REG or EMA models compiled by PyTorch JIT.""" |
|
|
assert hasattr(config, "checkpoint") and hasattr(config.checkpoint, "jit") |
|
|
config_jit = config.checkpoint.jit |
|
|
input_shape = tuple(config_jit.input_shape) |
|
|
example_input = torch.randn(input_shape) |
|
|
dtype = getattr(torch, config_jit.dtype) |
|
|
example_input = example_input.to(config_jit.device).to(dtype) |
|
|
|
|
|
|
|
|
|
|
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
|
|
|
|
with ema.ema_scope(model, enabled=model.config.ema.enabled and not args.reg): |
|
|
_model = model.network.eval() |
|
|
if isinstance(_model, torch_OptimizedModule): |
|
|
_model = _model._orig_mod |
|
|
model_jit = torch.jit.trace(_model, example_input, strict=config_jit.strict) |
|
|
encoder_jit = torch.jit.trace(_model.encoder_jit(), example_input, strict=config_jit.strict) |
|
|
decoder_example = encoder_jit(example_input)[0] |
|
|
decoder_jit = torch.jit.trace(_model.decoder_jit(), decoder_example, strict=config_jit.strict) |
|
|
if args.reg: |
|
|
return {"reg": model_jit, "enc": encoder_jit, "dec": decoder_jit} |
|
|
return {"ema": model_jit, "enc": encoder_jit, "dec": decoder_jit} |
|
|
|
|
|
|
|
|
def _run_export() -> None: |
|
|
"""Exports a torch.nn.Module into a torch.jit.ScriptModule.""" |
|
|
|
|
|
config.validate() |
|
|
config.checkpoint.load_path = args.ckpt_path |
|
|
config.checkpoint.strict_resume = args.strict_resume |
|
|
config.checkpoint.load_training_state = False |
|
|
config.job.name = os.path.basename(args.output_dir) if args.output_dir else os.path.basename(args.ckpt_path) |
|
|
|
|
|
|
|
|
config.freeze() |
|
|
callbacks = callback.CallBackGroup(config=config, trainer=None) |
|
|
checkpointer = TokenizerCheckpointer(config.checkpoint, config.job, callbacks=callbacks) |
|
|
|
|
|
|
|
|
logging.info(f"Instantiate model={config.model.config.network.name} ...") |
|
|
model = instantiate(config.model) |
|
|
model = model.to("cuda", memory_format=config.trainer.memory_format) |
|
|
model.on_train_start(config.trainer.memory_format) |
|
|
|
|
|
logging.info(f"loading weights from {config.checkpoint.load_path}...") |
|
|
_ = checkpointer.load(model) |
|
|
model.eval() |
|
|
ckpt_name = config.checkpoint.load_path.split("/")[-1][:-3] |
|
|
|
|
|
|
|
|
tmp_output_dir = os.path.dirname(config.checkpoint.load_path) |
|
|
output_dir = args.output_dir or tmp_output_dir |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
logging.info("Performing JIT compilation ...") |
|
|
jit_models = _compile_jit_models(model) |
|
|
for name, jit_model in jit_models.items(): |
|
|
logging.info(f"Outputing torch.jit: {output_dir}/{ckpt_name}_{name}.jit") |
|
|
torch.jit.save(jit_model, f"{output_dir}/{ckpt_name}_{name}.jit") |
|
|
|
|
|
|
|
|
@logging.catch(reraise=True) |
|
|
def main() -> None: |
|
|
_run_export() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|