File size: 3,281 Bytes
0ede4e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Inference and Model Loading Utilities
"""

import os
import torch
from torch.nn import functional as F
import tiktoken
from model import GPT, GPTConfig


def get_device():
    """Auto-detect and return the best available device"""
    if torch.cuda.is_available():
        return 'cuda'
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    else:
        return 'cpu'


def load_model(model_path=None, pretrained_model='gpt2', device=None):
    """
    Load model with priority: saved checkpoint > pretrained model
    
    Args:
        model_path: Path to saved model checkpoint (.pth or .pt file)
        pretrained_model: HuggingFace model name to fallback to ('gpt2', 'gpt2-medium', etc.)
        device: Device to load model on (auto-detected if None)
    
    Returns:
        Loaded model and device
    """
    if device is None:
        device = get_device()
    
    # Try to load saved checkpoint first
    if model_path and os.path.exists(model_path):
        try:
            print(f"Loading saved model from {model_path}...")
            model = GPT.load_checkpoint(model_path, device=device)
            return model, device
        except Exception as e:
            print(f"Failed to load saved model: {e}")
            print(f"Falling back to pretrained model: {pretrained_model}")
    
    # Fallback to pretrained model
    print(f"Loading pretrained model: {pretrained_model}...")
    try:
        model = GPT.from_pretrained(pretrained_model)
        model.to(device)
        return model, device
    except Exception as e:
        print(f"Failed to load pretrained model: {e}")
        # Last resort: create untrained model with default config
        print("Creating model with default config...")
        config = GPTConfig()
        model = GPT(config)
        model.to(device)
        return model, device


def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"):
    """
    Generate text completion for a given prompt using the GPT model.
    
    Args:
        prompt: Input text prompt
        model: GPT model instance
        max_tokens: Maximum number of tokens to generate
        top_k: Top-k sampling parameter (None for no top-k filtering)
        temperature: Temperature for sampling (higher = more random)
        device: Device to run inference on
    
    Returns:
        Generated text string (including original prompt)
    """
    enc = tiktoken.get_encoding("gpt2")
    model.eval()
    
    with torch.no_grad():
        # tokenize prompt
        input_ids = enc.encode(prompt)
        x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)

        for _ in range(max_tokens):
            logits, _ = model(x)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                topk = torch.topk(logits, top_k, dim=-1)
                mask = logits < topk.values[:, -1].unsqueeze(-1)
                logits = logits.masked_fill(mask, -float("inf"))

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, next_token), dim=1)

        generated_ids = x[0].tolist()
    return enc.decode(generated_ids)