victor-shirasuna
Upload files
3d83373
from .bamba_config import BambaEncoderDecoderConfig
from .bamba import BambaConfig, BambaEncoderDecoder
from .tokenizer.str_tokenizer import load_tokenizer
import torch
import numpy as np
import random
import json
import os
def load_strbamba(ckpt_filename,
base_folder='./str_bamba',
config_filename='config_encoder-decoder_436M.json',
tokenizer_filename='str_bamba_tokenizer.json',
eval_model=True,
device='cuda:0',
dtype=torch.float32
):
# load config
with open(os.path.join(base_folder, f'config/{config_filename}')) as json_data:
config_json = json.load(json_data)
bamba_config = BambaEncoderDecoderConfig(
encoder_config=BambaConfig(**config_json['encoder_config']),
decoder_config=BambaConfig(**config_json['decoder_config']),
tie_word_embeddings=config_json['tie_word_embeddings'],
seed=config_json['seed']
)
# load tokenizer
tokenizer = load_tokenizer(os.path.join(base_folder, f'tokenizer/{tokenizer_filename}'))
# load model
model = BambaEncoderDecoder(bamba_config, tokenizer, device=device, dtype=dtype)
# load weights
ckpt_dict = torch.load(
os.path.join(base_folder, f'checkpoints/{ckpt_filename}'),
map_location=device,
weights_only=False
)
model.load_state_dict(ckpt_dict['module'])
# load RNG states each time the model and states are loaded from checkpoint
if 'rng' in ckpt_dict:
rng = ckpt_dict['rng']
for key, value in rng.items():
if key =='torch_state':
torch.set_rng_state(value.cpu())
elif key =='cuda_state':
torch.cuda.set_rng_state(value.cpu())
elif key =='numpy_state':
np.random.set_state(value)
elif key =='python_state':
random.setstate(value)
else:
print('unrecognized state')
if eval_model:
return model.eval()
return model