""" Example usage of CosmicFish model """ import torch from transformers import GPT2Tokenizer from modeling_cosmicfish import CosmicFish, CosmicConfig import json def load_cosmicfish(model_dir): """Load CosmicFish model and tokenizer""" # Load config with open(f"{model_dir}/config.json", "r") as f: config_dict = json.load(f) # Create CosmicConfig config = CosmicConfig( vocab_size=config_dict["vocab_size"], block_size=config_dict["block_size"], n_layer=config_dict["n_layer"], n_head=config_dict["n_head"], n_embd=config_dict["n_embd"], bias=config_dict["bias"], dropout=0.0, # Set to 0 for inference use_rotary=config_dict["use_rotary"], use_swiglu=config_dict["use_swiglu"], use_gqa=config_dict["use_gqa"], n_query_groups=config_dict["n_query_groups"], use_qk_norm=config_dict["use_qk_norm"] ) # Create model model = CosmicFish(config) # Load weights state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict) model.eval() # Load tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") return model, tokenizer # Example usage: # model, tokenizer = load_cosmicfish("./") # input_text = "The future of AI is" # inputs = tokenizer.encode(input_text, return_tensors="pt") # outputs = model.generate(inputs, max_length=50, temperature=0.7, do_sample=True) # response = tokenizer.decode(outputs[0], skip_special_tokens=True) # print(response)