Ursa_Minor_Smashed / inference_cuda.py
Kaileh57's picture
Upload folder using huggingface_hub
d575ce4 verified
#!/usr/bin/env python3
"""
CUDA-optimized inference script for Ursa Minor Smashed model
Requires CUDA-capable GPU
"""
import torch
import torch.nn.functional as F
import argparse
import tiktoken
from typing import Optional, List, Tuple
import warnings
warnings.filterwarnings('ignore')
# Direct PyTorch Implementation
class GPTConfig:
def __init__(self, **kwargs):
self.block_size = kwargs.get('block_size', 1024)
self.vocab_size = kwargs.get('vocab_size', 50304)
self.n_layer = kwargs.get('n_layer', 12)
self.n_head = kwargs.get('n_head', 12)
self.n_embd = kwargs.get('n_embd', 768)
class CausalSelfAttention(torch.nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = torch.nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd)
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
class MLP(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = torch.nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = torch.nn.GELU(approximate='tanh')
self.c_proj = torch.nn.Linear(4 * config.n_embd, config.n_embd)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
return x
class Block(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = torch.nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = torch.nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = torch.nn.ModuleDict(dict(
wte = torch.nn.Embedding(config.vocab_size, config.n_embd),
wpe = torch.nn.Embedding(config.block_size, config.n_embd),
h = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = torch.nn.LayerNorm(config.n_embd),
))
self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying
self.transformer.wte.weight = self.lm_head.weight
def forward(self, idx):
B, T = idx.size()
assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
pos_emb = self.transformer.wpe(pos)
tok_emb = self.transformer.wte(idx)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
return logits
def apply_repetition_penalty(logits: torch.Tensor, token_ids: List[int], penalty: float = 1.1):
"""Apply repetition penalty to logits"""
for token_id in set(token_ids):
logits[0, token_id] /= penalty
return logits
def top_k_top_p_filtering(logits: torch.Tensor, top_k: int = 50, top_p: float = 0.9):
"""Filter logits using top-k and/or top-p (nucleus) filtering"""
if top_k > 0:
values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < values[:, [-1]]] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
return logits
def generate_direct(
model: GPT,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.1
):
"""Generate text using CUDA-optimized PyTorch implementation"""
device = "cuda"
# Initialize tokenizer
enc = tiktoken.get_encoding("gpt2")
# Encode prompt
tokens = enc.encode(prompt)
x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
model.eval()
generated_tokens = []
with torch.no_grad():
for _ in range(max_new_tokens):
# Get logits with CUDA autocast for performance
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits = model(x)
# Focus on last token
logits = logits[:, -1, :] / temperature
# Apply repetition penalty
if repetition_penalty > 1.0 and len(generated_tokens) > 0:
logits = apply_repetition_penalty(logits, generated_tokens[-20:], repetition_penalty)
# Apply top-k and top-p filtering
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample
probs = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to sequence
x = torch.cat([x, next_token], dim=1)
generated_tokens.append(next_token.item())
# Stop if EOS token
if next_token.item() == enc.eot_token:
break
# Truncate if exceeding block size
if x.size(1) > model.config.block_size:
x = x[:, -model.config.block_size:]
# Decode
all_tokens = tokens + generated_tokens
return enc.decode(all_tokens)
def load_model_direct(checkpoint_path: str):
"""Load model from a PyTorch checkpoint - CUDA optimized"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Use inference_cpu.py for CPU inference.")
device = "cuda"
print(f"Loading model from checkpoint: {checkpoint_path}")
# Create a dummy class to handle train_gpt2.GPTConfig references
import sys
import types
# Create a fake train_gpt2 module to handle the reference
train_gpt2_module = types.ModuleType('train_gpt2')
class DummyGPTConfig:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
train_gpt2_module.GPTConfig = DummyGPTConfig
sys.modules['train_gpt2'] = train_gpt2_module
try:
# Load to CPU first to avoid device issues
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
finally:
# Clean up
if 'train_gpt2' in sys.modules:
del sys.modules['train_gpt2']
# Handle the config - it might be a train_gpt2.GPTConfig object
config_obj = checkpoint['config']
if hasattr(config_obj, '__dict__'):
# If it's an object, extract its attributes
config_dict = vars(config_obj)
else:
# If it's already a dict
config_dict = config_obj
config = GPTConfig(**config_dict)
model = GPT(config)
model.load_state_dict(checkpoint['model'])
model.to(device)
# Enable optimizations
model = torch.compile(model) if hasattr(torch, 'compile') else model
return model
def main():
parser = argparse.ArgumentParser(description="Generate text with Ursa Minor Smashed model (CUDA)")
parser.add_argument("--model", type=str, default="model_optimized.pt",
help="Path to model checkpoint (.pt file)")
parser.add_argument("--prompt", type=str, default="Hello, I'm a language model",
help="Input prompt")
parser.add_argument("--max-tokens", type=int, default=100,
help="Maximum number of tokens to generate")
parser.add_argument("--temperature", type=float, default=0.8,
help="Sampling temperature (0.1=conservative, 1.0=creative)")
parser.add_argument("--top-k", type=int, default=50,
help="Top-k sampling (0=disabled)")
parser.add_argument("--top-p", type=float, default=0.9,
help="Top-p (nucleus) sampling")
parser.add_argument("--repetition-penalty", type=float, default=1.1,
help="Repetition penalty (1.0=disabled)")
args = parser.parse_args()
# Load model from checkpoint
model = load_model_direct(args.model)
result = generate_direct(
model,
args.prompt,
args.max_tokens,
args.temperature,
args.top_k,
args.top_p,
args.repetition_penalty
)
print("\nGenerated text:")
print("-" * 50)
print(result)
print("-" * 50)
if __name__ == "__main__":
main()