|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
from nemo.deploy.service.fastapi_interface_to_pytriton import ( |
|
|
ChatCompletionRequest, |
|
|
CompletionRequest, |
|
|
TritonSettings, |
|
|
_helper_fun, |
|
|
app, |
|
|
convert_numpy, |
|
|
dict_to_str, |
|
|
query_llm_async, |
|
|
) |
|
|
from nemo.deploy.service.rest_model_api import CompletionRequest as RestCompletionRequest |
|
|
from nemo.deploy.service.rest_model_api import TritonSettings as RestTritonSettings |
|
|
from nemo.deploy.service.rest_model_api import app as rest_app |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def client(): |
|
|
return TestClient(app) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_triton_settings(): |
|
|
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.TritonSettings') as mock: |
|
|
instance = mock.return_value |
|
|
instance.triton_service_port = 8000 |
|
|
instance.triton_service_ip = "localhost" |
|
|
yield instance |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def rest_client(): |
|
|
return TestClient(rest_app) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_rest_triton_settings(): |
|
|
with patch('nemo.deploy.service.rest_model_api.TritonSettings') as mock: |
|
|
instance = mock.return_value |
|
|
instance.triton_service_port = 8080 |
|
|
instance.triton_service_ip = "localhost" |
|
|
instance.triton_request_timeout = 60 |
|
|
instance.openai_format_response = False |
|
|
instance.output_generation_logits = False |
|
|
yield instance |
|
|
|
|
|
|
|
|
class TestTritonSettings: |
|
|
def test_default_values(self): |
|
|
with patch.dict(os.environ, {}, clear=True): |
|
|
settings = TritonSettings() |
|
|
assert settings.triton_service_port == 8000 |
|
|
assert settings.triton_service_ip == "0.0.0.0" |
|
|
|
|
|
def test_custom_values(self): |
|
|
with patch.dict(os.environ, {'TRITON_PORT': '9000', 'TRITON_HTTP_ADDRESS': '127.0.0.1'}, clear=True): |
|
|
settings = TritonSettings() |
|
|
assert settings.triton_service_port == 9000 |
|
|
assert settings.triton_service_ip == "127.0.0.1" |
|
|
|
|
|
|
|
|
class TestCompletionRequest: |
|
|
def test_default_completions_values(self): |
|
|
request = CompletionRequest(model="test_model", prompt="test prompt") |
|
|
assert request.model == "test_model" |
|
|
assert request.prompt == "test prompt" |
|
|
assert request.max_tokens == 512 |
|
|
assert request.temperature == 1.0 |
|
|
assert request.top_p == 0.0 |
|
|
assert request.top_k == 0 |
|
|
assert request.logprobs is None |
|
|
assert request.echo is False |
|
|
|
|
|
def test_default_chat_values(self): |
|
|
request = ChatCompletionRequest(model="test_model", messages=[{"role": "user", "content": "test message"}]) |
|
|
assert request.model == "test_model" |
|
|
assert request.messages == [{"role": "user", "content": "test message"}] |
|
|
assert request.max_tokens == 512 |
|
|
assert request.temperature == 1.0 |
|
|
assert request.top_p == 0.0 |
|
|
assert request.top_k == 0 |
|
|
|
|
|
def test_greedy_params(self): |
|
|
request = CompletionRequest(model="test_model", prompt="test prompt", temperature=0.0, top_p=0.0) |
|
|
assert request.top_k == 1 |
|
|
|
|
|
|
|
|
class TestHealthEndpoints: |
|
|
def test_health_check(self, client): |
|
|
response = client.get("/v1/health") |
|
|
assert response.status_code == 200 |
|
|
assert response.json() == {"status": "ok"} |
|
|
|
|
|
|
|
|
class TestUtilityFunctions: |
|
|
def test_convert_numpy(self): |
|
|
|
|
|
arr = np.array([1, 2, 3]) |
|
|
assert convert_numpy(arr) == [1, 2, 3] |
|
|
|
|
|
|
|
|
nested = {"a": np.array([1, 2]), "b": {"c": np.array([3, 4])}} |
|
|
assert convert_numpy(nested) == {"a": [1, 2], "b": {"c": [3, 4]}} |
|
|
|
|
|
|
|
|
lst = [np.array([1, 2]), np.array([3, 4])] |
|
|
assert convert_numpy(lst) == [[1, 2], [3, 4]] |
|
|
|
|
|
def test_dict_to_str(self): |
|
|
test_dict = {"key": "value", "number": 42} |
|
|
result = dict_to_str(test_dict) |
|
|
assert isinstance(result, str) |
|
|
assert json.loads(result) == test_dict |
|
|
|
|
|
|
|
|
class TestLLMQueryFunctions: |
|
|
def test_helper_fun(self): |
|
|
mock_nq = MagicMock() |
|
|
mock_nq.query_llm.return_value = {"test": "response"} |
|
|
|
|
|
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.NemoQueryLLMPyTorch', return_value=mock_nq): |
|
|
result = _helper_fun( |
|
|
url="http://test", |
|
|
model="test_model", |
|
|
prompts=["test prompt"], |
|
|
temperature=0.7, |
|
|
top_k=10, |
|
|
top_p=0.9, |
|
|
compute_logprob=True, |
|
|
max_length=100, |
|
|
apply_chat_template=False, |
|
|
echo=False, |
|
|
n_top_logprobs=0, |
|
|
) |
|
|
assert result == {"test": "response"} |
|
|
mock_nq.query_llm.assert_called_once() |
|
|
|
|
|
def test_query_llm_async(self): |
|
|
mock_result = {"test": "response"} |
|
|
with patch('nemo.deploy.service.fastapi_interface_to_pytriton._helper_fun', return_value=mock_result): |
|
|
|
|
|
import asyncio |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
result = loop.run_until_complete( |
|
|
query_llm_async( |
|
|
url="http://test", |
|
|
model="test_model", |
|
|
prompts=["test prompt"], |
|
|
temperature=0.7, |
|
|
top_k=10, |
|
|
top_p=0.9, |
|
|
compute_logprob=True, |
|
|
max_length=100, |
|
|
apply_chat_template=False, |
|
|
echo=False, |
|
|
n_top_logprobs=0, |
|
|
) |
|
|
) |
|
|
assert result == mock_result |
|
|
|
|
|
|
|
|
class TestAPIEndpoints: |
|
|
def test_completions_v1(self, client): |
|
|
mock_output = { |
|
|
"choices": [ |
|
|
{ |
|
|
"text": [["test response"]], |
|
|
"logprobs": {"token_logprobs": [[1.0, 2.0]], "top_logprobs": [[{"a": 0.5}, {"b": 0.5}]]}, |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output): |
|
|
response = client.post( |
|
|
"/v1/completions/", json={"model": "test_model", "prompt": "test prompt", "logprobs": 1} |
|
|
) |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert data["choices"][0]["text"] == "test response" |
|
|
assert "logprobs" in data["choices"][0] |
|
|
|
|
|
def test_chat_completions_v1(self, client): |
|
|
mock_output = {"choices": [{"text": [["test response"]]}]} |
|
|
|
|
|
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output): |
|
|
response = client.post( |
|
|
"/v1/chat/completions/", |
|
|
json={"model": "test_model", "messages": [{"role": "user", "content": "test message"}]}, |
|
|
) |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert data["choices"][0]["message"]["role"] == "assistant" |
|
|
assert data["choices"][0]["message"]["content"] == "test response" |
|
|
|
|
|
|
|
|
class TestRestTritonSettings: |
|
|
def test_default_values(self): |
|
|
with patch.dict(os.environ, {}, clear=True): |
|
|
settings = RestTritonSettings() |
|
|
assert settings.triton_service_port == 8080 |
|
|
assert settings.triton_service_ip == "0.0.0.0" |
|
|
assert settings.triton_request_timeout == 60 |
|
|
assert settings.openai_format_response is False |
|
|
assert settings.output_generation_logits is False |
|
|
|
|
|
def test_custom_values(self): |
|
|
with patch.dict( |
|
|
os.environ, |
|
|
{ |
|
|
'TRITON_PORT': '9000', |
|
|
'TRITON_HTTP_ADDRESS': '127.0.0.1', |
|
|
'TRITON_REQUEST_TIMEOUT': '120', |
|
|
'OPENAI_FORMAT_RESPONSE': 'True', |
|
|
'OUTPUT_GENERATION_LOGITS': 'True', |
|
|
}, |
|
|
clear=True, |
|
|
): |
|
|
settings = RestTritonSettings() |
|
|
assert settings.triton_service_port == 9000 |
|
|
assert settings.triton_service_ip == "127.0.0.1" |
|
|
assert settings.triton_request_timeout == 120 |
|
|
assert settings.openai_format_response is True |
|
|
assert settings.output_generation_logits is True |
|
|
|
|
|
|
|
|
class TestRestCompletionRequest: |
|
|
def test_default_values(self): |
|
|
request = RestCompletionRequest(model="test_model", prompt="test prompt") |
|
|
assert request.model == "test_model" |
|
|
assert request.prompt == "test prompt" |
|
|
assert request.max_tokens == 512 |
|
|
assert request.temperature == 1.0 |
|
|
assert request.top_p == 0.0 |
|
|
assert request.top_k == 1 |
|
|
assert request.stream is False |
|
|
assert request.stop is None |
|
|
assert request.frequency_penalty == 1.0 |
|
|
|
|
|
|
|
|
class TestRestHealthEndpoints: |
|
|
def test_health_check(self, rest_client): |
|
|
response = rest_client.get("/v1/health") |
|
|
assert response.status_code == 200 |
|
|
assert response.json() == {"status": "ok"} |
|
|
|
|
|
def test_triton_health_success(self, rest_client): |
|
|
with patch('requests.get') as mock_get: |
|
|
mock_response = MagicMock() |
|
|
mock_response.status_code = 200 |
|
|
mock_get.return_value = mock_response |
|
|
|
|
|
response = rest_client.get("/v1/triton_health") |
|
|
assert response.status_code == 200 |
|
|
assert response.json() == {"status": "Triton server is reachable and ready"} |
|
|
|
|
|
|
|
|
class TestRestCompletionsEndpoint: |
|
|
def test_completions_success(self, rest_client): |
|
|
mock_output = [["test response"]] |
|
|
with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: |
|
|
mock_instance = mock_llm.return_value |
|
|
mock_instance.query_llm.return_value = mock_output |
|
|
|
|
|
response = rest_client.post( |
|
|
"/v1/completions/", |
|
|
json={ |
|
|
"model": "test_model", |
|
|
"prompt": "test prompt", |
|
|
"max_tokens": 100, |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9, |
|
|
"top_k": 10, |
|
|
}, |
|
|
) |
|
|
assert response.status_code == 200 |
|
|
assert response.json() == {"output": "test response"} |
|
|
|
|
|
def test_completions_standard_format(self, rest_client, mock_rest_triton_settings): |
|
|
mock_output = [["test response"]] |
|
|
mock_rest_triton_settings.openai_format_response = False |
|
|
|
|
|
with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: |
|
|
mock_instance = mock_llm.return_value |
|
|
mock_instance.query_llm.return_value = mock_output |
|
|
|
|
|
response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"}) |
|
|
assert response.status_code == 200 |
|
|
assert response.json() == {"output": "test response"} |
|
|
|
|
|
def test_completions_error_handling(self, rest_client): |
|
|
with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: |
|
|
mock_instance = mock_llm.return_value |
|
|
mock_instance.query_llm.side_effect = Exception("Test error") |
|
|
|
|
|
response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"}) |
|
|
assert response.status_code == 200 |
|
|
assert response.json() == {"error": "An exception occurred"} |
|
|
|