File size: 1,571 Bytes
58d8e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Example usage of CosmicFish model
"""
import torch
from transformers import GPT2Tokenizer
from modeling_cosmicfish import CosmicFish, CosmicConfig
import json

def load_cosmicfish(model_dir):
    """Load CosmicFish model and tokenizer"""
    # Load config
    with open(f"{model_dir}/config.json", "r") as f:
        config_dict = json.load(f)

    # Create CosmicConfig
    config = CosmicConfig(
        vocab_size=config_dict["vocab_size"],
        block_size=config_dict["block_size"],
        n_layer=config_dict["n_layer"],
        n_head=config_dict["n_head"],
        n_embd=config_dict["n_embd"],
        bias=config_dict["bias"],
        dropout=0.0,  # Set to 0 for inference
        use_rotary=config_dict["use_rotary"],
        use_swiglu=config_dict["use_swiglu"],
        use_gqa=config_dict["use_gqa"],
        n_query_groups=config_dict["n_query_groups"],
        use_qk_norm=config_dict["use_qk_norm"]
    )

    # Create model
    model = CosmicFish(config)

    # Load weights
    state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
    model.load_state_dict(state_dict)
    model.eval()

    # Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    return model, tokenizer

# Example usage:
# model, tokenizer = load_cosmicfish("./")
# input_text = "The future of AI is"
# inputs = tokenizer.encode(input_text, return_tensors="pt")
# outputs = model.generate(inputs, max_length=50, temperature=0.7, do_sample=True)
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(response)