File size: 4,740 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path
import requests

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from pydantic_settings import BaseSettings

from nemo.deploy.nlp import NemoQueryLLM
from nemo.utils import logging


class TritonSettings(BaseSettings):
    _triton_service_port: int
    _triton_service_ip: str
    _triton_request_timeout: str

    def __init__(self):
        super(TritonSettings, self).__init__()
        try:
            self._triton_service_port = int(os.environ.get('TRITON_PORT', 8080))
            self._triton_service_ip = os.environ.get('TRITON_HTTP_ADDRESS', '0.0.0.0')
            self._triton_request_timeout = int(os.environ.get('TRITON_REQUEST_TIMEOUT', 60))
            self._openai_format_response = os.environ.get('OPENAI_FORMAT_RESPONSE', 'False').lower() == 'true'
            self._output_generation_logits = os.environ.get('OUTPUT_GENERATION_LOGITS', 'False').lower() == 'true'
        except Exception as error:
            logging.error("An exception occurred trying to retrieve set args in TritonSettings class. Error:", error)
            return

    @property
    def triton_service_port(self):
        return self._triton_service_port

    @property
    def triton_service_ip(self):
        return self._triton_service_ip

    @property
    def triton_request_timeout(self):
        return self._triton_request_timeout

    @property
    def openai_format_response(self):
        """
        Retuns the response from Triton server in OpenAI compatible format if set to True.
        """
        return self._openai_format_response

    @property
    def output_generation_logits(self):
        """
        Retuns the generation logits along with text in Triton server output if set to True.
        """
        return self._output_generation_logits


app = FastAPI()
triton_settings = TritonSettings()


class CompletionRequest(BaseModel):
    model: str
    prompt: str
    max_tokens: int = 512
    temperature: float = 1.0
    top_p: float = 0.0
    top_k: int = 1
    stream: bool = False
    stop: str | None = None
    frequency_penalty: float = 1.0


@app.get("/v1/health")
def health_check():
    return {"status": "ok"}


@app.get("/v1/triton_health")
async def check_triton_health():
    """
    This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application.
    Verify by running: curl http://service_http_address:service_port/v1/triton_health and the returned status should inform if the server is accessible.
    """
    triton_url = (
        f"http://{triton_settings.triton_service_ip}:{str(triton_settings.triton_service_port)}/v2/health/ready"
    )
    logging.info(f"Attempting to connect to Triton server at: {triton_url}")
    try:
        response = requests.get(triton_url, timeout=5)
        if response.status_code == 200:
            return {"status": "Triton server is reachable and ready"}
        else:
            raise HTTPException(status_code=503, detail="Triton server is not ready")
    except requests.RequestException as e:
        raise HTTPException(status_code=503, detail=f"Cannot reach Triton server: {str(e)}")


@app.post("/v1/completions/")
def completions_v1(request: CompletionRequest):
    try:
        url = triton_settings.triton_service_ip + ":" + str(triton_settings.triton_service_port)
        nq = NemoQueryLLM(url=url, model_name=request.model)
        output = nq.query_llm(
            prompts=[request.prompt],
            max_output_len=request.max_tokens,
            # when these below params are passed as None
            top_k=request.top_k,
            top_p=request.top_p,
            temperature=request.temperature,
            init_timeout=triton_settings.triton_request_timeout,
            openai_format_response=triton_settings.openai_format_response,
            output_generation_logits=triton_settings.output_generation_logits,
        )
        if triton_settings.openai_format_response:
            return output
        else:
            return {
                "output": output[0][0],
            }
    except Exception as error:
        logging.error("An exception occurred with the post request to /v1/completions/ endpoint:", error)
        return {"error": "An exception occurred"}