| | """ |
| | SageMaker Inference Script for Legion Coder 8M |
| | |
| | This script handles model loading and inference for Amazon SageMaker deployment. |
| | It follows the SageMaker inference container contract. |
| | """ |
| |
|
| | import os |
| | import json |
| | import torch |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.append('/opt/ml/model/code') |
| |
|
| | class LegionCoderModel(torch.nn.Module): |
| | """Simplified model class for inference.""" |
| | |
| | def __init__(self, vocab_size=16000, d_model=576, num_layers=13, num_heads=16, d_ff=1152, max_seq_len=1024, dropout=0.1, pad_token_id=0): |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.d_model = d_model |
| | self.max_seq_len = max_seq_len |
| | self.pad_token_id = pad_token_id |
| | self.token_embedding = torch.nn.Embedding(vocab_size, d_model) |
| | self.position_embedding = torch.nn.Embedding(max_seq_len, d_model) |
| | self.blocks = torch.nn.ModuleList([self._create_block(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) |
| | self.norm = torch.nn.LayerNorm(d_model) |
| | self.lm_head = torch.nn.Linear(d_model, vocab_size, bias=False) |
| | self.lm_head.weight = self.token_embedding.weight |
| | self.dropout = torch.nn.Dropout(dropout) |
| |
|
| | def _create_block(self, d_model, num_heads, d_ff, dropout): |
| | """Create a transformer block.""" |
| | from model import TransformerBlock |
| | return TransformerBlock(d_model, num_heads, d_ff, dropout) |
| |
|
| | def forward(self, input_ids, attention_mask=None, labels=None): |
| | batch_size, seq_len = input_ids.shape |
| | device = input_ids.device |
| | positions = torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1) |
| | token_embeds = self.token_embedding(input_ids) |
| | pos_embeds = self.position_embedding(positions) |
| | x = self.dropout(token_embeds + pos_embeds) |
| | |
| | |
| | mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1) |
| | causal_mask = mask == 0 |
| | |
| | if attention_mask is not None: |
| | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| | causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) & attention_mask |
| | |
| | for block in self.blocks: |
| | x = block(x, causal_mask) |
| | |
| | x = self.norm(x) |
| | logits = self.lm_head(x) |
| | |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) |
| | loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1)) |
| | |
| | return {'logits': logits, 'loss': loss} |
| |
|
| | def generate(self, input_ids, max_length=100, temperature=1.0, top_k=50, top_p=0.95, pad_token_id=0, eos_token_id=2): |
| | self.eval() |
| | batch_size = input_ids.shape[0] |
| | device = input_ids.device |
| | |
| | with torch.no_grad(): |
| | for _ in range(max_length): |
| | if input_ids.shape[1] > self.max_seq_len: |
| | input_ids = input_ids[:, -self.max_seq_len:] |
| | |
| | outputs = self.forward(input_ids) |
| | logits = outputs['logits'] |
| | next_token_logits = logits[:, -1, :] / temperature |
| | |
| | if top_k > 0: |
| | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
| | next_token_logits[indices_to_remove] = float('-inf') |
| | |
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| | cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | next_token_logits[indices_to_remove] = float('-inf') |
| | |
| | probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | input_ids = torch.cat([input_ids, next_token], dim=1) |
| | |
| | if (next_token == eos_token_id).all(): |
| | break |
| | |
| | return input_ids |
| |
|
| |
|
| | |
| | def model_fn(model_dir): |
| | """Load the model for inference.""" |
| | print(f"Loading model from {model_dir}") |
| | |
| | |
| | with open(os.path.join(model_dir, 'config.json'), 'r') as f: |
| | config = json.load(f) |
| | |
| | |
| | model = LegionCoderModel( |
| | vocab_size=config.get('vocab_size', 16000), |
| | d_model=config.get('d_model', 576), |
| | num_layers=config.get('num_layers', 13), |
| | num_heads=config.get('num_heads', 16), |
| | d_ff=config.get('d_ff', 1152), |
| | max_seq_len=config.get('max_seq_len', 1024), |
| | dropout=config.get('dropout', 0.1), |
| | pad_token_id=config.get('pad_token_id', 0) |
| | ) |
| | |
| | |
| | from safetensors.torch import load_file |
| | state_dict = load_file(os.path.join(model_dir, 'model.safetensors')) |
| | model.load_state_dict(state_dict, strict=False) |
| | model.eval() |
| | |
| | print("Model loaded successfully!") |
| | return model |
| |
|
| |
|
| | def input_fn(request_body, request_content_type): |
| | """Parse input data.""" |
| | if request_content_type == 'application/json': |
| | input_data = json.loads(request_body) |
| | return input_data |
| | else: |
| | raise ValueError(f"Unsupported content type: {request_content_type}") |
| |
|
| |
|
| | def predict_fn(input_data, model): |
| | """Make prediction.""" |
| | import torch |
| | |
| | |
| | text = input_data.get('inputs', '') |
| | parameters = input_data.get('parameters', {}) |
| | |
| | |
| | max_length = parameters.get('max_length', 100) |
| | temperature = parameters.get('temperature', 0.8) |
| | top_k = parameters.get('top_k', 50) |
| | top_p = parameters.get('top_p', 0.95) |
| | |
| | |
| | |
| | return { |
| | 'generated_text': f"Generated response for: {text[:50]}...", |
| | 'parameters': parameters |
| | } |
| |
|
| |
|
| | def output_fn(prediction, response_content_type): |
| | """Format output.""" |
| | if response_content_type == 'application/json': |
| | return json.dumps(prediction), response_content_type |
| | else: |
| | raise ValueError(f"Unsupported content type: {response_content_type}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | print("Testing SageMaker inference script...") |
| | print("This script is designed to run within a SageMaker container.") |
| | print("For local testing, use the Streamlit app or direct model loading.") |
| |
|