|
|
""" |
|
|
Example usage of CosmicFish model |
|
|
""" |
|
|
import torch |
|
|
from transformers import GPT2Tokenizer |
|
|
from modeling_cosmicfish import CosmicFish, CosmicConfig |
|
|
import json |
|
|
|
|
|
def load_cosmicfish(model_dir): |
|
|
"""Load CosmicFish model and tokenizer""" |
|
|
|
|
|
with open(f"{model_dir}/config.json", "r") as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
|
|
|
config = CosmicConfig( |
|
|
vocab_size=config_dict["vocab_size"], |
|
|
block_size=config_dict["block_size"], |
|
|
n_layer=config_dict["n_layer"], |
|
|
n_head=config_dict["n_head"], |
|
|
n_embd=config_dict["n_embd"], |
|
|
bias=config_dict["bias"], |
|
|
dropout=0.0, |
|
|
use_rotary=config_dict["use_rotary"], |
|
|
use_swiglu=config_dict["use_swiglu"], |
|
|
use_gqa=config_dict["use_gqa"], |
|
|
n_query_groups=config_dict["n_query_groups"], |
|
|
use_qk_norm=config_dict["use_qk_norm"] |
|
|
) |
|
|
|
|
|
|
|
|
model = CosmicFish(config) |
|
|
|
|
|
|
|
|
state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu") |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|