|
|
|
|
|
"""
|
|
|
CUDA-optimized inference script for Ursa Minor Smashed model
|
|
|
Requires CUDA-capable GPU
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import argparse
|
|
|
import tiktoken
|
|
|
from typing import Optional, List, Tuple
|
|
|
import warnings
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
|
class GPTConfig:
|
|
|
def __init__(self, **kwargs):
|
|
|
self.block_size = kwargs.get('block_size', 1024)
|
|
|
self.vocab_size = kwargs.get('vocab_size', 50304)
|
|
|
self.n_layer = kwargs.get('n_layer', 12)
|
|
|
self.n_head = kwargs.get('n_head', 12)
|
|
|
self.n_embd = kwargs.get('n_embd', 768)
|
|
|
|
|
|
class CausalSelfAttention(torch.nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
self.c_attn = torch.nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd)
|
|
|
self.n_head = config.n_head
|
|
|
self.n_embd = config.n_embd
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, T, C = x.size()
|
|
|
qkv = self.c_attn(x)
|
|
|
q, k, v = qkv.split(self.n_embd, dim=2)
|
|
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
|
|
y = self.c_proj(y)
|
|
|
return y
|
|
|
|
|
|
class MLP(torch.nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.c_fc = torch.nn.Linear(config.n_embd, 4 * config.n_embd)
|
|
|
self.gelu = torch.nn.GELU(approximate='tanh')
|
|
|
self.c_proj = torch.nn.Linear(4 * config.n_embd, config.n_embd)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.c_fc(x)
|
|
|
x = self.gelu(x)
|
|
|
x = self.c_proj(x)
|
|
|
return x
|
|
|
|
|
|
class Block(torch.nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.ln_1 = torch.nn.LayerNorm(config.n_embd)
|
|
|
self.attn = CausalSelfAttention(config)
|
|
|
self.ln_2 = torch.nn.LayerNorm(config.n_embd)
|
|
|
self.mlp = MLP(config)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x + self.attn(self.ln_1(x))
|
|
|
x = x + self.mlp(self.ln_2(x))
|
|
|
return x
|
|
|
|
|
|
class GPT(torch.nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
|
|
|
self.transformer = torch.nn.ModuleDict(dict(
|
|
|
wte = torch.nn.Embedding(config.vocab_size, config.n_embd),
|
|
|
wpe = torch.nn.Embedding(config.block_size, config.n_embd),
|
|
|
h = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
|
|
ln_f = torch.nn.LayerNorm(config.n_embd),
|
|
|
))
|
|
|
self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
self.transformer.wte.weight = self.lm_head.weight
|
|
|
|
|
|
def forward(self, idx):
|
|
|
B, T = idx.size()
|
|
|
assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
|
|
|
|
|
|
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
|
|
|
pos_emb = self.transformer.wpe(pos)
|
|
|
tok_emb = self.transformer.wte(idx)
|
|
|
x = tok_emb + pos_emb
|
|
|
|
|
|
for block in self.transformer.h:
|
|
|
x = block(x)
|
|
|
|
|
|
x = self.transformer.ln_f(x)
|
|
|
logits = self.lm_head(x)
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def apply_repetition_penalty(logits: torch.Tensor, token_ids: List[int], penalty: float = 1.1):
|
|
|
"""Apply repetition penalty to logits"""
|
|
|
for token_id in set(token_ids):
|
|
|
logits[0, token_id] /= penalty
|
|
|
return logits
|
|
|
|
|
|
def top_k_top_p_filtering(logits: torch.Tensor, top_k: int = 50, top_p: float = 0.9):
|
|
|
"""Filter logits using top-k and/or top-p (nucleus) filtering"""
|
|
|
if top_k > 0:
|
|
|
values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
logits[logits < values[:, [-1]]] = float('-inf')
|
|
|
|
|
|
if top_p < 1.0:
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
logits[indices_to_remove] = float('-inf')
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def generate_direct(
|
|
|
model: GPT,
|
|
|
prompt: str,
|
|
|
max_new_tokens: int = 100,
|
|
|
temperature: float = 0.8,
|
|
|
top_k: int = 50,
|
|
|
top_p: float = 0.9,
|
|
|
repetition_penalty: float = 1.1
|
|
|
):
|
|
|
"""Generate text using CUDA-optimized PyTorch implementation"""
|
|
|
device = "cuda"
|
|
|
|
|
|
|
|
|
enc = tiktoken.get_encoding("gpt2")
|
|
|
|
|
|
|
|
|
tokens = enc.encode(prompt)
|
|
|
x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
|
|
|
|
|
|
model.eval()
|
|
|
generated_tokens = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(max_new_tokens):
|
|
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
|
logits = model(x)
|
|
|
|
|
|
|
|
|
logits = logits[:, -1, :] / temperature
|
|
|
|
|
|
|
|
|
if repetition_penalty > 1.0 and len(generated_tokens) > 0:
|
|
|
logits = apply_repetition_penalty(logits, generated_tokens[-20:], repetition_penalty)
|
|
|
|
|
|
|
|
|
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
|
|
|
|
|
|
|
probs = F.softmax(filtered_logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
x = torch.cat([x, next_token], dim=1)
|
|
|
generated_tokens.append(next_token.item())
|
|
|
|
|
|
|
|
|
if next_token.item() == enc.eot_token:
|
|
|
break
|
|
|
|
|
|
|
|
|
if x.size(1) > model.config.block_size:
|
|
|
x = x[:, -model.config.block_size:]
|
|
|
|
|
|
|
|
|
all_tokens = tokens + generated_tokens
|
|
|
return enc.decode(all_tokens)
|
|
|
|
|
|
def load_model_direct(checkpoint_path: str):
|
|
|
"""Load model from a PyTorch checkpoint - CUDA optimized"""
|
|
|
if not torch.cuda.is_available():
|
|
|
raise RuntimeError("CUDA is not available. Use inference_cpu.py for CPU inference.")
|
|
|
|
|
|
device = "cuda"
|
|
|
print(f"Loading model from checkpoint: {checkpoint_path}")
|
|
|
|
|
|
|
|
|
import sys
|
|
|
import types
|
|
|
|
|
|
|
|
|
train_gpt2_module = types.ModuleType('train_gpt2')
|
|
|
|
|
|
class DummyGPTConfig:
|
|
|
def __init__(self, **kwargs):
|
|
|
for k, v in kwargs.items():
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
train_gpt2_module.GPTConfig = DummyGPTConfig
|
|
|
sys.modules['train_gpt2'] = train_gpt2_module
|
|
|
|
|
|
try:
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
finally:
|
|
|
|
|
|
if 'train_gpt2' in sys.modules:
|
|
|
del sys.modules['train_gpt2']
|
|
|
|
|
|
|
|
|
config_obj = checkpoint['config']
|
|
|
if hasattr(config_obj, '__dict__'):
|
|
|
|
|
|
config_dict = vars(config_obj)
|
|
|
else:
|
|
|
|
|
|
config_dict = config_obj
|
|
|
|
|
|
config = GPTConfig(**config_dict)
|
|
|
model = GPT(config)
|
|
|
model.load_state_dict(checkpoint['model'])
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
|
model = torch.compile(model) if hasattr(torch, 'compile') else model
|
|
|
|
|
|
return model
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Generate text with Ursa Minor Smashed model (CUDA)")
|
|
|
parser.add_argument("--model", type=str, default="model_optimized.pt",
|
|
|
help="Path to model checkpoint (.pt file)")
|
|
|
parser.add_argument("--prompt", type=str, default="Hello, I'm a language model",
|
|
|
help="Input prompt")
|
|
|
parser.add_argument("--max-tokens", type=int, default=100,
|
|
|
help="Maximum number of tokens to generate")
|
|
|
parser.add_argument("--temperature", type=float, default=0.8,
|
|
|
help="Sampling temperature (0.1=conservative, 1.0=creative)")
|
|
|
parser.add_argument("--top-k", type=int, default=50,
|
|
|
help="Top-k sampling (0=disabled)")
|
|
|
parser.add_argument("--top-p", type=float, default=0.9,
|
|
|
help="Top-p (nucleus) sampling")
|
|
|
parser.add_argument("--repetition-penalty", type=float, default=1.1,
|
|
|
help="Repetition penalty (1.0=disabled)")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
model = load_model_direct(args.model)
|
|
|
|
|
|
result = generate_direct(
|
|
|
model,
|
|
|
args.prompt,
|
|
|
args.max_tokens,
|
|
|
args.temperature,
|
|
|
args.top_k,
|
|
|
args.top_p,
|
|
|
args.repetition_penalty
|
|
|
)
|
|
|
|
|
|
print("\nGenerated text:")
|
|
|
print("-" * 50)
|
|
|
print(result)
|
|
|
print("-" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |