""" Inference and Model Loading Utilities """ import os import torch from torch.nn import functional as F import tiktoken from model import GPT, GPTConfig def get_device(): """Auto-detect and return the best available device""" if torch.cuda.is_available(): return 'cuda' elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" else: return 'cpu' def load_model(model_path=None, pretrained_model='gpt2', device=None): """ Load model with priority: saved checkpoint > pretrained model Args: model_path: Path to saved model checkpoint (.pth or .pt file) pretrained_model: HuggingFace model name to fallback to ('gpt2', 'gpt2-medium', etc.) device: Device to load model on (auto-detected if None) Returns: Loaded model and device """ if device is None: device = get_device() # Try to load saved checkpoint first if model_path and os.path.exists(model_path): try: print(f"Loading saved model from {model_path}...") model = GPT.load_checkpoint(model_path, device=device) return model, device except Exception as e: print(f"Failed to load saved model: {e}") print(f"Falling back to pretrained model: {pretrained_model}") # Fallback to pretrained model print(f"Loading pretrained model: {pretrained_model}...") try: model = GPT.from_pretrained(pretrained_model) model.to(device) return model, device except Exception as e: print(f"Failed to load pretrained model: {e}") # Last resort: create untrained model with default config print("Creating model with default config...") config = GPTConfig() model = GPT(config) model.to(device) return model, device def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"): """ Generate text completion for a given prompt using the GPT model. Args: prompt: Input text prompt model: GPT model instance max_tokens: Maximum number of tokens to generate top_k: Top-k sampling parameter (None for no top-k filtering) temperature: Temperature for sampling (higher = more random) device: Device to run inference on Returns: Generated text string (including original prompt) """ enc = tiktoken.get_encoding("gpt2") model.eval() with torch.no_grad(): # tokenize prompt input_ids = enc.encode(prompt) x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0) for _ in range(max_tokens): logits, _ = model(x) logits = logits[:, -1, :] / temperature if top_k is not None: topk = torch.topk(logits, top_k, dim=-1) mask = logits < topk.values[:, -1].unsqueeze(-1) logits = logits.masked_fill(mask, -float("inf")) probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) x = torch.cat((x, next_token), dim=1) generated_ids = x[0].tolist() return enc.decode(generated_ids)