File size: 1,414 Bytes
f6fdf6a c77ec91 f6fdf6a c77ec91 f6fdf6a c77ec91 f6fdf6a c77ec91 f6fdf6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_models(monkeypatch):
async def fake_list_models():
return {"data": [{"id": "DragonLLM/LLM-Pro-Finance-Small"}]}
from app.providers import transformers_provider
monkeypatch.setattr(transformers_provider, "list_models", fake_list_models)
r = client.get("/v1/models")
assert r.status_code == 200
j = r.json()
assert "data" in j
def test_chat_completions(monkeypatch):
async def fake_chat(payload, stream=False):
assert payload["model"]
return {
"id": "cmpl-1",
"object": "chat.completion",
"created": 0,
"model": payload["model"],
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Hello"},
"finish_reason": "stop",
}
],
}
from app.providers import transformers_provider
monkeypatch.setattr(transformers_provider, "chat", fake_chat)
r = client.post(
"/v1/chat/completions",
json={
"model": "DragonLLM/LLM-Pro-Finance-Small",
"messages": [{"role": "user", "content": "Hi"}],
},
)
assert r.status_code == 200
j = r.json()
assert j["choices"][0]["message"]["content"] == "Hello"
|