# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from unittest.mock import MagicMock, patch import numpy as np import pytest from fastapi.testclient import TestClient from nemo.deploy.service.fastapi_interface_to_pytriton import ( ChatCompletionRequest, CompletionRequest, TritonSettings, _helper_fun, app, convert_numpy, dict_to_str, query_llm_async, ) from nemo.deploy.service.rest_model_api import CompletionRequest as RestCompletionRequest from nemo.deploy.service.rest_model_api import TritonSettings as RestTritonSettings from nemo.deploy.service.rest_model_api import app as rest_app @pytest.fixture def client(): return TestClient(app) @pytest.fixture def mock_triton_settings(): with patch('nemo.deploy.service.fastapi_interface_to_pytriton.TritonSettings') as mock: instance = mock.return_value instance.triton_service_port = 8000 instance.triton_service_ip = "localhost" yield instance @pytest.fixture def rest_client(): return TestClient(rest_app) @pytest.fixture def mock_rest_triton_settings(): with patch('nemo.deploy.service.rest_model_api.TritonSettings') as mock: instance = mock.return_value instance.triton_service_port = 8080 instance.triton_service_ip = "localhost" instance.triton_request_timeout = 60 instance.openai_format_response = False instance.output_generation_logits = False yield instance class TestTritonSettings: def test_default_values(self): with patch.dict(os.environ, {}, clear=True): settings = TritonSettings() assert settings.triton_service_port == 8000 assert settings.triton_service_ip == "0.0.0.0" def test_custom_values(self): with patch.dict(os.environ, {'TRITON_PORT': '9000', 'TRITON_HTTP_ADDRESS': '127.0.0.1'}, clear=True): settings = TritonSettings() assert settings.triton_service_port == 9000 assert settings.triton_service_ip == "127.0.0.1" class TestCompletionRequest: def test_default_completions_values(self): request = CompletionRequest(model="test_model", prompt="test prompt") assert request.model == "test_model" assert request.prompt == "test prompt" assert request.max_tokens == 512 assert request.temperature == 1.0 assert request.top_p == 0.0 assert request.top_k == 0 assert request.logprobs is None assert request.echo is False def test_default_chat_values(self): request = ChatCompletionRequest(model="test_model", messages=[{"role": "user", "content": "test message"}]) assert request.model == "test_model" assert request.messages == [{"role": "user", "content": "test message"}] assert request.max_tokens == 512 assert request.temperature == 1.0 assert request.top_p == 0.0 assert request.top_k == 0 def test_greedy_params(self): request = CompletionRequest(model="test_model", prompt="test prompt", temperature=0.0, top_p=0.0) assert request.top_k == 1 class TestHealthEndpoints: def test_health_check(self, client): response = client.get("/v1/health") assert response.status_code == 200 assert response.json() == {"status": "ok"} class TestUtilityFunctions: def test_convert_numpy(self): # Test with numpy array arr = np.array([1, 2, 3]) assert convert_numpy(arr) == [1, 2, 3] # Test with nested dictionary nested = {"a": np.array([1, 2]), "b": {"c": np.array([3, 4])}} assert convert_numpy(nested) == {"a": [1, 2], "b": {"c": [3, 4]}} # Test with list lst = [np.array([1, 2]), np.array([3, 4])] assert convert_numpy(lst) == [[1, 2], [3, 4]] def test_dict_to_str(self): test_dict = {"key": "value", "number": 42} result = dict_to_str(test_dict) assert isinstance(result, str) assert json.loads(result) == test_dict class TestLLMQueryFunctions: def test_helper_fun(self): mock_nq = MagicMock() mock_nq.query_llm.return_value = {"test": "response"} with patch('nemo.deploy.service.fastapi_interface_to_pytriton.NemoQueryLLMPyTorch', return_value=mock_nq): result = _helper_fun( url="http://test", model="test_model", prompts=["test prompt"], temperature=0.7, top_k=10, top_p=0.9, compute_logprob=True, max_length=100, apply_chat_template=False, echo=False, n_top_logprobs=0, ) assert result == {"test": "response"} mock_nq.query_llm.assert_called_once() def test_query_llm_async(self): mock_result = {"test": "response"} with patch('nemo.deploy.service.fastapi_interface_to_pytriton._helper_fun', return_value=mock_result): # Create an event loop and run the async function import asyncio loop = asyncio.get_event_loop() result = loop.run_until_complete( query_llm_async( url="http://test", model="test_model", prompts=["test prompt"], temperature=0.7, top_k=10, top_p=0.9, compute_logprob=True, max_length=100, apply_chat_template=False, echo=False, n_top_logprobs=0, ) ) assert result == mock_result class TestAPIEndpoints: def test_completions_v1(self, client): mock_output = { "choices": [ { "text": [["test response"]], "logprobs": {"token_logprobs": [[1.0, 2.0]], "top_logprobs": [[{"a": 0.5}, {"b": 0.5}]]}, } ] } with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output): response = client.post( "/v1/completions/", json={"model": "test_model", "prompt": "test prompt", "logprobs": 1} ) assert response.status_code == 200 data = response.json() assert data["choices"][0]["text"] == "test response" assert "logprobs" in data["choices"][0] def test_chat_completions_v1(self, client): mock_output = {"choices": [{"text": [["test response"]]}]} with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output): response = client.post( "/v1/chat/completions/", json={"model": "test_model", "messages": [{"role": "user", "content": "test message"}]}, ) assert response.status_code == 200 data = response.json() assert data["choices"][0]["message"]["role"] == "assistant" assert data["choices"][0]["message"]["content"] == "test response" class TestRestTritonSettings: def test_default_values(self): with patch.dict(os.environ, {}, clear=True): settings = RestTritonSettings() assert settings.triton_service_port == 8080 assert settings.triton_service_ip == "0.0.0.0" assert settings.triton_request_timeout == 60 assert settings.openai_format_response is False assert settings.output_generation_logits is False def test_custom_values(self): with patch.dict( os.environ, { 'TRITON_PORT': '9000', 'TRITON_HTTP_ADDRESS': '127.0.0.1', 'TRITON_REQUEST_TIMEOUT': '120', 'OPENAI_FORMAT_RESPONSE': 'True', 'OUTPUT_GENERATION_LOGITS': 'True', }, clear=True, ): settings = RestTritonSettings() assert settings.triton_service_port == 9000 assert settings.triton_service_ip == "127.0.0.1" assert settings.triton_request_timeout == 120 assert settings.openai_format_response is True assert settings.output_generation_logits is True class TestRestCompletionRequest: def test_default_values(self): request = RestCompletionRequest(model="test_model", prompt="test prompt") assert request.model == "test_model" assert request.prompt == "test prompt" assert request.max_tokens == 512 assert request.temperature == 1.0 assert request.top_p == 0.0 assert request.top_k == 1 assert request.stream is False assert request.stop is None assert request.frequency_penalty == 1.0 class TestRestHealthEndpoints: def test_health_check(self, rest_client): response = rest_client.get("/v1/health") assert response.status_code == 200 assert response.json() == {"status": "ok"} def test_triton_health_success(self, rest_client): with patch('requests.get') as mock_get: mock_response = MagicMock() mock_response.status_code = 200 mock_get.return_value = mock_response response = rest_client.get("/v1/triton_health") assert response.status_code == 200 assert response.json() == {"status": "Triton server is reachable and ready"} class TestRestCompletionsEndpoint: def test_completions_success(self, rest_client): mock_output = [["test response"]] with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: mock_instance = mock_llm.return_value mock_instance.query_llm.return_value = mock_output response = rest_client.post( "/v1/completions/", json={ "model": "test_model", "prompt": "test prompt", "max_tokens": 100, "temperature": 0.7, "top_p": 0.9, "top_k": 10, }, ) assert response.status_code == 200 assert response.json() == {"output": "test response"} def test_completions_standard_format(self, rest_client, mock_rest_triton_settings): mock_output = [["test response"]] mock_rest_triton_settings.openai_format_response = False with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: mock_instance = mock_llm.return_value mock_instance.query_llm.return_value = mock_output response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"}) assert response.status_code == 200 assert response.json() == {"output": "test response"} def test_completions_error_handling(self, rest_client): with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: mock_instance = mock_llm.return_value mock_instance.query_llm.side_effect = Exception("Test error") response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"}) assert response.status_code == 200 assert response.json() == {"error": "An exception occurred"}