File size: 3,458 Bytes
7daf744 828d794 8e6ec71 7daf744 828d794 7daf744 828d794 73a8f69 828d794 7daf744 73a8f69 837e58c 73a8f69 837e58c 73a8f69 837e58c 0be15f2 73a8f69 4c4a40c 262acca 4c4a40c 262acca 4c4a40c 7daf744 262acca 4c4a40c 262acca 4c4a40c 828d794 262acca 35a2486 73a8f69 7daf744 73a8f69 35a2486 7daf744 73a8f69 35a2486 828d794 73a8f69 7daf744 4c4a40c 7daf744 73a8f69 4c4a40c 262acca 73a8f69 828d794 73a8f69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from typing import Dict, List, Any
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
# Get HuggingFace token for gated model access
hf_token = os.getenv("HF_TOKEN")
# Load model and tokenizer with authentication
self.tokenizer = AutoTokenizer.from_pretrained(
path,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
token=hf_token
)
# Set pad token if not exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Simple handler that mimics local LLM behavior for RemoteLLM
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Handle different input formats that RemoteLLM sends
if isinstance(inputs, dict) and "messages" in inputs:
messages = inputs["messages"]
elif isinstance(inputs, list):
messages = inputs
else:
# Fallback - treat as direct text
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": str(inputs)}
]
# Check if this is a continuation (has assistant message)
has_assistant = any(msg.get("role") == "assistant" for msg in messages)
# Apply chat template exactly like BrickGPT does locally
if has_assistant:
prompt = self.tokenizer.apply_chat_template(
messages,
continue_final_message=True,
return_tensors='pt'
)
else:
prompt = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors='pt'
)
# Move to device
input_ids = prompt.to(self.model.device)
attention_mask = torch.ones_like(input_ids)
# Generation parameters - use BrickGPT defaults
generation_params = {
"max_new_tokens": parameters.get("max_new_tokens", 10),
"temperature": parameters.get("temperature", 0.6),
"top_k": parameters.get("top_k", 20),
"top_p": parameters.get("top_p", 1.0),
"pad_token_id": self.tokenizer.pad_token_id,
"do_sample": True,
"num_return_sequences": 1,
"return_dict_in_generate": True,
}
# Generate
with torch.no_grad():
output_dict = self.model.generate(
input_ids,
attention_mask=attention_mask,
**generation_params
)
# Extract new tokens and decode EXACTLY like local LLM
input_length = input_ids.shape[1]
result_ids = output_dict['sequences'][0][input_length:]
# CRITICAL: Decode exactly like local LLM (no skip_special_tokens parameter)
generated_text = self.tokenizer.decode(result_ids)
# Return in format RemoteLLM expects
return [{"generated_text": generated_text}] |