akkiisfrommars commited on
Commit
95f1d58
·
verified ·
1 Parent(s): 46b4e09

Delete example_usage.py

Browse files
Files changed (1) hide show
  1. example_usage.py +0 -50
example_usage.py DELETED
@@ -1,50 +0,0 @@
1
- """
2
- Example usage of CosmicFish model
3
- """
4
- import torch
5
- from transformers import GPT2Tokenizer
6
- from modeling_cosmicfish import CosmicFish, CosmicConfig
7
- import json
8
-
9
- def load_cosmicfish(model_dir):
10
- """Load CosmicFish model and tokenizer"""
11
- # Load config
12
- with open(f"{model_dir}/config.json", "r") as f:
13
- config_dict = json.load(f)
14
-
15
- # Create CosmicConfig
16
- config = CosmicConfig(
17
- vocab_size=config_dict["vocab_size"],
18
- block_size=config_dict["block_size"],
19
- n_layer=config_dict["n_layer"],
20
- n_head=config_dict["n_head"],
21
- n_embd=config_dict["n_embd"],
22
- bias=config_dict["bias"],
23
- dropout=0.0, # Set to 0 for inference
24
- use_rotary=config_dict["use_rotary"],
25
- use_swiglu=config_dict["use_swiglu"],
26
- use_gqa=config_dict["use_gqa"],
27
- n_query_groups=config_dict["n_query_groups"],
28
- use_qk_norm=config_dict["use_qk_norm"]
29
- )
30
-
31
- # Create model
32
- model = CosmicFish(config)
33
-
34
- # Load weights
35
- state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
36
- model.load_state_dict(state_dict)
37
- model.eval()
38
-
39
- # Load tokenizer
40
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
41
-
42
- return model, tokenizer
43
-
44
- # Example usage:
45
- # model, tokenizer = load_cosmicfish("./")
46
- # input_text = "The future of AI is"
47
- # inputs = tokenizer.encode(input_text, return_tensors="pt")
48
- # outputs = model.generate(inputs, max_length=50, temperature=0.7, do_sample=True)
49
- # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
- # print(response)