Delete example_usage.py
Browse files- example_usage.py +0 -50
example_usage.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Example usage of CosmicFish model
|
| 3 |
-
"""
|
| 4 |
-
import torch
|
| 5 |
-
from transformers import GPT2Tokenizer
|
| 6 |
-
from modeling_cosmicfish import CosmicFish, CosmicConfig
|
| 7 |
-
import json
|
| 8 |
-
|
| 9 |
-
def load_cosmicfish(model_dir):
|
| 10 |
-
"""Load CosmicFish model and tokenizer"""
|
| 11 |
-
# Load config
|
| 12 |
-
with open(f"{model_dir}/config.json", "r") as f:
|
| 13 |
-
config_dict = json.load(f)
|
| 14 |
-
|
| 15 |
-
# Create CosmicConfig
|
| 16 |
-
config = CosmicConfig(
|
| 17 |
-
vocab_size=config_dict["vocab_size"],
|
| 18 |
-
block_size=config_dict["block_size"],
|
| 19 |
-
n_layer=config_dict["n_layer"],
|
| 20 |
-
n_head=config_dict["n_head"],
|
| 21 |
-
n_embd=config_dict["n_embd"],
|
| 22 |
-
bias=config_dict["bias"],
|
| 23 |
-
dropout=0.0, # Set to 0 for inference
|
| 24 |
-
use_rotary=config_dict["use_rotary"],
|
| 25 |
-
use_swiglu=config_dict["use_swiglu"],
|
| 26 |
-
use_gqa=config_dict["use_gqa"],
|
| 27 |
-
n_query_groups=config_dict["n_query_groups"],
|
| 28 |
-
use_qk_norm=config_dict["use_qk_norm"]
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
# Create model
|
| 32 |
-
model = CosmicFish(config)
|
| 33 |
-
|
| 34 |
-
# Load weights
|
| 35 |
-
state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
|
| 36 |
-
model.load_state_dict(state_dict)
|
| 37 |
-
model.eval()
|
| 38 |
-
|
| 39 |
-
# Load tokenizer
|
| 40 |
-
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 41 |
-
|
| 42 |
-
return model, tokenizer
|
| 43 |
-
|
| 44 |
-
# Example usage:
|
| 45 |
-
# model, tokenizer = load_cosmicfish("./")
|
| 46 |
-
# input_text = "The future of AI is"
|
| 47 |
-
# inputs = tokenizer.encode(input_text, return_tensors="pt")
|
| 48 |
-
# outputs = model.generate(inputs, max_length=50, temperature=0.7, do_sample=True)
|
| 49 |
-
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 50 |
-
# print(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|