|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize the model and tokenizer when the endpoint starts. |
|
|
|
|
|
Args: |
|
|
path (str): Path to the model files |
|
|
""" |
|
|
logger.info(f"Loading model from {path}") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True, |
|
|
load_in_8bit=False, |
|
|
load_in_4bit=False |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load without quantization: {e}") |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True, |
|
|
use_safetensors=True |
|
|
) |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process the inference request. |
|
|
|
|
|
Args: |
|
|
data (Dict[str, Any]): Request data containing: |
|
|
- inputs (str): The input text/prompt |
|
|
- parameters (dict, optional): Generation parameters |
|
|
- max_new_tokens (int): Maximum tokens to generate (default: 256) |
|
|
- temperature (float): Sampling temperature (default: 0.7) |
|
|
- top_p (float): Top-p sampling (default: 0.9) |
|
|
- do_sample (bool): Whether to use sampling (default: True) |
|
|
- repetition_penalty (float): Repetition penalty (default: 1.1) |
|
|
- return_full_text (bool): Return full text including input (default: False) |
|
|
|
|
|
Returns: |
|
|
List[Dict[str, Any]]: Generated text response |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
if not inputs: |
|
|
return [{"error": "No input text provided"}] |
|
|
|
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
|
max_new_tokens = parameters.get("max_new_tokens", 256) |
|
|
temperature = parameters.get("temperature", 0.7) |
|
|
top_p = parameters.get("top_p", 0.9) |
|
|
do_sample = parameters.get("do_sample", True) |
|
|
repetition_penalty = parameters.get("repetition_penalty", 1.1) |
|
|
return_full_text = parameters.get("return_full_text", False) |
|
|
|
|
|
|
|
|
if not any(marker in inputs.lower() for marker in ["[inst]", "<s>", "### instruction", "user:", "assistant:"]): |
|
|
formatted_input = f"[INST] {inputs} [/INST]" |
|
|
else: |
|
|
formatted_input = inputs |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer.encode( |
|
|
formatted_input, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=2048 |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
input_ids = input_ids.cuda() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = self.model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=do_sample, |
|
|
repetition_penalty=repetition_penalty, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
|
|
|
if return_full_text: |
|
|
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
else: |
|
|
|
|
|
new_tokens = output_ids[0][input_ids.shape[-1]:] |
|
|
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
generated_text = generated_text.strip() |
|
|
|
|
|
|
|
|
return [{ |
|
|
"generated_text": generated_text, |
|
|
"input_length": input_ids.shape[-1], |
|
|
"output_length": len(output_ids[0]) - input_ids.shape[-1] |
|
|
}] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during inference: {str(e)}") |
|
|
return [{"error": f"Inference failed: {str(e)}"}] |
|
|
|
|
|
def __del__(self): |
|
|
"""Clean up resources when the handler is destroyed.""" |
|
|
if hasattr(self, 'model'): |
|
|
del self.model |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |