Spaces:
Sleeping
Sleeping
| """ | |
| Inference and Model Loading Utilities | |
| """ | |
| import os | |
| import torch | |
| from torch.nn import functional as F | |
| import tiktoken | |
| from model import GPT, GPTConfig | |
| def get_device(): | |
| """Auto-detect and return the best available device""" | |
| if torch.cuda.is_available(): | |
| return 'cuda' | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| else: | |
| return 'cpu' | |
| def load_model(model_path=None, pretrained_model='gpt2', device=None): | |
| """ | |
| Load model with priority: saved checkpoint > pretrained model | |
| Args: | |
| model_path: Path to saved model checkpoint (.pth or .pt file) | |
| pretrained_model: HuggingFace model name to fallback to ('gpt2', 'gpt2-medium', etc.) | |
| device: Device to load model on (auto-detected if None) | |
| Returns: | |
| Loaded model and device | |
| """ | |
| if device is None: | |
| device = get_device() | |
| # Try to load saved checkpoint first | |
| if model_path and os.path.exists(model_path): | |
| try: | |
| print(f"Loading saved model from {model_path}...") | |
| model = GPT.load_checkpoint(model_path, device=device) | |
| return model, device | |
| except Exception as e: | |
| print(f"Failed to load saved model: {e}") | |
| print(f"Falling back to pretrained model: {pretrained_model}") | |
| # Fallback to pretrained model | |
| print(f"Loading pretrained model: {pretrained_model}...") | |
| try: | |
| model = GPT.from_pretrained(pretrained_model) | |
| model.to(device) | |
| return model, device | |
| except Exception as e: | |
| print(f"Failed to load pretrained model: {e}") | |
| # Last resort: create untrained model with default config | |
| print("Creating model with default config...") | |
| config = GPTConfig() | |
| model = GPT(config) | |
| model.to(device) | |
| return model, device | |
| def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"): | |
| """ | |
| Generate text completion for a given prompt using the GPT model. | |
| Args: | |
| prompt: Input text prompt | |
| model: GPT model instance | |
| max_tokens: Maximum number of tokens to generate | |
| top_k: Top-k sampling parameter (None for no top-k filtering) | |
| temperature: Temperature for sampling (higher = more random) | |
| device: Device to run inference on | |
| Returns: | |
| Generated text string (including original prompt) | |
| """ | |
| enc = tiktoken.get_encoding("gpt2") | |
| model.eval() | |
| with torch.no_grad(): | |
| # tokenize prompt | |
| input_ids = enc.encode(prompt) | |
| x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0) | |
| for _ in range(max_tokens): | |
| logits, _ = model(x) | |
| logits = logits[:, -1, :] / temperature | |
| if top_k is not None: | |
| topk = torch.topk(logits, top_k, dim=-1) | |
| mask = logits < topk.values[:, -1].unsqueeze(-1) | |
| logits = logits.masked_fill(mask, -float("inf")) | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| x = torch.cat((x, next_token), dim=1) | |
| generated_ids = x[0].tolist() | |
| return enc.decode(generated_ids) | |