Spaces:
Sleeping
Sleeping
File size: 3,281 Bytes
0ede4e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
"""
Inference and Model Loading Utilities
"""
import os
import torch
from torch.nn import functional as F
import tiktoken
from model import GPT, GPTConfig
def get_device():
"""Auto-detect and return the best available device"""
if torch.cuda.is_available():
return 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
else:
return 'cpu'
def load_model(model_path=None, pretrained_model='gpt2', device=None):
"""
Load model with priority: saved checkpoint > pretrained model
Args:
model_path: Path to saved model checkpoint (.pth or .pt file)
pretrained_model: HuggingFace model name to fallback to ('gpt2', 'gpt2-medium', etc.)
device: Device to load model on (auto-detected if None)
Returns:
Loaded model and device
"""
if device is None:
device = get_device()
# Try to load saved checkpoint first
if model_path and os.path.exists(model_path):
try:
print(f"Loading saved model from {model_path}...")
model = GPT.load_checkpoint(model_path, device=device)
return model, device
except Exception as e:
print(f"Failed to load saved model: {e}")
print(f"Falling back to pretrained model: {pretrained_model}")
# Fallback to pretrained model
print(f"Loading pretrained model: {pretrained_model}...")
try:
model = GPT.from_pretrained(pretrained_model)
model.to(device)
return model, device
except Exception as e:
print(f"Failed to load pretrained model: {e}")
# Last resort: create untrained model with default config
print("Creating model with default config...")
config = GPTConfig()
model = GPT(config)
model.to(device)
return model, device
def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"):
"""
Generate text completion for a given prompt using the GPT model.
Args:
prompt: Input text prompt
model: GPT model instance
max_tokens: Maximum number of tokens to generate
top_k: Top-k sampling parameter (None for no top-k filtering)
temperature: Temperature for sampling (higher = more random)
device: Device to run inference on
Returns:
Generated text string (including original prompt)
"""
enc = tiktoken.get_encoding("gpt2")
model.eval()
with torch.no_grad():
# tokenize prompt
input_ids = enc.encode(prompt)
x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
for _ in range(max_tokens):
logits, _ = model(x)
logits = logits[:, -1, :] / temperature
if top_k is not None:
topk = torch.topk(logits, top_k, dim=-1)
mask = logits < topk.values[:, -1].unsqueeze(-1)
logits = logits.masked_fill(mask, -float("inf"))
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
x = torch.cat((x, next_token), dim=1)
generated_ids = x[0].tolist()
return enc.decode(generated_ids)
|