|
|
import os |
|
|
from llama_cpp import Llama |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
model_path = os.environ.get("GGUF_MODEL_PATH") |
|
|
if not model_path: |
|
|
|
|
|
|
|
|
model_name = "Llama3_1_SCB_FT_Q8_0.gguf" |
|
|
model_path = os.path.join(path, model_name) |
|
|
|
|
|
print(f"Loading GGUF model from: {model_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.llama = Llama( |
|
|
model_path=model_path, |
|
|
n_gpu_layers=-1, |
|
|
n_ctx=4096, |
|
|
verbose=True, |
|
|
) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Handles the inference request. |
|
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", None) |
|
|
if inputs is None: |
|
|
return {"error": "No 'inputs' key found in the request payload."} |
|
|
|
|
|
|
|
|
max_new_tokens = data.pop("max_new_tokens", 256) |
|
|
temperature = data.pop("temperature", 0.7) |
|
|
top_p = data.pop("top_p", 0.95) |
|
|
|
|
|
|
|
|
output = self.llama( |
|
|
inputs, |
|
|
max_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
echo=False, |
|
|
) |
|
|
|
|
|
return output |
|
|
|