File size: 3,208 Bytes
1df0e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from aetheris.api.server import app, get_engine
import aetheris.api.server

# Mock the engine globally
@pytest.fixture
def mock_engine():
    with patch("aetheris.api.server.engine") as mock_eng:
        # Mock generate_full
        mock_eng.generate_full.return_value = "This is a generated response."
        
        # Mock generate (streaming)
        def mock_stream(*args, **kwargs):
            yield "This "
            yield "is "
            yield "streamed."
        mock_eng.generate.side_effect = mock_stream
        
        # Need to ensure get_engine returns this mock
        # We can also just set aetheris.api.server.engine
        aetheris.api.server.engine = mock_eng
        yield mock_eng

client = TestClient(app)

def test_list_models(mock_engine):
    response = client.get("/v1/models")
    assert response.status_code == 200
    data = response.json()
    assert data["object"] == "list"
    assert len(data["data"]) > 0
    assert data["data"][0]["id"] == "aetheris-hybrid-mamba-moe"

def test_chat_completions_non_stream(mock_engine):
    payload = {
        "model": "aetheris-hybrid-mamba-moe",
        "messages": [{"role": "user", "content": "Hello"}],
        "stream": False
    }
    response = client.post("/v1/chat/completions", json=payload)
    assert response.status_code == 200
    data = response.json()
    assert data["object"] == "chat.completion"
    assert len(data["choices"]) == 1
    assert data["choices"][0]["message"]["content"] == "This is a generated response."

def test_chat_completions_stream(mock_engine):
    payload = {
        "model": "aetheris-hybrid-mamba-moe",
        "messages": [{"role": "user", "content": "Hello"}],
        "stream": True
    }
    response = client.post("/v1/chat/completions", json=payload)
    assert response.status_code == 200
    # SSE format checking
    assert "text/event-stream" in response.headers["content-type"]
    
    # We can iterate over the response lines to check content
    content = ""
    for line in response.iter_lines():
        if line:
            # TestClient iter_lines yields strings, not bytes, unless configured otherwise
            # or depending on the version. If it's bytes, we decode. If it's str, we don't.
            if isinstance(line, bytes):
                line = line.decode("utf-8")
            
            if line.startswith("data: ") and line != "data: [DONE]":
                import json
                chunk = json.loads(line[6:])
                if chunk["choices"][0]["delta"].get("content"):
                    content += chunk["choices"][0]["delta"]["content"]
    
    assert content == "This is streamed."

def test_completions(mock_engine):
    payload = {
        "model": "aetheris-hybrid-mamba-moe",
        "prompt": "Once upon a time",
        "max_tokens": 10
    }
    response = client.post("/v1/completions", json=payload)
    assert response.status_code == 200
    data = response.json()
    assert data["object"] == "text_completion"
    assert len(data["choices"]) == 1
    assert data["choices"][0]["text"] == "This is a generated response."