transformers / tests /generation /test_continuous_batching.py
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
# Copyright 2025 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import itertools
import unittest
from unittest.mock import patch
import torch
from parameterized import parameterized
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
CompileConfig,
GenerationConfig,
LogitsProcessorList,
)
from transformers.generation.continuous_batching.cache import (
FullAttentionCacheAllocator,
PagedAttentionCache,
SlidingAttentionCacheAllocator,
group_layers_by_attn_type,
)
from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, build_attention_mask
from transformers.testing_utils import (
Expectations,
require_deterministic_for_xpu,
require_flash_attn,
require_torch_accelerator,
slow,
torch_device,
)
from transformers.utils import is_flash_attn_2_available, is_kernels_available
def flush_memory(flush_compile: bool = True) -> None:
gc.collect()
# If needed, flush everything related to torch.compile
if flush_compile:
# Dynamo resets
torch._dynamo.reset()
torch._dynamo.reset_code_caches()
if hasattr(torch._inductor, "codecache"):
# Clear FX graph cache
if hasattr(torch._inductor.codecache, "FxGraphCache"):
torch._inductor.codecache.FxGraphCache.clear()
# Clear PyCodeCache
if hasattr(torch._inductor.codecache, "PyCodeCache"):
torch._inductor.codecache.PyCodeCache.cache_clear()
# Clear TritonFuture cache (for async compilation)
if hasattr(torch._inductor.codecache, "TritonFuture"):
if hasattr(torch._inductor.codecache.TritonFuture, "_compile_cache"):
torch._inductor.codecache.TritonFuture._compile_cache.clear()
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
torch.xpu.synchronize()
gc.collect()
class ContinuousBatchingNonGenerationTest(unittest.TestCase):
@parameterized.expand(
[
(None, None, "0"),
(None, 4096, "0"),
("f", None, "0"),
("ffff", None, "0000"),
("sssss", 4096, "00000"),
("fs", 4096, "01"),
("ssfssf", 4096, "001221"),
("ssssf", 4096, "01234"),
("fffsffs", 4096, "0123456"),
]
)
def test_group_layers(
self,
layer_types_str: str | None,
sliding_window: int | None,
expected_groups: str,
) -> None:
"""Test the layer grouping algorithm of the hybrid allocator."""
# Take a config and change the layer_types attribute to the mix we want
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
if layer_types_str is not None:
layer_types = [{"f": "full_attention", "s": "sliding_window"}[char] for char in layer_types_str]
else:
layer_types = None
config.num_hidden_layers = len(expected_groups)
config.layer_types = layer_types
config.sliding_window = sliding_window
expected_lg = {}
for i, group in enumerate(expected_groups):
group = int(group)
expected_lg[group] = expected_lg.get(group, []) + [i]
expected_layer_groups = [expected_lg[i] for i in sorted(expected_lg.keys())]
# Test layer groups formation
layer_groups, group_types = group_layers_by_attn_type(config)
self.assertEqual(
sorted(expected_layer_groups),
sorted(layer_groups),
f"Test failed for: {layer_types_str = }, {sliding_window = }, {expected_layer_groups = }, {layer_groups = }",
)
# If layer_types is provided, check that group_types matches the type of the all layers in each group
if layer_types is not None:
for layer_group, group_type in zip(layer_groups, group_types):
layer_types = [config.layer_types[i] for i in layer_group]
self.assertEqual(layer_types, [group_type] * len(layer_types))
# If layer_types is None, all groups should be of the same type
else:
for group_type in group_types:
sliding_window = getattr(config, "sliding_window", None)
expected_group_type = "sliding_attention" if sliding_window is not None else "full_attention"
self.assertEqual(
group_type,
expected_group_type,
f"Test failed for: {layer_types_str = }, {sliding_window = }, {group_types = }",
)
@parameterized.expand(
[
([0, 4], [0, 4], 1, ["1000", "1100", "1110", "1111"]),
([0, 4], [0, 4], 2, ["1000", "1100", "0110", "0011"]),
([0, 3], [0, 5], 1, ["11100", "11110", "11111"]),
([0, 3], [0, 5], 3, ["11100", "01110", "00111"]),
([0, 3, 6], [0, 3, 6], 1, ["100000", "110000", "111000", "000100", "000110", "000111"]),
([0, 3, 6], [0, 3, 6], 2, ["100000", "110000", "011000", "000100", "000110", "000011"]),
]
)
def test_attention_mask(
self,
cumulative_seqlens_q: list[int],
cumulative_seqlens_k: list[int],
sliding_window: int, # the sliding window size, 1 means no sliding window
str_expected_mask: list[str], # the attention mask, broken down by line as a string of 0s and 1s
) -> None:
"""Tests the correctness of the attention mask used in the continuous batching API."""
# Build expected mask
minus_inf = torch.finfo(torch.float32).min
expected_mask = torch.empty((cumulative_seqlens_q[-1], cumulative_seqlens_k[-1]), dtype=torch.float32)
for i, line in enumerate(str_expected_mask):
expected_mask[i, :] = torch.tensor([minus_inf if c == "0" else 0 for c in line])
# Build actual mask
actual_mask = torch.full_like(expected_mask, minus_inf) # function modifies in place
build_attention_mask(
actual_mask, torch.tensor(cumulative_seqlens_q), torch.tensor(cumulative_seqlens_k), sliding_window
)
# Check that the actual mask matches the expected mask
matches = (expected_mask == actual_mask).all()
# If it doesn't match, print the masks in a readable form and fail the test
if not matches:
str_mask = [
"".join("1" if x == 0 else "0" for x in token_attn_vector) for token_attn_vector in actual_mask
]
str_mask = "\n".join(str_mask)
str_expected_mask = "\n".join(str_expected_mask)
self.fail(
f"Test failed for: {cumulative_seqlens_q = }, {cumulative_seqlens_k = }, {sliding_window = }\n"
f"Expected mask:\n{str_expected_mask}\n"
f"Actual mask:\n{str_mask}"
)
@parameterized.expand(
[
# Case 1: Only full attention groups, allocation succeeds
# needed_blocks = 2 * 1 = 2, free_blocks = 10 -> 2 <= 10 = True
(2, 0, 1, 0, 0, 10, True),
# Case 2: Only full attention groups, allocation fails
# needed_blocks = 5 * 2 = 10, free_blocks = 5 -> 10 <= 5 = False
(5, 0, 2, 0, 0, 5, False),
# Case 3: Mixed attention, sliding window not yet full
# needed_blocks = 2 * 1 + min(4 - 0, 2) * 1 = 2 + 2 = 4, free_blocks = 10 -> 4 <= 10 = True
(2, 0, 1, 1, 4, 10, True),
# Case 4: Mixed attention, sliding window partially filled
# needed_blocks = 3 * 1 + min(4 - 2, 3) * 1 = 3 + 2 = 5, free_blocks = 5 -> 5 <= 5 = True
(3, 2, 1, 1, 4, 5, True),
# Case 5: Mixed attention, sliding window already full (allocated_blocks >= max_sliding)
# blocks_left = max(4 - 5, 0) = 0, needed_blocks = 3 * 1 + 0 = 3, free_blocks = 5 -> 3 <= 5 = True
(3, 5, 1, 1, 4, 5, True),
# Case 6: Mixed attention, sliding window full, allocation fails due to full attention
# blocks_left = max(4 - 4, 0) = 0, needed_blocks = 6 * 1 + 0 = 6, free_blocks = 5 -> 6 <= 5 = False
(6, 4, 1, 1, 4, 5, False),
# Case 7: Multiple full attention groups
# needed_blocks = 3 * 2 = 6, free_blocks = 6 -> 6 <= 6 = True
(3, 0, 2, 0, 0, 6, True),
# Case 8: Multiple sliding attention groups, not full
# needed_blocks = 2 * 1 + min(4 - 1, 2) * 2 = 2 + 4 = 6, free_blocks = 6 -> 6 <= 6 = True
(2, 1, 1, 2, 4, 6, True),
# Case 9: Edge case - requesting 0 blocks always succeeds
# needed_blocks = 0, free_blocks = 0 -> 0 <= 0 = True
(0, 0, 1, 1, 4, 0, True),
# Case 10: Edge case - exactly enough blocks
# needed_blocks = 2 * 1 + min(3 - 0, 2) * 1 = 2 + 2 = 4, free_blocks = 4 -> 4 <= 4 = True
(2, 0, 1, 1, 3, 4, True),
]
)
@require_torch_accelerator
def test_continuous_batching_will_allocation_be_successful(
self,
num_requested_blocks: int,
allocated_blocks: int,
num_full_attention_groups: int,
num_sliding_attention_groups: int,
max_sliding_window_blocks_per_request: int,
num_free_blocks: int,
expected_result: bool,
) -> None:
"""Test the will_allocation_be_successful method of PagedAttentionCache, overloading the elevant attributes of
a dummy cache."""
# Create the cache
cache = PagedAttentionCache(
config=AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="sdpa"),
generation_config=GenerationConfig(num_blocks=8, block_size=16, max_batch_tokens=8),
device=torch_device,
)
# Overload cache parameters to match test scenario
cache.num_full_attention_groups = num_full_attention_groups
cache.num_sliding_attention_groups = num_sliding_attention_groups
cache.max_sliding_window_blocks_per_request = max_sliding_window_blocks_per_request
# Overload the cache get_num_free_blocks method
cache.get_num_free_blocks = lambda: num_free_blocks
# Test the method
result = cache.will_allocation_be_successful(num_requested_blocks, allocated_blocks)
self.assertEqual(
result,
expected_result,
f"Failed for: {num_requested_blocks=}, {allocated_blocks=}, {num_full_attention_groups=}, "
f"{num_sliding_attention_groups=}, {max_sliding_window_blocks_per_request=}, {num_free_blocks=}. "
f"Expected {expected_result}, got {result}",
)
class ContinuousBatchingGenerationTest(unittest.TestCase):
# -----------------------------------------------Parity tests----------------------------------------------- #
# Ensure continuous batching and non-continuous batching generation produce the same outputs #
# ---------------------------------------------------------------------------------------------------------- #
@require_deterministic_for_xpu
def _test_continuous_batching_parity(
self,
model_id: str,
allow_block_sharing: bool,
attn_implementation: str,
use_cuda_graph: bool,
use_compile: bool,
max_new_tokens: int = 20,
num_blocks: int | None = None,
num_repeat_prompts: int = 1,
) -> None:
"""Tests the parity between continuous batching and non-continuous batching generation."""
# Skip the test if Flash Attention 2 is required but not available
if attn_implementation == "flash_attention_2" and not (is_flash_attn_2_available() or is_kernels_available()):
self.skipTest("Flash Attention 2 is not available and neither is the kernels library. Skipping test.")
# Skip the test if cuda graph is on but the device is not CUDA
if use_cuda_graph and torch_device != "cuda":
self.skipTest("CUDA graph is only supported on CUDA devices. Skipping test.")
# Prepare continuous batching inputs
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
user_messages = [
"Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?",
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
"A basket contains 25 oranges among which 1 is bad, 20% are unripe, 2 are sour and the rest are good. How many oranges are good?",
] # fmt: skip
if num_repeat_prompts > 1:
user_messages = user_messages * num_repeat_prompts
chats = [[{"role": "user", "content": user_message}] for user_message in user_messages]
tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats]
input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized]
# Eager and SDPA implementations get a precision boost to account for the fact that an attention mask is used in
# continuous batching but not in generate
dtype = "auto" if attn_implementation == "flash_attention_2" else torch.float32
# Generation with continuous batching
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype=dtype)
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = max_new_tokens
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = use_cuda_graph
model.generation_config.num_blocks = num_blocks
if use_compile:
model.generation_config.compile_config = CompileConfig(fullgraph=True, mode="default")
# Generation with continuous batching
continuous_batching_outputs = model.generate_batch(
inputs=input_ids, generation_config=model.generation_config, allow_block_sharing=allow_block_sharing
)
# Prepare non-continuous batching inputs
inputs = tokenizer.apply_chat_template(
chats,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
return_dict=True,
return_attention_mask=True,
)
num_input_tokens = inputs.input_ids.shape[1]
# Generation without continuous batching
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype=dtype)
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = max_new_tokens
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = use_cuda_graph
if use_compile:
model.generation_config.compile_config = CompileConfig(fullgraph=True, mode="default")
generate_outputs = model.generate(**inputs.to(torch_device), generation_config=model.generation_config)
for i, user_message in enumerate(user_messages):
# Find the corresponding request in the continuous batching outputs
input_tokens = inputs.input_ids[i][inputs.attention_mask[i] == 1].tolist()
key_to_pop = None
for key, state in continuous_batching_outputs.items():
if state.prompt_ids == input_tokens:
key_to_pop = key
break
if key_to_pop is None:
self.fail(f"Request {i} not found in continuous batching outputs")
continuous_batching_output = continuous_batching_outputs.pop(key_to_pop).generated_tokens
generate_output = generate_outputs[i][num_input_tokens:].tolist()
while generate_output[-1] == model.generation_config.pad_token_id:
generate_output.pop()
if continuous_batching_output != generate_output:
decoded_continuous_batching_output = tokenizer.decode(continuous_batching_output)
decoded_generate_output = tokenizer.decode(generate_output)
msg = f"Test failed for {model_id = } {allow_block_sharing = }, {attn_implementation = }, {use_cuda_graph = }, {use_compile = }\n"
msg += f"User message : {repr(user_message)}\n"
msg += f"Continuous batching output: {repr(decoded_continuous_batching_output)}\n"
msg += f"Generate output : {repr(decoded_generate_output)}"
self.fail(msg)
del model
flush_memory(flush_compile=use_compile)
@parameterized.expand(
list(
itertools.product(
[False, True],
["eager", "sdpa", "flash_attention_2"],
[False, True],
[False, True],
)
)
)
@require_torch_accelerator
@slow
def test_continuous_batching_config_combinations(
self,
allow_block_sharing: bool,
attn_implementation: str,
use_cuda_graph: bool,
use_compile: bool,
) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self._test_continuous_batching_parity(
model_id, allow_block_sharing, attn_implementation, use_cuda_graph, use_compile
)
# FIXME: Qwen2.5-0.5B-Instruct is not here because it's broken (it uses a repetition penalty logits processor)
# TODO: replace gemma2 with a tiny version of GPT-OSS? That way we can test sliding window AND attention sink
@parameterized.expand(
list(
itertools.product(
["TinyLlama/TinyLlama-1.1B-Chat-v1.0", "google/gemma-2-2b-it"],
[False, True],
[False, True],
)
)
)
@require_torch_accelerator
@slow
def test_continuous_batching_diverse_models(self, model_id: str, use_cuda_graph: bool, use_compile: bool) -> None:
try:
self._test_continuous_batching_parity(model_id, True, "flash_attention_2", use_cuda_graph, use_compile)
finally:
flush_memory(flush_compile=use_compile)
@require_torch_accelerator
def test_continuous_batching_fast(self) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self._test_continuous_batching_parity(model_id, False, "sdpa", False, False)
@require_torch_accelerator
def test_continuous_batching_long_generate(self) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self._test_continuous_batching_parity(model_id, True, "flash_attention_2", True, True, max_new_tokens=80)
@require_torch_accelerator
def test_continuous_batching_few_blocks(self) -> None:
"""This test verifies that generation works with a very small number of blocks, ie. small enough that we need to
offload a request at some point. To add more complexity, we repeat the same prompt 4 times and enable prefix
sharing."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Patch soft_reset_one_request to verify it's called at least once
original_soft_reset = ContinuousBatchProcessor.soft_reset_one_request
with patch.object(
ContinuousBatchProcessor, "soft_reset_one_request", autospec=True, side_effect=original_soft_reset
) as mock_soft_reset:
self._test_continuous_batching_parity(
model_id, True, "sdpa", True, False, num_blocks=4, num_repeat_prompts=4
)
self.assertTrue(mock_soft_reset.called, "Soft reset method was not called.")
# ---------------------------------------Streaming tests--------------------------------------- #
# Ensures the requests have the right behavior with and without streaming #
# --------------------------------------------------------------------------------------------- #
def _test_streaming_or_not_request(self, with_streaming: bool, with_non_streaming: bool) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
model = AutoModelForCausalLM.from_pretrained(model_id)
manager = model.init_continuous_batching()
manager.logit_processor = LogitsProcessorList()
manager.start()
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True, return_dict=False
).to(model.device)[0]
# Test with non-streaming
if with_non_streaming:
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)
# In non-streaming mode, the total number of generated tokens is equal to the max new tokens
chunk = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk.generated_tokens), max_new_tokens)
# Test with streaming
if with_streaming:
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
# In streaming mode, the total number of generated tokens is incremented by 1 on each iteration
chunk_1 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_1.generated_tokens), 1)
chunk_2 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_2.generated_tokens), 2)
chunk_3 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_3.generated_tokens), 3)
manager.stop(block=True)
@require_torch_accelerator
def test_streaming_request(self) -> None:
self._test_streaming_or_not_request(with_streaming=True, with_non_streaming=False)
@require_torch_accelerator
def test_non_streaming_request(self) -> None:
self._test_streaming_or_not_request(with_streaming=False, with_non_streaming=True)
@require_torch_accelerator
def test_streaming_and_non_streaming_requests_can_alternate(self) -> None:
self._test_streaming_or_not_request(with_streaming=True, with_non_streaming=True)
# -----------------------------------------Misc. tests----------------------------------------- #
# Various tests that don't fit into the other categories #
# --------------------------------------------------------------------------------------------- #
def _test_block_sharing(
self, model_id: str, expected_layer_types: dict[str, int], input_msg: str, expected_output_tokens: list[int]
) -> None:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
generation_config = GenerationConfig(do_sample=False, block_size=32)
with model.continuous_batching_context_manager(generation_config=generation_config) as manager:
manager.logit_processor = LogitsProcessorList()
# Create a request with at least 32 tokens but less than 64 so prefill only generates one complete block
messages = [{"content": input_msg, "role": "user"}]
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True, return_dict=False
)
inputs = inputs.to(model.device)[0].tolist()
self.assertGreaterEqual(len(inputs), 32, f"Input length is {len(inputs)} instead of at least 32")
self.assertLess(len(inputs), 64, f"Input length is {len(inputs)} instead of less than 64")
# First request, which populates the cache w/ 2 complete blocks for each full attention layer group
request_id = manager.add_request(inputs, max_new_tokens=32)
chunk_no_reuse = next(manager.request_id_iter(request_id))
num_fa = expected_layer_types["full_attention"]
num_sw = expected_layer_types["sliding_window"]
hash_table = manager.batch_processor.cache._block_manager._hash_to_id
self.assertEqual(
len(hash_table),
2 * num_fa, # 2 = 1 for prefill + 1 for decode
f"There should be {2 * num_fa} blocks, 2 for each full attention layer group, but {len(hash_table) = }",
)
total_prefix_length = manager.batch_processor.cache._total_prefix_length
self.assertEqual(
total_prefix_length, 0, f"Expected total prefix length to be 0, got {total_prefix_length}"
)
# Assert the number of layer groups and their types are the expected ones
layer_groups = manager.batch_processor.cache.group_cache_managers
self.assertEqual(
len(layer_groups),
num_fa + num_sw,
f"There should be {num_fa + num_sw} layer groups, but {len(layer_groups) = }",
)
layer_group_types = {"full_attention": 0, "sliding_window": 0}
for cm in layer_groups:
if isinstance(cm, FullAttentionCacheAllocator):
layer_group_types["full_attention"] += 1
elif isinstance(cm, SlidingAttentionCacheAllocator):
layer_group_types["sliding_window"] += 1
else:
raise ValueError(f"Invalid layer group type: {type(cm)}")
self.assertEqual(
layer_group_types,
expected_layer_types,
f"The expected layer group types are\n{expected_layer_types}\nbut got\n{layer_group_types}",
)
# Second request, which should reuse the same blocks for the full attention layer groups
request_id = manager.add_request(inputs, max_new_tokens=32)
chunk_with_reuse = next(manager.request_id_iter(request_id))
# There should only still be two blocks in the hash table because of block reuse
self.assertEqual(
len(hash_table),
2 * num_fa,
f"Because of block reuse, there should still be two blocks in the hash table, but {len(hash_table) = }",
)
# Check that the whole prefill was matched if there are only full attention layers
if expected_layer_types["sliding_window"] == 0:
expected_total_prefix_length = 32
else:
expected_total_prefix_length = 0
total_prefix_length = manager.batch_processor.cache._total_prefix_length
self.assertEqual(
total_prefix_length,
expected_total_prefix_length,
f"Expected total prefix length to be {expected_total_prefix_length}, but got {total_prefix_length = }",
)
# Check the outputs were the same
self.assertEqual(chunk_no_reuse.generated_tokens, chunk_with_reuse.generated_tokens)
# As an additional sanity check, we also compare to the generated tokens when prefix sharing is disabled
print(f"{chunk_no_reuse.generated_tokens = } {expected_output_tokens = }")
self.assertEqual(chunk_no_reuse.generated_tokens, expected_output_tokens)
@require_torch_accelerator
def test_prefix_sharing(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
num_layer_groups = {"full_attention": 1, "sliding_window": 0}
input_msg = "What is the Transformers library known for?"
expected_generated_tokens = Expectations({
(None, None): [785, 80532, 6733, 374, 3881, 369, 1181, 5726, 311, 1855, 323, 36635, 3460, 12934, 4128, 4119, 11, 2670, 1846, 429, 646, 6923, 1467, 11, 14683, 1467, 11, 323, 2736, 1008, 4128, 13904]
}).get_expectation() # fmt: skip
return self._test_block_sharing(model_id, num_layer_groups, input_msg, expected_generated_tokens)
@require_torch_accelerator
def test_block_sharing_with_hybrid_model(self) -> None:
model_id = "google/gemma-3-1b-it"
num_layer_groups = {"full_attention": 2, "sliding_window": 11}
input_msg = "I am a software engineer looking to use open source software to build a new AI agent. What is the Transformers library known for?"
expected_generated_tokens = Expectations({
(None, None): [19058, 236764, 1531, 236789, 236751, 2541, 1679, 1144, 506, 128282, 9427, 563, 3224, 573, 236764, 10916, 528, 506, 4403, 529, 3788, 12498, 11362, 236761, 1030, 236789, 236751, 496, 808, 120749, 236829, 532]
}).get_expectation() # fmt: skip
return self._test_block_sharing(model_id, num_layer_groups, input_msg, expected_generated_tokens)
@parameterized.expand([True, False])
@require_torch_accelerator
@require_flash_attn # otherwise the test can fail because attention bias has a very slight impact on SDPA and eager
def test_num_return_sequences(self, allow_block_sharing: bool) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
user_messages = [
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?"
]
chats = [[{"role": "user", "content": user_message}] for user_message in user_messages]
tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats]
input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized]
# Generation with continuous batching
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = 30
model.generation_config.do_sample = False
# Generation with continuous batching
manager_cm = model.continuous_batching_context_manager(
allow_block_sharing=allow_block_sharing, block=True, timeout=5
)
# Main loop
results = []
with manager_cm as manager:
manager.num_return_sequences = 2
manager.add_requests(inputs=input_ids, max_new_tokens=30)
requests_left = 2
while requests_left:
result = manager.get_result(timeout=1)
if result and result.is_finished():
results.append(result)
requests_left -= 1
else:
if not manager.is_running():
break
self.assertEqual(len(results), 2, f"Expected 2 results, but got {len(results) = }")
self.assertEqual(results[0].generated_tokens, results[1].generated_tokens)