| """ |
| Codsworth Inference Script |
| Example: Load a trained model and generate text |
| """ |
|
|
| import json |
| import torch |
| import sys |
| sys.path.insert(0, '.') |
|
|
| from codsworth.config import CodsworthConfig |
| from codsworth.model import CodsworthTransformer |
| from codsworth.utils import get_device |
|
|
|
|
| |
| |
| |
| """ |
| Quick Example: |
| |
| # Load model and tokenizer |
| python inference.py --model codsworth_model.pt --prompt "the" |
| |
| # Interactive: |
| python inference.py --model codsworth_model.pt --interactive |
| |
| # With temperature: |
| python inference.py --model codsworth_model.pt --prompt "hello" --temperature 0.8 |
| """ |
|
|
|
|
| def load_model(model_path: str, config_path: str = "config.json"): |
| """ |
| Load trained Codsworth model from checkpoint. |
| |
| Args: |
| model_path: Path to .pt model file (e.g., "codsworth_model.pt") |
| config_path: Path to config.json |
| |
| Returns: |
| model: CodsworthTransformer |
| vocab: word -> id mapping |
| id_to_word: id -> word mapping |
| device: torch device |
| """ |
| |
| |
| with open(config_path, 'r') as f: |
| config_data = json.load(f) |
| |
| model_cfg = config_data["model"] |
| |
| |
| config = CodsworthConfig( |
| vocab_size=model_cfg["vocab_size"], |
| context_length=model_cfg["context_length"], |
| embedding_dim=model_cfg["embedding_dim"], |
| num_layers=model_cfg["num_layers"], |
| num_heads=model_cfg["num_heads"], |
| head_dim=model_cfg["head_dim"], |
| ffn_hidden_dim=model_cfg["ffn_hidden_dim"], |
| use_rope=model_cfg["use_rope"], |
| rope_theta=model_cfg["rope_theta"], |
| use_flash_attention=False, |
| use_gradient_checkpointing=False, |
| ) |
| |
| |
| tokenizer_cfg = config_data["tokenizer"] |
| with open(tokenizer_cfg["vocab_file"], 'r') as f: |
| vocab = json.load(f) |
| |
| id_to_word = {v: k for k, v in vocab.items()} |
| |
| |
| model = CodsworthTransformer(config) |
| model.load_state_dict(torch.load(model_path, map_location='cpu')) |
| |
| device = get_device() |
| model.to(device) |
| model.eval() |
| |
| return model, vocab, id_to_word, device |
|
|
|
|
| def generate( |
| model: CodsworthTransformer, |
| prompt: str, |
| vocab: dict, |
| id_to_word: dict, |
| device: torch.device, |
| max_new_tokens: int = 50, |
| temperature: float = 1.0, |
| top_k: int = None, |
| ) -> str: |
| """ |
| Generate text from a prompt. |
| |
| Args: |
| model: Trained Codsworth model |
| prompt: Input text |
| vocab: Vocabulary dictionary |
| id_to_word: ID to word mapping |
| device: torch device |
| max_new_tokens: Max tokens to generate |
| temperature: Sampling temperature (lower = more predictable) |
| top_k: Top-k sampling (None = disabled) |
| |
| Returns: |
| Generated text string |
| """ |
| |
| model.eval() |
| |
| |
| words = prompt.lower().split() |
| prompt_ids = [vocab.get(w, vocab["<unk>"]) for w in words] |
| |
| for _ in range(max_new_tokens): |
| |
| input_seq = prompt_ids[-model.config.context_length:] |
| padding_needed = model.config.context_length - len(input_seq) |
| if padding_needed > 0: |
| input_seq = [vocab["<pad>"]] * padding_needed + input_seq |
| |
| input_t = torch.tensor([input_seq], dtype=torch.long).to(device) |
| |
| with torch.no_grad(): |
| logits = model(input_t)["logits"] |
| next_logits = logits[0, -1, :] / temperature |
| |
| |
| if top_k is not None: |
| top_k_vals = torch.topk(next_logits, top_k)[0] |
| next_logits = torch.where( |
| next_logits < top_k_vals[-1], |
| torch.tensor(float('-inf'), device=device), |
| next_logits |
| ) |
| |
| probs = torch.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, 1).item() |
| |
| prompt_ids.append(next_token) |
| |
| |
| if next_token == vocab.get("<eos>", 2): |
| break |
| |
| |
| generated = [id_to_word.get(t, "<unk>") for t in prompt_ids] |
| return " ".join(generated) |
|
|
|
|
| def main(): |
| """Main function for command-line usage.""" |
| |
| import argparse |
| parser = argparse.ArgumentParser(description="Codsworth Inference") |
| parser.add_argument("--model", default="codsworth_model.pt", |
| help="Model checkpoint file") |
| parser.add_argument("--config", default="config.json", |
| help="Config file") |
| parser.add_argument("--prompt", default="the", |
| help="Input prompt") |
| parser.add_argument("--max_tokens", type=int, default=50, |
| help="Max tokens to generate") |
| parser.add_argument("--temperature", type=float, default=1.0, |
| help="Temperature (0.1-2.0)") |
| parser.add_argument("--top_k", type=int, default=None, |
| help="Top-k sampling") |
| parser.add_argument("--interactive", action="store_true", |
| help="Interactive mode") |
| |
| args = parser.parse_args() |
| |
| |
| print("Loading model...") |
| model, vocab, id_to_word, device = load_model(args.model, args.config) |
| print(f"Model loaded! Parameters: {model.get_num_params():,}") |
| print(f"Vocabulary: {len(vocab)} words") |
| print(f"Device: {device}") |
| |
| if args.interactive: |
| print("\nInteractive mode (type 'quit' to exit)") |
| while True: |
| prompt = input("\n> ") |
| if prompt.lower() == 'quit': |
| break |
| if prompt.strip(): |
| result = generate( |
| model, prompt, vocab, id_to_word, device, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| ) |
| print(result) |
| else: |
| print(f"\nPrompt: {args.prompt}") |
| result = generate( |
| model, args.prompt, vocab, id_to_word, device, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| ) |
| print(f"\nGenerated:\n{result}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |