|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import onnx_ir as ir |
|
|
import torch |
|
|
import transformers |
|
|
from transformers import LlamaConfig |
|
|
|
|
|
import onnxscript.optimizer |
|
|
|
|
|
|
|
|
_config = LlamaConfig( |
|
|
_name_or_path="HuggingFaceTB/SmolLM-1.7B", |
|
|
architectures=["LlamaForCausalLM"], |
|
|
attention_bias=False, |
|
|
attention_dropout=0.0, |
|
|
bos_token_id=0, |
|
|
eos_token_id=0, |
|
|
hidden_act="silu", |
|
|
hidden_size=2048, |
|
|
initializer_range=0.02, |
|
|
intermediate_size=8192, |
|
|
max_position_embeddings=2048, |
|
|
model_type="llama", |
|
|
num_attention_heads=32, |
|
|
num_hidden_layers=1, |
|
|
num_key_value_heads=32, |
|
|
pretraining_tp=1, |
|
|
rms_norm_eps=1e-05, |
|
|
rope_scaling=None, |
|
|
rope_theta=10000.0, |
|
|
tie_word_embeddings=True, |
|
|
torch_dtype="float32", |
|
|
transformers_version="4.37.2", |
|
|
use_cache=True, |
|
|
vocab_size=49152, |
|
|
) |
|
|
|
|
|
|
|
|
_batch_size = 1 |
|
|
_seq_len = 10 |
|
|
_hidden_size = _config.hidden_size |
|
|
_num_attention_heads = _config.num_attention_heads |
|
|
dim = _hidden_size // _num_attention_heads |
|
|
_vocab_size = _config.vocab_size |
|
|
|
|
|
|
|
|
class _SmollmTestData: |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def get_torch_model(self): |
|
|
if not hasattr(self, "_torch_model"): |
|
|
model = transformers.LlamaForCausalLM(_config) |
|
|
model.eval() |
|
|
self._torch_model = model |
|
|
return self._torch_model |
|
|
|
|
|
def get_onnx_model(self) -> ir.Model: |
|
|
model = self.get_torch_model() |
|
|
inputs = self.get_inputs() |
|
|
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] |
|
|
exported = torch.onnx.export( |
|
|
model, inputs, input_names=input_names, dynamo=True, fallback=True |
|
|
) |
|
|
|
|
|
exported_model = exported.model |
|
|
onnxscript.optimizer.optimize(exported_model) |
|
|
return exported_model |
|
|
|
|
|
def get_inputs(self): |
|
|
if not hasattr(self, "_inputs"): |
|
|
input_ids = torch.randint(0, _vocab_size, (_batch_size, _seq_len)).to(torch.int64) |
|
|
attention_mask = torch.ones(input_ids.shape) |
|
|
position_ids = torch.arange(0, input_ids.size(-1)).unsqueeze(0) |
|
|
self._inputs = (input_ids, attention_mask, position_ids) |
|
|
return self._inputs |
|
|
|
|
|
def get_torch_outputs(self): |
|
|
output = self.get_torch_model()(*self.get_inputs()) |
|
|
logits = output.logits |
|
|
past_key_value = output.past_key_values[0] |
|
|
key = past_key_value[0] |
|
|
value = past_key_value[1] |
|
|
return (logits.detach().numpy(), key.detach().numpy(), value.detach().numpy()) |
|
|
|
|
|
def get_ort_inputs(self): |
|
|
inputs = self.get_inputs() |
|
|
return { |
|
|
f"input{i}": input.numpy() for i, input in enumerate(inputs) if input is not None |
|
|
} |
|
|
|