Spaces:
Runtime error
Runtime error
| import torch | |
| from genie.config import GenieConfig | |
| from genie.st_mask_git import STMaskGIT | |
| def convert_lightning_checkpoint(lightning_checkpoint, num_layers, num_heads, d_model, save_dir): | |
| """ | |
| v0.0.1 saved models in Lightning checkpoints, this can convert Lightning checkpoints to HF checkpoints. | |
| """ | |
| config = GenieConfig(num_layers=num_layers, num_heads=num_heads, d_model=d_model) | |
| model = STMaskGIT(config) | |
| lightning_checkpoint = torch.load(lightning_checkpoint) | |
| model_state_dict = lightning_checkpoint["state_dict"] | |
| # Remove `model.` prefix | |
| model_state_dict = {name.replace("model.", ""): params for name, params in model_state_dict.items()} | |
| model.load_state_dict(model_state_dict) | |
| model.save_pretrained(save_dir) | |