File size: 7,082 Bytes
0d9979c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """
SageMaker Inference Script for Legion Coder 8M
This script handles model loading and inference for Amazon SageMaker deployment.
It follows the SageMaker inference container contract.
"""
import os
import json
import torch
import sys
from pathlib import Path
# Add model code to path
sys.path.append('/opt/ml/model/code')
class LegionCoderModel(torch.nn.Module):
"""Simplified model class for inference."""
def __init__(self, vocab_size=16000, d_model=576, num_layers=13, num_heads=16, d_ff=1152, max_seq_len=1024, dropout=0.1, pad_token_id=0):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_seq_len = max_seq_len
self.pad_token_id = pad_token_id
self.token_embedding = torch.nn.Embedding(vocab_size, d_model)
self.position_embedding = torch.nn.Embedding(max_seq_len, d_model)
self.blocks = torch.nn.ModuleList([self._create_block(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = torch.nn.LayerNorm(d_model)
self.lm_head = torch.nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.token_embedding.weight
self.dropout = torch.nn.Dropout(dropout)
def _create_block(self, d_model, num_heads, d_ff, dropout):
"""Create a transformer block."""
from model import TransformerBlock
return TransformerBlock(d_model, num_heads, d_ff, dropout)
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_len = input_ids.shape
device = input_ids.device
positions = torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
token_embeds = self.token_embedding(input_ids)
pos_embeds = self.position_embedding(positions)
x = self.dropout(token_embeds + pos_embeds)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
causal_mask = mask == 0
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) & attention_mask
for block in self.blocks:
x = block(x, causal_mask)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
return {'logits': logits, 'loss': loss}
def generate(self, input_ids, max_length=100, temperature=1.0, top_k=50, top_p=0.95, pad_token_id=0, eos_token_id=2):
self.eval()
batch_size = input_ids.shape[0]
device = input_ids.device
with torch.no_grad():
for _ in range(max_length):
if input_ids.shape[1] > self.max_seq_len:
input_ids = input_ids[:, -self.max_seq_len:]
outputs = self.forward(input_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :] / temperature
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if (next_token == eos_token_id).all():
break
return input_ids
# SageMaker inference functions
def model_fn(model_dir):
"""Load the model for inference."""
print(f"Loading model from {model_dir}")
# Load config
with open(os.path.join(model_dir, 'config.json'), 'r') as f:
config = json.load(f)
# Create model
model = LegionCoderModel(
vocab_size=config.get('vocab_size', 16000),
d_model=config.get('d_model', 576),
num_layers=config.get('num_layers', 13),
num_heads=config.get('num_heads', 16),
d_ff=config.get('d_ff', 1152),
max_seq_len=config.get('max_seq_len', 1024),
dropout=config.get('dropout', 0.1),
pad_token_id=config.get('pad_token_id', 0)
)
# Load weights
from safetensors.torch import load_file
state_dict = load_file(os.path.join(model_dir, 'model.safetensors'))
model.load_state_dict(state_dict, strict=False)
model.eval()
print("Model loaded successfully!")
return model
def input_fn(request_body, request_content_type):
"""Parse input data."""
if request_content_type == 'application/json':
input_data = json.loads(request_body)
return input_data
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model):
"""Make prediction."""
import torch
# Get input text
text = input_data.get('inputs', '')
parameters = input_data.get('parameters', {})
# Default parameters
max_length = parameters.get('max_length', 100)
temperature = parameters.get('temperature', 0.8)
top_k = parameters.get('top_k', 50)
top_p = parameters.get('top_p', 0.95)
# Tokenize (simplified - would use actual tokenizer in production)
# For now, return a placeholder
return {
'generated_text': f"Generated response for: {text[:50]}...",
'parameters': parameters
}
def output_fn(prediction, response_content_type):
"""Format output."""
if response_content_type == 'application/json':
return json.dumps(prediction), response_content_type
else:
raise ValueError(f"Unsupported content type: {response_content_type}")
if __name__ == "__main__":
# Test local inference
print("Testing SageMaker inference script...")
print("This script is designed to run within a SageMaker container.")
print("For local testing, use the Streamlit app or direct model loading.")
|