File size: 3,414 Bytes
adcb9bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Test fixtures for ZeroGPU OpenCode Provider tests."""

import pytest
from unittest.mock import MagicMock, patch


@pytest.fixture
def mock_tokenizer():
    """Create a mock tokenizer for testing."""
    tokenizer = MagicMock()
    tokenizer.pad_token = None
    tokenizer.eos_token = "</s>"
    tokenizer.pad_token_id = 0
    tokenizer.eos_token_id = 2
    tokenizer.model_max_length = 4096

    def mock_apply_chat_template(messages, tokenize=False, add_generation_prompt=True):
        parts = []
        for msg in messages:
            role = msg.get("role", msg.role if hasattr(msg, "role") else "user")
            content = msg.get("content", msg.content if hasattr(msg, "content") else "")
            if role == "system":
                parts.append(f"<|system|>{content}</s>")
            elif role == "user":
                parts.append(f"<|user|>{content}</s>")
            elif role == "assistant":
                parts.append(f"<|assistant|>{content}</s>")
        if add_generation_prompt:
            parts.append("<|assistant|>")
        return "".join(parts)

    tokenizer.apply_chat_template = mock_apply_chat_template

    def mock_call(text, return_tensors=None, truncation=True, max_length=None):
        import torch
        # Simple mock: return input_ids based on text length
        token_count = max(1, len(text) // 4)
        return {
            "input_ids": torch.ones((1, token_count), dtype=torch.long),
            "attention_mask": torch.ones((1, token_count), dtype=torch.long),
        }

    tokenizer.__call__ = mock_call
    tokenizer.return_value = mock_call("test")

    def mock_decode(tokens, skip_special_tokens=True):
        return "This is a test response."

    tokenizer.decode = mock_decode

    return tokenizer


@pytest.fixture
def mock_model():
    """Create a mock model for testing."""
    import torch

    model = MagicMock()

    def mock_generate(**kwargs):
        input_ids = kwargs.get("input_ids", torch.ones((1, 10), dtype=torch.long))
        input_length = input_ids.shape[1]
        # Generate some tokens
        generated = torch.ones((1, input_length + 20), dtype=torch.long)
        return generated

    model.generate = mock_generate
    model.device = "cpu"

    return model


@pytest.fixture
def sample_messages():
    """Sample chat messages for testing."""
    return [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello!"},
    ]


@pytest.fixture
def sample_request_data():
    """Sample request data for OpenAI-compatible endpoint."""
    return {
        "model": "meta-llama/Llama-3.1-8B-Instruct",
        "messages": [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "Hello!"},
        ],
        "temperature": 0.7,
        "max_tokens": 512,
        "stream": False,
    }


@pytest.fixture
def sample_streaming_request_data():
    """Sample streaming request data."""
    return {
        "model": "meta-llama/Llama-3.1-8B-Instruct",
        "messages": [
            {"role": "user", "content": "Tell me a joke."},
        ],
        "temperature": 0.7,
        "max_tokens": 256,
        "stream": True,
    }


@pytest.fixture(autouse=True)
def mock_torch_cuda():
    """Mock CUDA availability for tests."""
    with patch("torch.cuda.is_available", return_value=False):
        yield