legion-coder-8m-10k / sagemaker_inference.py
dineth554's picture
Upload folder using huggingface_hub
0d9979c verified
"""
SageMaker Inference Script for Legion Coder 8M
This script handles model loading and inference for Amazon SageMaker deployment.
It follows the SageMaker inference container contract.
"""
import os
import json
import torch
import sys
from pathlib import Path
# Add model code to path
sys.path.append('/opt/ml/model/code')
class LegionCoderModel(torch.nn.Module):
"""Simplified model class for inference."""
def __init__(self, vocab_size=16000, d_model=576, num_layers=13, num_heads=16, d_ff=1152, max_seq_len=1024, dropout=0.1, pad_token_id=0):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_seq_len = max_seq_len
self.pad_token_id = pad_token_id
self.token_embedding = torch.nn.Embedding(vocab_size, d_model)
self.position_embedding = torch.nn.Embedding(max_seq_len, d_model)
self.blocks = torch.nn.ModuleList([self._create_block(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = torch.nn.LayerNorm(d_model)
self.lm_head = torch.nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.token_embedding.weight
self.dropout = torch.nn.Dropout(dropout)
def _create_block(self, d_model, num_heads, d_ff, dropout):
"""Create a transformer block."""
from model import TransformerBlock
return TransformerBlock(d_model, num_heads, d_ff, dropout)
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_len = input_ids.shape
device = input_ids.device
positions = torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
token_embeds = self.token_embedding(input_ids)
pos_embeds = self.position_embedding(positions)
x = self.dropout(token_embeds + pos_embeds)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
causal_mask = mask == 0
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) & attention_mask
for block in self.blocks:
x = block(x, causal_mask)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
return {'logits': logits, 'loss': loss}
def generate(self, input_ids, max_length=100, temperature=1.0, top_k=50, top_p=0.95, pad_token_id=0, eos_token_id=2):
self.eval()
batch_size = input_ids.shape[0]
device = input_ids.device
with torch.no_grad():
for _ in range(max_length):
if input_ids.shape[1] > self.max_seq_len:
input_ids = input_ids[:, -self.max_seq_len:]
outputs = self.forward(input_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :] / temperature
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if (next_token == eos_token_id).all():
break
return input_ids
# SageMaker inference functions
def model_fn(model_dir):
"""Load the model for inference."""
print(f"Loading model from {model_dir}")
# Load config
with open(os.path.join(model_dir, 'config.json'), 'r') as f:
config = json.load(f)
# Create model
model = LegionCoderModel(
vocab_size=config.get('vocab_size', 16000),
d_model=config.get('d_model', 576),
num_layers=config.get('num_layers', 13),
num_heads=config.get('num_heads', 16),
d_ff=config.get('d_ff', 1152),
max_seq_len=config.get('max_seq_len', 1024),
dropout=config.get('dropout', 0.1),
pad_token_id=config.get('pad_token_id', 0)
)
# Load weights
from safetensors.torch import load_file
state_dict = load_file(os.path.join(model_dir, 'model.safetensors'))
model.load_state_dict(state_dict, strict=False)
model.eval()
print("Model loaded successfully!")
return model
def input_fn(request_body, request_content_type):
"""Parse input data."""
if request_content_type == 'application/json':
input_data = json.loads(request_body)
return input_data
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model):
"""Make prediction."""
import torch
# Get input text
text = input_data.get('inputs', '')
parameters = input_data.get('parameters', {})
# Default parameters
max_length = parameters.get('max_length', 100)
temperature = parameters.get('temperature', 0.8)
top_k = parameters.get('top_k', 50)
top_p = parameters.get('top_p', 0.95)
# Tokenize (simplified - would use actual tokenizer in production)
# For now, return a placeholder
return {
'generated_text': f"Generated response for: {text[:50]}...",
'parameters': parameters
}
def output_fn(prediction, response_content_type):
"""Format output."""
if response_content_type == 'application/json':
return json.dumps(prediction), response_content_type
else:
raise ValueError(f"Unsupported content type: {response_content_type}")
if __name__ == "__main__":
# Test local inference
print("Testing SageMaker inference script...")
print("This script is designed to run within a SageMaker container.")
print("For local testing, use the Streamlit app or direct model loading.")