akkiisfrommars commited on
Commit
04146a5
·
verified ·
1 Parent(s): ad3aadd

Upload example_usage.py

Browse files
Files changed (1) hide show
  1. example_usage.py +78 -0
example_usage.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple example usage of CosmicFish model (local model)
3
+ """
4
+ import torch
5
+ from transformers import GPT2Tokenizer
6
+ from modeling_cosmicfish import CosmicFish, CosmicConfig
7
+ from safetensors.torch import load_file
8
+ import json
9
+
10
+ def load_cosmicfish(model_dir):
11
+ """Load CosmicFish model and tokenizer"""
12
+ # Load config
13
+ with open(f"{model_dir}/config.json", "r") as f:
14
+ config_dict = json.load(f)
15
+
16
+ # Create model config
17
+ config = CosmicConfig(
18
+ vocab_size=config_dict["vocab_size"],
19
+ block_size=config_dict["block_size"],
20
+ n_layer=config_dict["n_layer"],
21
+ n_head=config_dict["n_head"],
22
+ n_embd=config_dict["n_embd"],
23
+ bias=config_dict["bias"],
24
+ dropout=0.0,
25
+ use_rotary=config_dict["use_rotary"],
26
+ use_swiglu=config_dict["use_swiglu"],
27
+ use_gqa=config_dict["use_gqa"],
28
+ n_query_groups=config_dict["n_query_groups"],
29
+ use_qk_norm=config_dict.get("use_qk_norm", False)
30
+ )
31
+
32
+ # Create and load model
33
+ model = CosmicFish(config)
34
+ state_dict = load_file(f"{model_dir}/model.safetensors")
35
+
36
+ # Handle weight sharing
37
+ if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
38
+ state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
39
+
40
+ model.load_state_dict(state_dict)
41
+ model.eval()
42
+
43
+ # Load tokenizer
44
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
45
+ return model, tokenizer
46
+
47
+ def simple_generate(model, tokenizer, prompt, max_tokens=50, temperature=0.7):
48
+ """Generate text from a prompt"""
49
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
50
+
51
+ with torch.no_grad():
52
+ outputs = model.generate(
53
+ inputs,
54
+ max_new_tokens=max_tokens,
55
+ temperature=temperature,
56
+ top_k=40
57
+ )
58
+
59
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+
61
+ if __name__ == "__main__":
62
+ # Load model
63
+ print("Loading CosmicFish...")
64
+ model, tokenizer = load_cosmicfish("./")
65
+ print(f"Model loaded! ({model.get_num_params()/1e6:.1f}M parameters)")
66
+
67
+ # Example prompts
68
+ prompts = [
69
+ "What is climate change?",
70
+ "Write a poem",
71
+ "Define ML"
72
+ ]
73
+
74
+ # Generate responses
75
+ for prompt in prompts:
76
+ print(f"\nPrompt: {prompt}")
77
+ response = simple_generate(model, tokenizer, prompt, max_tokens=30)
78
+ print(f"Response: {response}")