""" Codsworth Inference Script Example: Load a trained model and generate text """ import json import torch import sys sys.path.insert(0, '.') from codsworth.config import CodsworthConfig from codsworth.model import CodsworthTransformer from codsworth.utils import get_device # ======================== # EXAMPLE USAGE # ======================== """ Quick Example: # Load model and tokenizer python inference.py --model codsworth_model.pt --prompt "the" # Interactive: python inference.py --model codsworth_model.pt --interactive # With temperature: python inference.py --model codsworth_model.pt --prompt "hello" --temperature 0.8 """ def load_model(model_path: str, config_path: str = "config.json"): """ Load trained Codsworth model from checkpoint. Args: model_path: Path to .pt model file (e.g., "codsworth_model.pt") config_path: Path to config.json Returns: model: CodsworthTransformer vocab: word -> id mapping id_to_word: id -> word mapping device: torch device """ # Load config.json with open(config_path, 'r') as f: config_data = json.load(f) model_cfg = config_data["model"] # Create CodsworthConfig config = CodsworthConfig( vocab_size=model_cfg["vocab_size"], context_length=model_cfg["context_length"], embedding_dim=model_cfg["embedding_dim"], num_layers=model_cfg["num_layers"], num_heads=model_cfg["num_heads"], head_dim=model_cfg["head_dim"], ffn_hidden_dim=model_cfg["ffn_hidden_dim"], use_rope=model_cfg["use_rope"], rope_theta=model_cfg["rope_theta"], use_flash_attention=False, use_gradient_checkpointing=False, ) # Load tokenizer tokenizer_cfg = config_data["tokenizer"] with open(tokenizer_cfg["vocab_file"], 'r') as f: vocab = json.load(f) id_to_word = {v: k for k, v in vocab.items()} # Create and load model model = CodsworthTransformer(config) model.load_state_dict(torch.load(model_path, map_location='cpu')) device = get_device() model.to(device) model.eval() return model, vocab, id_to_word, device def generate( model: CodsworthTransformer, prompt: str, vocab: dict, id_to_word: dict, device: torch.device, max_new_tokens: int = 50, temperature: float = 1.0, top_k: int = None, ) -> str: """ Generate text from a prompt. Args: model: Trained Codsworth model prompt: Input text vocab: Vocabulary dictionary id_to_word: ID to word mapping device: torch device max_new_tokens: Max tokens to generate temperature: Sampling temperature (lower = more predictable) top_k: Top-k sampling (None = disabled) Returns: Generated text string """ model.eval() # Encode prompt words = prompt.lower().split() prompt_ids = [vocab.get(w, vocab[""]) for w in words] for _ in range(max_new_tokens): # Pad or truncate to context length input_seq = prompt_ids[-model.config.context_length:] padding_needed = model.config.context_length - len(input_seq) if padding_needed > 0: input_seq = [vocab[""]] * padding_needed + input_seq input_t = torch.tensor([input_seq], dtype=torch.long).to(device) with torch.no_grad(): logits = model(input_t)["logits"] next_logits = logits[0, -1, :] / temperature # Apply top-k if top_k is not None: top_k_vals = torch.topk(next_logits, top_k)[0] next_logits = torch.where( next_logits < top_k_vals[-1], torch.tensor(float('-inf'), device=device), next_logits ) probs = torch.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, 1).item() prompt_ids.append(next_token) # Stop at EOS if next_token == vocab.get("", 2): break # Decode generated = [id_to_word.get(t, "") for t in prompt_ids] return " ".join(generated) def main(): """Main function for command-line usage.""" import argparse parser = argparse.ArgumentParser(description="Codsworth Inference") parser.add_argument("--model", default="codsworth_model.pt", help="Model checkpoint file") parser.add_argument("--config", default="config.json", help="Config file") parser.add_argument("--prompt", default="the", help="Input prompt") parser.add_argument("--max_tokens", type=int, default=50, help="Max tokens to generate") parser.add_argument("--temperature", type=float, default=1.0, help="Temperature (0.1-2.0)") parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling") parser.add_argument("--interactive", action="store_true", help="Interactive mode") args = parser.parse_args() # Load model print("Loading model...") model, vocab, id_to_word, device = load_model(args.model, args.config) print(f"Model loaded! Parameters: {model.get_num_params():,}") print(f"Vocabulary: {len(vocab)} words") print(f"Device: {device}") if args.interactive: print("\nInteractive mode (type 'quit' to exit)") while True: prompt = input("\n> ") if prompt.lower() == 'quit': break if prompt.strip(): result = generate( model, prompt, vocab, id_to_word, device, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, ) print(result) else: print(f"\nPrompt: {args.prompt}") result = generate( model, args.prompt, vocab, id_to_word, device, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, ) print(f"\nGenerated:\n{result}") if __name__ == "__main__": main()