NeMo_Canary / tests /deploy /test_deployment_service.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from fastapi.testclient import TestClient
from nemo.deploy.service.fastapi_interface_to_pytriton import (
ChatCompletionRequest,
CompletionRequest,
TritonSettings,
_helper_fun,
app,
convert_numpy,
dict_to_str,
query_llm_async,
)
from nemo.deploy.service.rest_model_api import CompletionRequest as RestCompletionRequest
from nemo.deploy.service.rest_model_api import TritonSettings as RestTritonSettings
from nemo.deploy.service.rest_model_api import app as rest_app
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_triton_settings():
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.TritonSettings') as mock:
instance = mock.return_value
instance.triton_service_port = 8000
instance.triton_service_ip = "localhost"
yield instance
@pytest.fixture
def rest_client():
return TestClient(rest_app)
@pytest.fixture
def mock_rest_triton_settings():
with patch('nemo.deploy.service.rest_model_api.TritonSettings') as mock:
instance = mock.return_value
instance.triton_service_port = 8080
instance.triton_service_ip = "localhost"
instance.triton_request_timeout = 60
instance.openai_format_response = False
instance.output_generation_logits = False
yield instance
class TestTritonSettings:
def test_default_values(self):
with patch.dict(os.environ, {}, clear=True):
settings = TritonSettings()
assert settings.triton_service_port == 8000
assert settings.triton_service_ip == "0.0.0.0"
def test_custom_values(self):
with patch.dict(os.environ, {'TRITON_PORT': '9000', 'TRITON_HTTP_ADDRESS': '127.0.0.1'}, clear=True):
settings = TritonSettings()
assert settings.triton_service_port == 9000
assert settings.triton_service_ip == "127.0.0.1"
class TestCompletionRequest:
def test_default_completions_values(self):
request = CompletionRequest(model="test_model", prompt="test prompt")
assert request.model == "test_model"
assert request.prompt == "test prompt"
assert request.max_tokens == 512
assert request.temperature == 1.0
assert request.top_p == 0.0
assert request.top_k == 0
assert request.logprobs is None
assert request.echo is False
def test_default_chat_values(self):
request = ChatCompletionRequest(model="test_model", messages=[{"role": "user", "content": "test message"}])
assert request.model == "test_model"
assert request.messages == [{"role": "user", "content": "test message"}]
assert request.max_tokens == 512
assert request.temperature == 1.0
assert request.top_p == 0.0
assert request.top_k == 0
def test_greedy_params(self):
request = CompletionRequest(model="test_model", prompt="test prompt", temperature=0.0, top_p=0.0)
assert request.top_k == 1
class TestHealthEndpoints:
def test_health_check(self, client):
response = client.get("/v1/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
class TestUtilityFunctions:
def test_convert_numpy(self):
# Test with numpy array
arr = np.array([1, 2, 3])
assert convert_numpy(arr) == [1, 2, 3]
# Test with nested dictionary
nested = {"a": np.array([1, 2]), "b": {"c": np.array([3, 4])}}
assert convert_numpy(nested) == {"a": [1, 2], "b": {"c": [3, 4]}}
# Test with list
lst = [np.array([1, 2]), np.array([3, 4])]
assert convert_numpy(lst) == [[1, 2], [3, 4]]
def test_dict_to_str(self):
test_dict = {"key": "value", "number": 42}
result = dict_to_str(test_dict)
assert isinstance(result, str)
assert json.loads(result) == test_dict
class TestLLMQueryFunctions:
def test_helper_fun(self):
mock_nq = MagicMock()
mock_nq.query_llm.return_value = {"test": "response"}
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.NemoQueryLLMPyTorch', return_value=mock_nq):
result = _helper_fun(
url="http://test",
model="test_model",
prompts=["test prompt"],
temperature=0.7,
top_k=10,
top_p=0.9,
compute_logprob=True,
max_length=100,
apply_chat_template=False,
echo=False,
n_top_logprobs=0,
)
assert result == {"test": "response"}
mock_nq.query_llm.assert_called_once()
def test_query_llm_async(self):
mock_result = {"test": "response"}
with patch('nemo.deploy.service.fastapi_interface_to_pytriton._helper_fun', return_value=mock_result):
# Create an event loop and run the async function
import asyncio
loop = asyncio.get_event_loop()
result = loop.run_until_complete(
query_llm_async(
url="http://test",
model="test_model",
prompts=["test prompt"],
temperature=0.7,
top_k=10,
top_p=0.9,
compute_logprob=True,
max_length=100,
apply_chat_template=False,
echo=False,
n_top_logprobs=0,
)
)
assert result == mock_result
class TestAPIEndpoints:
def test_completions_v1(self, client):
mock_output = {
"choices": [
{
"text": [["test response"]],
"logprobs": {"token_logprobs": [[1.0, 2.0]], "top_logprobs": [[{"a": 0.5}, {"b": 0.5}]]},
}
]
}
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output):
response = client.post(
"/v1/completions/", json={"model": "test_model", "prompt": "test prompt", "logprobs": 1}
)
assert response.status_code == 200
data = response.json()
assert data["choices"][0]["text"] == "test response"
assert "logprobs" in data["choices"][0]
def test_chat_completions_v1(self, client):
mock_output = {"choices": [{"text": [["test response"]]}]}
with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output):
response = client.post(
"/v1/chat/completions/",
json={"model": "test_model", "messages": [{"role": "user", "content": "test message"}]},
)
assert response.status_code == 200
data = response.json()
assert data["choices"][0]["message"]["role"] == "assistant"
assert data["choices"][0]["message"]["content"] == "test response"
class TestRestTritonSettings:
def test_default_values(self):
with patch.dict(os.environ, {}, clear=True):
settings = RestTritonSettings()
assert settings.triton_service_port == 8080
assert settings.triton_service_ip == "0.0.0.0"
assert settings.triton_request_timeout == 60
assert settings.openai_format_response is False
assert settings.output_generation_logits is False
def test_custom_values(self):
with patch.dict(
os.environ,
{
'TRITON_PORT': '9000',
'TRITON_HTTP_ADDRESS': '127.0.0.1',
'TRITON_REQUEST_TIMEOUT': '120',
'OPENAI_FORMAT_RESPONSE': 'True',
'OUTPUT_GENERATION_LOGITS': 'True',
},
clear=True,
):
settings = RestTritonSettings()
assert settings.triton_service_port == 9000
assert settings.triton_service_ip == "127.0.0.1"
assert settings.triton_request_timeout == 120
assert settings.openai_format_response is True
assert settings.output_generation_logits is True
class TestRestCompletionRequest:
def test_default_values(self):
request = RestCompletionRequest(model="test_model", prompt="test prompt")
assert request.model == "test_model"
assert request.prompt == "test prompt"
assert request.max_tokens == 512
assert request.temperature == 1.0
assert request.top_p == 0.0
assert request.top_k == 1
assert request.stream is False
assert request.stop is None
assert request.frequency_penalty == 1.0
class TestRestHealthEndpoints:
def test_health_check(self, rest_client):
response = rest_client.get("/v1/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
def test_triton_health_success(self, rest_client):
with patch('requests.get') as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
response = rest_client.get("/v1/triton_health")
assert response.status_code == 200
assert response.json() == {"status": "Triton server is reachable and ready"}
class TestRestCompletionsEndpoint:
def test_completions_success(self, rest_client):
mock_output = [["test response"]]
with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm:
mock_instance = mock_llm.return_value
mock_instance.query_llm.return_value = mock_output
response = rest_client.post(
"/v1/completions/",
json={
"model": "test_model",
"prompt": "test prompt",
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 10,
},
)
assert response.status_code == 200
assert response.json() == {"output": "test response"}
def test_completions_standard_format(self, rest_client, mock_rest_triton_settings):
mock_output = [["test response"]]
mock_rest_triton_settings.openai_format_response = False
with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm:
mock_instance = mock_llm.return_value
mock_instance.query_llm.return_value = mock_output
response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"})
assert response.status_code == 200
assert response.json() == {"output": "test response"}
def test_completions_error_handling(self, rest_client):
with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm:
mock_instance = mock_llm.return_value
mock_instance.query_llm.side_effect = Exception("Test error")
response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"})
assert response.status_code == 200
assert response.json() == {"error": "An exception occurred"}