File size: 5,865 Bytes
e8fdd65 ee72852 096d229 ee72852 b815173 267fefd ee72852 e8fdd65 b815173 ee72852 e8fdd65 ee72852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the model and tokenizer when the endpoint starts.
Args:
path (str): Path to the model files
"""
logger.info(f"Loading model from {path}")
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Try to load without quantization first
try:
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
load_in_8bit=False,
load_in_4bit=False
)
except Exception as e:
logger.warning(f"Failed to load without quantization: {e}")
# Fallback: try with different settings
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
use_safetensors=True
)
# Set pad token if it doesn't exist
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("Model loaded successfully")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request.
Args:
data (Dict[str, Any]): Request data containing:
- inputs (str): The input text/prompt
- parameters (dict, optional): Generation parameters
- max_new_tokens (int): Maximum tokens to generate (default: 256)
- temperature (float): Sampling temperature (default: 0.7)
- top_p (float): Top-p sampling (default: 0.9)
- do_sample (bool): Whether to use sampling (default: True)
- repetition_penalty (float): Repetition penalty (default: 1.1)
- return_full_text (bool): Return full text including input (default: False)
Returns:
List[Dict[str, Any]]: Generated text response
"""
try:
# Extract inputs
inputs = data.get("inputs", "")
if not inputs:
return [{"error": "No input text provided"}]
# Extract generation parameters
parameters = data.get("parameters", {})
max_new_tokens = parameters.get("max_new_tokens", 256)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.9)
do_sample = parameters.get("do_sample", True)
repetition_penalty = parameters.get("repetition_penalty", 1.1)
return_full_text = parameters.get("return_full_text", False)
# Format the input as a chat message if it doesn't already contain instruction formatting
if not any(marker in inputs.lower() for marker in ["[inst]", "<s>", "### instruction", "user:", "assistant:"]):
formatted_input = f"[INST] {inputs} [/INST]"
else:
formatted_input = inputs
# Tokenize input
input_ids = self.tokenizer.encode(
formatted_input,
return_tensors="pt",
truncation=True,
max_length=2048 # Reasonable limit for input
)
# Move to GPU if available
if torch.cuda.is_available():
input_ids = input_ids.cuda()
# Generate response
with torch.no_grad():
output_ids = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True
)
# Decode the response
if return_full_text:
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
else:
# Only return the newly generated tokens
new_tokens = output_ids[0][input_ids.shape[-1]:]
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# Clean up the response
generated_text = generated_text.strip()
# Return in the expected format
return [{
"generated_text": generated_text,
"input_length": input_ids.shape[-1],
"output_length": len(output_ids[0]) - input_ids.shape[-1]
}]
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
return [{"error": f"Inference failed: {str(e)}"}]
def __del__(self):
"""Clean up resources when the handler is destroyed."""
if hasattr(self, 'model'):
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache() |