File size: 2,409 Bytes
19bc4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Simple example usage of CosmicFish model (local model)
"""
import torch
from transformers import GPT2Tokenizer
from modeling_cosmicfish import CosmicFish, CosmicConfig
from safetensors.torch import load_file
import json

def load_cosmicfish(model_dir):
    """Load CosmicFish model and tokenizer"""
    # Load config
    with open(f"{model_dir}/config.json", "r") as f:
        config_dict = json.load(f)

    # Create model config
    config = CosmicConfig(
        vocab_size=config_dict["vocab_size"],
        block_size=config_dict["block_size"],
        n_layer=config_dict["n_layer"],
        n_head=config_dict["n_head"],
        n_embd=config_dict["n_embd"],
        bias=config_dict["bias"],
        dropout=0.0,
        use_rotary=config_dict["use_rotary"],
        use_swiglu=config_dict["use_swiglu"],
        use_gqa=config_dict["use_gqa"],
        n_query_groups=config_dict["n_query_groups"],
        use_qk_norm=config_dict.get("use_qk_norm", False)
    )

    # Create and load model
    model = CosmicFish(config)
    state_dict = load_file(f"{model_dir}/model.safetensors")

    # Handle weight sharing
    if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
        state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']

    model.load_state_dict(state_dict)
    model.eval()

    # Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    return model, tokenizer

def simple_generate(model, tokenizer, prompt, max_tokens=50, temperature=0.7):
    """Generate text from a prompt"""
    inputs = tokenizer.encode(prompt, return_tensors="pt")

    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_k=40
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

if __name__ == "__main__":
    # Load model
    print("Loading CosmicFish...")
    model, tokenizer = load_cosmicfish("./")
    print(f"Model loaded! ({model.get_num_params()/1e6:.1f}M parameters)")

    # Example prompts
    prompts = [
        "What is climate change?",
        "Write a poem",
        "Define ML"
    ]

    # Generate responses
    for prompt in prompts:
        print(f"\nPrompt: {prompt}")
        response = simple_generate(model, tokenizer, prompt, max_tokens=30)
        print(f"Response: {response}")