|
|
from .bamba_config import BambaEncoderDecoderConfig |
|
|
from .bamba import BambaConfig, BambaEncoderDecoder |
|
|
from .tokenizer.str_tokenizer import load_tokenizer |
|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
def load_strbamba(ckpt_filename, |
|
|
base_folder='./str_bamba', |
|
|
config_filename='config_encoder-decoder_436M.json', |
|
|
tokenizer_filename='str_bamba_tokenizer.json', |
|
|
eval_model=True, |
|
|
device='cuda:0', |
|
|
dtype=torch.float32 |
|
|
): |
|
|
|
|
|
with open(os.path.join(base_folder, f'config/{config_filename}')) as json_data: |
|
|
config_json = json.load(json_data) |
|
|
bamba_config = BambaEncoderDecoderConfig( |
|
|
encoder_config=BambaConfig(**config_json['encoder_config']), |
|
|
decoder_config=BambaConfig(**config_json['decoder_config']), |
|
|
tie_word_embeddings=config_json['tie_word_embeddings'], |
|
|
seed=config_json['seed'] |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = load_tokenizer(os.path.join(base_folder, f'tokenizer/{tokenizer_filename}')) |
|
|
|
|
|
|
|
|
model = BambaEncoderDecoder(bamba_config, tokenizer, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
ckpt_dict = torch.load( |
|
|
os.path.join(base_folder, f'checkpoints/{ckpt_filename}'), |
|
|
map_location=device, |
|
|
weights_only=False |
|
|
) |
|
|
model.load_state_dict(ckpt_dict['module']) |
|
|
|
|
|
|
|
|
if 'rng' in ckpt_dict: |
|
|
rng = ckpt_dict['rng'] |
|
|
for key, value in rng.items(): |
|
|
if key =='torch_state': |
|
|
torch.set_rng_state(value.cpu()) |
|
|
elif key =='cuda_state': |
|
|
torch.cuda.set_rng_state(value.cpu()) |
|
|
elif key =='numpy_state': |
|
|
np.random.set_state(value) |
|
|
elif key =='python_state': |
|
|
random.setstate(value) |
|
|
else: |
|
|
print('unrecognized state') |
|
|
|
|
|
if eval_model: |
|
|
return model.eval() |
|
|
return model |