Usage
ONNX Runtime
from transformers import AutoConfig, AutoProcessor, GenerationConfig
from transformers.image_utils import load_image
import onnxruntime
import numpy as np
import os
# 1. Load config, processor, and model
model_id = "onnx-community/gemma-3-4b-it-ONNX"
local_dir = "./gemma-3-4b-it"
onnx_dir = os.path.join(local_dir, "onnx")
config = AutoConfig.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
generation_config = GenerationConfig.from_pretrained(model_id)
## Load sessions
providers = ['CPUExecutionProvider']
vision_session = onnxruntime.InferenceSession(os.path.join(onnx_dir, "vision_encoder.onnx"), providers=providers)
embed_session = onnxruntime.InferenceSession(os.path.join(onnx_dir, "embed_tokens.onnx"), providers=providers)
decoder_session = onnxruntime.InferenceSession(os.path.join(onnx_dir, "decoder_model_merged.onnx"), providers=providers)
## Set config values
text_config = config.text_config
num_key_value_heads = text_config.num_key_value_heads
head_dim = text_config.head_dim
num_hidden_layers = text_config.num_hidden_layers
eos_token_id = generation_config.eos_token_id
image_token_index = config.image_token_index
# 2. Prepare inputs
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = load_image(image_url)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "What is in this image?"},
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
tokenize=True,
)
input_ids = inputs['input_ids'].numpy()
attention_mask = inputs['attention_mask'].numpy()
has_vision_inputs = 'pixel_values' in inputs
pixel_values = inputs['pixel_values'].numpy() if has_vision_inputs else None
num_logits_to_keep = np.array(1, dtype=np.int64)
batch_size = input_ids.shape[0]
past_cache_values = {}
for i in range(num_hidden_layers):
for kv in ('key', 'value'):
past_cache_values[f'past_key_values.{i}.{kv}'] = np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=np.float32)
# 3. Generation loop
max_new_tokens = 1024
generated_tokens = np.array([[]], dtype=np.int64)
image_features = None
for i in range(max_new_tokens):
inputs_embeds = embed_session.run(None, {'input_ids': input_ids})[0]
if has_vision_inputs and image_features is None:
## Only compute vision features on first iteration
image_features = vision_session.run(None, dict(
pixel_values=pixel_values,
))[0]
## Merge text and vision embeddings
inputs_embeds[input_ids == image_token_index] = image_features.reshape(-1, image_features.shape[-1])
decoder_inputs = dict(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
num_logits_to_keep=num_logits_to_keep,
**past_cache_values,
)
logits, *present_cache_values = decoder_session.run(None, decoder_inputs)
## Update values for next generation loop
input_ids = logits[:, -1].argmax(-1, keepdims=True)
attention_mask = np.concatenate([attention_mask, np.ones((batch_size, 1), dtype=attention_mask.dtype)], axis=-1)
for j, key in enumerate(past_cache_values):
past_cache_values[key] = present_cache_values[j]
generated_tokens = np.concatenate([generated_tokens, input_ids], axis=-1)
if np.isin(input_ids, eos_token_id).any():
break
## (Optional) Streaming
print(processor.decode(input_ids[0], skip_special_tokens=False), end='', flush=True)
print()
# 4. Output result
print(processor.batch_decode(generated_tokens, skip_special_tokens=False)[0])
- Downloads last month
- 682
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support