"""Test fixtures for ZeroGPU OpenCode Provider tests.""" import pytest from unittest.mock import MagicMock, patch @pytest.fixture def mock_tokenizer(): """Create a mock tokenizer for testing.""" tokenizer = MagicMock() tokenizer.pad_token = None tokenizer.eos_token = "" tokenizer.pad_token_id = 0 tokenizer.eos_token_id = 2 tokenizer.model_max_length = 4096 def mock_apply_chat_template(messages, tokenize=False, add_generation_prompt=True): parts = [] for msg in messages: role = msg.get("role", msg.role if hasattr(msg, "role") else "user") content = msg.get("content", msg.content if hasattr(msg, "content") else "") if role == "system": parts.append(f"<|system|>{content}") elif role == "user": parts.append(f"<|user|>{content}") elif role == "assistant": parts.append(f"<|assistant|>{content}") if add_generation_prompt: parts.append("<|assistant|>") return "".join(parts) tokenizer.apply_chat_template = mock_apply_chat_template def mock_call(text, return_tensors=None, truncation=True, max_length=None): import torch # Simple mock: return input_ids based on text length token_count = max(1, len(text) // 4) return { "input_ids": torch.ones((1, token_count), dtype=torch.long), "attention_mask": torch.ones((1, token_count), dtype=torch.long), } tokenizer.__call__ = mock_call tokenizer.return_value = mock_call("test") def mock_decode(tokens, skip_special_tokens=True): return "This is a test response." tokenizer.decode = mock_decode return tokenizer @pytest.fixture def mock_model(): """Create a mock model for testing.""" import torch model = MagicMock() def mock_generate(**kwargs): input_ids = kwargs.get("input_ids", torch.ones((1, 10), dtype=torch.long)) input_length = input_ids.shape[1] # Generate some tokens generated = torch.ones((1, input_length + 20), dtype=torch.long) return generated model.generate = mock_generate model.device = "cpu" return model @pytest.fixture def sample_messages(): """Sample chat messages for testing.""" return [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}, ] @pytest.fixture def sample_request_data(): """Sample request data for OpenAI-compatible endpoint.""" return { "model": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}, ], "temperature": 0.7, "max_tokens": 512, "stream": False, } @pytest.fixture def sample_streaming_request_data(): """Sample streaming request data.""" return { "model": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ {"role": "user", "content": "Tell me a joke."}, ], "temperature": 0.7, "max_tokens": 256, "stream": True, } @pytest.fixture(autouse=True) def mock_torch_cuda(): """Mock CUDA availability for tests.""" with patch("torch.cuda.is_available", return_value=False): yield