#!/usr/bin/env python3 """ CUDA-optimized inference script for Ursa Minor Smashed model Requires CUDA-capable GPU """ import torch import torch.nn.functional as F import argparse import tiktoken from typing import Optional, List, Tuple import warnings warnings.filterwarnings('ignore') # Direct PyTorch Implementation class GPTConfig: def __init__(self, **kwargs): self.block_size = kwargs.get('block_size', 1024) self.vocab_size = kwargs.get('vocab_size', 50304) self.n_layer = kwargs.get('n_layer', 12) self.n_head = kwargs.get('n_head', 12) self.n_embd = kwargs.get('n_embd', 768) class CausalSelfAttention(torch.nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = torch.nn.Linear(config.n_embd, 3 * config.n_embd) self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd) self.n_head = config.n_head self.n_embd = config.n_embd def forward(self, x): B, T, C = x.size() qkv = self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.c_proj(y) return y class MLP(torch.nn.Module): def __init__(self, config): super().__init__() self.c_fc = torch.nn.Linear(config.n_embd, 4 * config.n_embd) self.gelu = torch.nn.GELU(approximate='tanh') self.c_proj = torch.nn.Linear(4 * config.n_embd, config.n_embd) def forward(self, x): x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) return x class Block(torch.nn.Module): def __init__(self, config): super().__init__() self.ln_1 = torch.nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln_2 = torch.nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class GPT(torch.nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = torch.nn.ModuleDict(dict( wte = torch.nn.Embedding(config.vocab_size, config.n_embd), wpe = torch.nn.Embedding(config.block_size, config.n_embd), h = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f = torch.nn.LayerNorm(config.n_embd), )) self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) # Weight tying self.transformer.wte.weight = self.lm_head.weight def forward(self, idx): B, T = idx.size() assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}" pos = torch.arange(0, T, dtype=torch.long, device=idx.device) pos_emb = self.transformer.wpe(pos) tok_emb = self.transformer.wte(idx) x = tok_emb + pos_emb for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) logits = self.lm_head(x) return logits def apply_repetition_penalty(logits: torch.Tensor, token_ids: List[int], penalty: float = 1.1): """Apply repetition penalty to logits""" for token_id in set(token_ids): logits[0, token_id] /= penalty return logits def top_k_top_p_filtering(logits: torch.Tensor, top_k: int = 50, top_p: float = 0.9): """Filter logits using top-k and/or top-p (nucleus) filtering""" if top_k > 0: values, indices = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < values[:, [-1]]] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') return logits def generate_direct( model: GPT, prompt: str, max_new_tokens: int = 100, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, repetition_penalty: float = 1.1 ): """Generate text using CUDA-optimized PyTorch implementation""" device = "cuda" # Initialize tokenizer enc = tiktoken.get_encoding("gpt2") # Encode prompt tokens = enc.encode(prompt) x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) model.eval() generated_tokens = [] with torch.no_grad(): for _ in range(max_new_tokens): # Get logits with CUDA autocast for performance with torch.cuda.amp.autocast(dtype=torch.bfloat16): logits = model(x) # Focus on last token logits = logits[:, -1, :] / temperature # Apply repetition penalty if repetition_penalty > 1.0 and len(generated_tokens) > 0: logits = apply_repetition_penalty(logits, generated_tokens[-20:], repetition_penalty) # Apply top-k and top-p filtering filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample probs = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to sequence x = torch.cat([x, next_token], dim=1) generated_tokens.append(next_token.item()) # Stop if EOS token if next_token.item() == enc.eot_token: break # Truncate if exceeding block size if x.size(1) > model.config.block_size: x = x[:, -model.config.block_size:] # Decode all_tokens = tokens + generated_tokens return enc.decode(all_tokens) def load_model_direct(checkpoint_path: str): """Load model from a PyTorch checkpoint - CUDA optimized""" if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available. Use inference_cpu.py for CPU inference.") device = "cuda" print(f"Loading model from checkpoint: {checkpoint_path}") # Create a dummy class to handle train_gpt2.GPTConfig references import sys import types # Create a fake train_gpt2 module to handle the reference train_gpt2_module = types.ModuleType('train_gpt2') class DummyGPTConfig: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) train_gpt2_module.GPTConfig = DummyGPTConfig sys.modules['train_gpt2'] = train_gpt2_module try: # Load to CPU first to avoid device issues checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) finally: # Clean up if 'train_gpt2' in sys.modules: del sys.modules['train_gpt2'] # Handle the config - it might be a train_gpt2.GPTConfig object config_obj = checkpoint['config'] if hasattr(config_obj, '__dict__'): # If it's an object, extract its attributes config_dict = vars(config_obj) else: # If it's already a dict config_dict = config_obj config = GPTConfig(**config_dict) model = GPT(config) model.load_state_dict(checkpoint['model']) model.to(device) # Enable optimizations model = torch.compile(model) if hasattr(torch, 'compile') else model return model def main(): parser = argparse.ArgumentParser(description="Generate text with Ursa Minor Smashed model (CUDA)") parser.add_argument("--model", type=str, default="model_optimized.pt", help="Path to model checkpoint (.pt file)") parser.add_argument("--prompt", type=str, default="Hello, I'm a language model", help="Input prompt") parser.add_argument("--max-tokens", type=int, default=100, help="Maximum number of tokens to generate") parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature (0.1=conservative, 1.0=creative)") parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling (0=disabled)") parser.add_argument("--top-p", type=float, default=0.9, help="Top-p (nucleus) sampling") parser.add_argument("--repetition-penalty", type=float, default=1.1, help="Repetition penalty (1.0=disabled)") args = parser.parse_args() # Load model from checkpoint model = load_model_direct(args.model) result = generate_direct( model, args.prompt, args.max_tokens, args.temperature, args.top_k, args.top_p, args.repetition_penalty ) print("\nGenerated text:") print("-" * 50) print(result) print("-" * 50) if __name__ == "__main__": main()