LvcidPsyche's picture
Fix syntax error in handler.py string literal
114f486 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
# Load the base model and tokenizer
base_model_id = "Nanbeige/Nanbeige4.1-3B"
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
trust_remote_code=True
)
# Load the LoRA adapter from the endpoint path
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> list:
# Get inputs
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Format the prompt using ChatML if it's not already formatted
if isinstance(inputs, str) and not inputs.startswith("<|im_start|>"):
system_prompt = "You are OpenClaw, a highly capable principal engineer and autonomous AI agent. You reason step-by-step, utilize tools effectively, and synthesize cross-domain knowledge to solve complex problems."
prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
else:
prompt = inputs
# Tokenize
encoded = self.tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
# Default generation parameters
gen_kwargs = {
"max_new_tokens": parameters.get("max_new_tokens", 512),
"temperature": parameters.get("temperature", 0.7),
"top_p": parameters.get("top_p", 0.9),
"repetition_penalty": parameters.get("repetition_penalty", 1.1),
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id
}
# Generate
with torch.no_grad():
outputs = self.model.generate(
**encoded,
**gen_kwargs
)
# Extract just the generated text
input_length = encoded.input_ids.shape[1]
response = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
return [{"generated_text": response}]