| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Utility to convert weights to safetensors.""" |
|
|
| import argparse |
|
|
| import torch |
|
|
| from .configuration_embed1 import CosmosEmbed1Config |
| from .modeling_embed1 import CosmosEmbed1 |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Save model weights with optional format conversion and sharding.") |
| parser.add_argument("--input_weights", type=str, required=True, help="Path to the input .pt weights file") |
| parser.add_argument( |
| "--output_weights", |
| type=str, |
| required=True, |
| help="Path to the output directory where safetensors weights will be saved", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| model = CosmosEmbed1(CosmosEmbed1Config()).to("cuda", dtype=torch.bfloat16) |
|
|
| |
| model.qformer.cls.predictions.decoder.weight = torch.nn.Parameter( |
| model.qformer.cls.predictions.decoder.weight.clone() |
| ) |
| model.qformer.bert.embeddings.word_embeddings.weight = torch.nn.Parameter( |
| model.qformer.bert.embeddings.word_embeddings.weight.clone() |
| ) |
| model.qformer.cls.predictions.decoder.bias = torch.nn.Parameter(model.qformer.cls.predictions.decoder.bias.clone()) |
| model.qformer.cls.predictions.bias = torch.nn.Parameter(model.qformer.cls.predictions.bias.clone()) |
|
|
| with open(args.input_weights, "rb") as fp: |
| state_dict = torch.load(fp) |
| model.load_state_dict(state_dict, strict=True) |
|
|
| model.save_pretrained( |
| args.output_weights, |
| safe_serialization=True, |
| max_shard_size="500MB", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|