File size: 1,978 Bytes
e226163 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""
Example usage of CosmicFish model (using safetensors)
"""
import torch
from transformers import GPT2Tokenizer
from modeling_cosmicfish import CosmicFish, CosmicConfig
from safetensors.torch import load_file
import json
def load_cosmicfish(model_dir):
"""Load CosmicFish model and tokenizer"""
# Load config
with open(f"{model_dir}/config.json", "r") as f:
config_dict = json.load(f)
# Create CosmicConfig
config = CosmicConfig(
vocab_size=config_dict["vocab_size"],
block_size=config_dict["block_size"],
n_layer=config_dict["n_layer"],
n_head=config_dict["n_head"],
n_embd=config_dict["n_embd"],
bias=config_dict["bias"],
dropout=0.0, # Set to 0 for inference
use_rotary=config_dict["use_rotary"],
use_swiglu=config_dict["use_swiglu"],
use_gqa=config_dict["use_gqa"],
n_query_groups=config_dict["n_query_groups"],
use_qk_norm=config_dict["use_qk_norm"]
)
# Create model
model = CosmicFish(config)
# Load weights from safetensors (safer and faster)
state_dict = load_file(f"{model_dir}/model.safetensors")
# Handle weight sharing: lm_head.weight shares with transformer.wte.weight
if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
print("Weight sharing detected: tying lm_head.weight to transformer.wte.weight")
state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
model.load_state_dict(state_dict)
model.eval()
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
return model, tokenizer
# Example usage:
# model, tokenizer = load_cosmicfish("./")
# input_text = "The future of AI is"
# inputs = tokenizer.encode(input_text, return_tensors="pt")
# outputs = model.generate(inputs, max_length=50, temperature=0.7, do_sample=True)
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(response)
|