""" SageMaker Inference Script for Legion Coder 8M This script handles model loading and inference for Amazon SageMaker deployment. It follows the SageMaker inference container contract. """ import os import json import torch import sys from pathlib import Path # Add model code to path sys.path.append('/opt/ml/model/code') class LegionCoderModel(torch.nn.Module): """Simplified model class for inference.""" def __init__(self, vocab_size=16000, d_model=576, num_layers=13, num_heads=16, d_ff=1152, max_seq_len=1024, dropout=0.1, pad_token_id=0): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.max_seq_len = max_seq_len self.pad_token_id = pad_token_id self.token_embedding = torch.nn.Embedding(vocab_size, d_model) self.position_embedding = torch.nn.Embedding(max_seq_len, d_model) self.blocks = torch.nn.ModuleList([self._create_block(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) self.norm = torch.nn.LayerNorm(d_model) self.lm_head = torch.nn.Linear(d_model, vocab_size, bias=False) self.lm_head.weight = self.token_embedding.weight self.dropout = torch.nn.Dropout(dropout) def _create_block(self, d_model, num_heads, d_ff, dropout): """Create a transformer block.""" from model import TransformerBlock return TransformerBlock(d_model, num_heads, d_ff, dropout) def forward(self, input_ids, attention_mask=None, labels=None): batch_size, seq_len = input_ids.shape device = input_ids.device positions = torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1) token_embeds = self.token_embedding(input_ids) pos_embeds = self.position_embedding(positions) x = self.dropout(token_embeds + pos_embeds) # Create causal mask mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1) causal_mask = mask == 0 if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) & attention_mask for block in self.blocks: x = block(x, causal_mask) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1)) return {'logits': logits, 'loss': loss} def generate(self, input_ids, max_length=100, temperature=1.0, top_k=50, top_p=0.95, pad_token_id=0, eos_token_id=2): self.eval() batch_size = input_ids.shape[0] device = input_ids.device with torch.no_grad(): for _ in range(max_length): if input_ids.shape[1] > self.max_seq_len: input_ids = input_ids[:, -self.max_seq_len:] outputs = self.forward(input_ids) logits = outputs['logits'] next_token_logits = logits[:, -1, :] / temperature if top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = torch.nn.functional.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) if (next_token == eos_token_id).all(): break return input_ids # SageMaker inference functions def model_fn(model_dir): """Load the model for inference.""" print(f"Loading model from {model_dir}") # Load config with open(os.path.join(model_dir, 'config.json'), 'r') as f: config = json.load(f) # Create model model = LegionCoderModel( vocab_size=config.get('vocab_size', 16000), d_model=config.get('d_model', 576), num_layers=config.get('num_layers', 13), num_heads=config.get('num_heads', 16), d_ff=config.get('d_ff', 1152), max_seq_len=config.get('max_seq_len', 1024), dropout=config.get('dropout', 0.1), pad_token_id=config.get('pad_token_id', 0) ) # Load weights from safetensors.torch import load_file state_dict = load_file(os.path.join(model_dir, 'model.safetensors')) model.load_state_dict(state_dict, strict=False) model.eval() print("Model loaded successfully!") return model def input_fn(request_body, request_content_type): """Parse input data.""" if request_content_type == 'application/json': input_data = json.loads(request_body) return input_data else: raise ValueError(f"Unsupported content type: {request_content_type}") def predict_fn(input_data, model): """Make prediction.""" import torch # Get input text text = input_data.get('inputs', '') parameters = input_data.get('parameters', {}) # Default parameters max_length = parameters.get('max_length', 100) temperature = parameters.get('temperature', 0.8) top_k = parameters.get('top_k', 50) top_p = parameters.get('top_p', 0.95) # Tokenize (simplified - would use actual tokenizer in production) # For now, return a placeholder return { 'generated_text': f"Generated response for: {text[:50]}...", 'parameters': parameters } def output_fn(prediction, response_content_type): """Format output.""" if response_content_type == 'application/json': return json.dumps(prediction), response_content_type else: raise ValueError(f"Unsupported content type: {response_content_type}") if __name__ == "__main__": # Test local inference print("Testing SageMaker inference script...") print("This script is designed to run within a SageMaker container.") print("For local testing, use the Streamlit app or direct model loading.")