File size: 1,414 Bytes
f6fdf6a
 
 
 
 
 
 
 
 
 
 
 
c77ec91
f6fdf6a
c77ec91
f6fdf6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c77ec91
f6fdf6a
c77ec91
f6fdf6a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from fastapi.testclient import TestClient

from app.main import app


client = TestClient(app)


def test_models(monkeypatch):
    async def fake_list_models():
        return {"data": [{"id": "DragonLLM/LLM-Pro-Finance-Small"}]}

    from app.providers import transformers_provider

    monkeypatch.setattr(transformers_provider, "list_models", fake_list_models)

    r = client.get("/v1/models")
    assert r.status_code == 200
    j = r.json()
    assert "data" in j


def test_chat_completions(monkeypatch):
    async def fake_chat(payload, stream=False):
        assert payload["model"]
        return {
            "id": "cmpl-1",
            "object": "chat.completion",
            "created": 0,
            "model": payload["model"],
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": "Hello"},
                    "finish_reason": "stop",
                }
            ],
        }

    from app.providers import transformers_provider

    monkeypatch.setattr(transformers_provider, "chat", fake_chat)

    r = client.post(
        "/v1/chat/completions",
        json={
            "model": "DragonLLM/LLM-Pro-Finance-Small",
            "messages": [{"role": "user", "content": "Hi"}],
        },
    )
    assert r.status_code == 200
    j = r.json()
    assert j["choices"][0]["message"]["content"] == "Hello"