opencode-zerogpu / tests /conftest.py
serenichron's picture
Initial implementation of ZeroGPU OpenCode Provider
adcb9bd
"""Test fixtures for ZeroGPU OpenCode Provider tests."""
import pytest
from unittest.mock import MagicMock, patch
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer for testing."""
tokenizer = MagicMock()
tokenizer.pad_token = None
tokenizer.eos_token = "</s>"
tokenizer.pad_token_id = 0
tokenizer.eos_token_id = 2
tokenizer.model_max_length = 4096
def mock_apply_chat_template(messages, tokenize=False, add_generation_prompt=True):
parts = []
for msg in messages:
role = msg.get("role", msg.role if hasattr(msg, "role") else "user")
content = msg.get("content", msg.content if hasattr(msg, "content") else "")
if role == "system":
parts.append(f"<|system|>{content}</s>")
elif role == "user":
parts.append(f"<|user|>{content}</s>")
elif role == "assistant":
parts.append(f"<|assistant|>{content}</s>")
if add_generation_prompt:
parts.append("<|assistant|>")
return "".join(parts)
tokenizer.apply_chat_template = mock_apply_chat_template
def mock_call(text, return_tensors=None, truncation=True, max_length=None):
import torch
# Simple mock: return input_ids based on text length
token_count = max(1, len(text) // 4)
return {
"input_ids": torch.ones((1, token_count), dtype=torch.long),
"attention_mask": torch.ones((1, token_count), dtype=torch.long),
}
tokenizer.__call__ = mock_call
tokenizer.return_value = mock_call("test")
def mock_decode(tokens, skip_special_tokens=True):
return "This is a test response."
tokenizer.decode = mock_decode
return tokenizer
@pytest.fixture
def mock_model():
"""Create a mock model for testing."""
import torch
model = MagicMock()
def mock_generate(**kwargs):
input_ids = kwargs.get("input_ids", torch.ones((1, 10), dtype=torch.long))
input_length = input_ids.shape[1]
# Generate some tokens
generated = torch.ones((1, input_length + 20), dtype=torch.long)
return generated
model.generate = mock_generate
model.device = "cpu"
return model
@pytest.fixture
def sample_messages():
"""Sample chat messages for testing."""
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
]
@pytest.fixture
def sample_request_data():
"""Sample request data for OpenAI-compatible endpoint."""
return {
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
"temperature": 0.7,
"max_tokens": 512,
"stream": False,
}
@pytest.fixture
def sample_streaming_request_data():
"""Sample streaming request data."""
return {
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{"role": "user", "content": "Tell me a joke."},
],
"temperature": 0.7,
"max_tokens": 256,
"stream": True,
}
@pytest.fixture(autouse=True)
def mock_torch_cuda():
"""Mock CUDA availability for tests."""
with patch("torch.cuda.is_available", return_value=False):
yield