codsworth-3.8m / codsworth /scripts /inference.py
Jaqshanahan's picture
Initial upload of Codsworth model
b84d85a verified
"""
Codsworth Inference Script
Example: Load a trained model and generate text
"""
import json
import torch
import sys
sys.path.insert(0, '.')
from codsworth.config import CodsworthConfig
from codsworth.model import CodsworthTransformer
from codsworth.utils import get_device
# ========================
# EXAMPLE USAGE
# ========================
"""
Quick Example:
# Load model and tokenizer
python inference.py --model codsworth_model.pt --prompt "the"
# Interactive:
python inference.py --model codsworth_model.pt --interactive
# With temperature:
python inference.py --model codsworth_model.pt --prompt "hello" --temperature 0.8
"""
def load_model(model_path: str, config_path: str = "config.json"):
"""
Load trained Codsworth model from checkpoint.
Args:
model_path: Path to .pt model file (e.g., "codsworth_model.pt")
config_path: Path to config.json
Returns:
model: CodsworthTransformer
vocab: word -> id mapping
id_to_word: id -> word mapping
device: torch device
"""
# Load config.json
with open(config_path, 'r') as f:
config_data = json.load(f)
model_cfg = config_data["model"]
# Create CodsworthConfig
config = CodsworthConfig(
vocab_size=model_cfg["vocab_size"],
context_length=model_cfg["context_length"],
embedding_dim=model_cfg["embedding_dim"],
num_layers=model_cfg["num_layers"],
num_heads=model_cfg["num_heads"],
head_dim=model_cfg["head_dim"],
ffn_hidden_dim=model_cfg["ffn_hidden_dim"],
use_rope=model_cfg["use_rope"],
rope_theta=model_cfg["rope_theta"],
use_flash_attention=False,
use_gradient_checkpointing=False,
)
# Load tokenizer
tokenizer_cfg = config_data["tokenizer"]
with open(tokenizer_cfg["vocab_file"], 'r') as f:
vocab = json.load(f)
id_to_word = {v: k for k, v in vocab.items()}
# Create and load model
model = CodsworthTransformer(config)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
device = get_device()
model.to(device)
model.eval()
return model, vocab, id_to_word, device
def generate(
model: CodsworthTransformer,
prompt: str,
vocab: dict,
id_to_word: dict,
device: torch.device,
max_new_tokens: int = 50,
temperature: float = 1.0,
top_k: int = None,
) -> str:
"""
Generate text from a prompt.
Args:
model: Trained Codsworth model
prompt: Input text
vocab: Vocabulary dictionary
id_to_word: ID to word mapping
device: torch device
max_new_tokens: Max tokens to generate
temperature: Sampling temperature (lower = more predictable)
top_k: Top-k sampling (None = disabled)
Returns:
Generated text string
"""
model.eval()
# Encode prompt
words = prompt.lower().split()
prompt_ids = [vocab.get(w, vocab["<unk>"]) for w in words]
for _ in range(max_new_tokens):
# Pad or truncate to context length
input_seq = prompt_ids[-model.config.context_length:]
padding_needed = model.config.context_length - len(input_seq)
if padding_needed > 0:
input_seq = [vocab["<pad>"]] * padding_needed + input_seq
input_t = torch.tensor([input_seq], dtype=torch.long).to(device)
with torch.no_grad():
logits = model(input_t)["logits"]
next_logits = logits[0, -1, :] / temperature
# Apply top-k
if top_k is not None:
top_k_vals = torch.topk(next_logits, top_k)[0]
next_logits = torch.where(
next_logits < top_k_vals[-1],
torch.tensor(float('-inf'), device=device),
next_logits
)
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
prompt_ids.append(next_token)
# Stop at EOS
if next_token == vocab.get("<eos>", 2):
break
# Decode
generated = [id_to_word.get(t, "<unk>") for t in prompt_ids]
return " ".join(generated)
def main():
"""Main function for command-line usage."""
import argparse
parser = argparse.ArgumentParser(description="Codsworth Inference")
parser.add_argument("--model", default="codsworth_model.pt",
help="Model checkpoint file")
parser.add_argument("--config", default="config.json",
help="Config file")
parser.add_argument("--prompt", default="the",
help="Input prompt")
parser.add_argument("--max_tokens", type=int, default=50,
help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=1.0,
help="Temperature (0.1-2.0)")
parser.add_argument("--top_k", type=int, default=None,
help="Top-k sampling")
parser.add_argument("--interactive", action="store_true",
help="Interactive mode")
args = parser.parse_args()
# Load model
print("Loading model...")
model, vocab, id_to_word, device = load_model(args.model, args.config)
print(f"Model loaded! Parameters: {model.get_num_params():,}")
print(f"Vocabulary: {len(vocab)} words")
print(f"Device: {device}")
if args.interactive:
print("\nInteractive mode (type 'quit' to exit)")
while True:
prompt = input("\n> ")
if prompt.lower() == 'quit':
break
if prompt.strip():
result = generate(
model, prompt, vocab, id_to_word, device,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
)
print(result)
else:
print(f"\nPrompt: {args.prompt}")
result = generate(
model, args.prompt, vocab, id_to_word, device,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
)
print(f"\nGenerated:\n{result}")
if __name__ == "__main__":
main()