import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel from typing import Dict, Any class EndpointHandler: def __init__(self, path=""): # Load the base model and tokenizer base_model_id = "Nanbeige/Nanbeige4.1-3B" self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) # Load base model base_model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, trust_remote_code=True ) # Load the LoRA adapter from the endpoint path self.model = PeftModel.from_pretrained(base_model, path) self.model.eval() def __call__(self, data: Dict[str, Any]) -> list: # Get inputs inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Format the prompt using ChatML if it's not already formatted if isinstance(inputs, str) and not inputs.startswith("<|im_start|>"): system_prompt = "You are OpenClaw, a highly capable principal engineer and autonomous AI agent. You reason step-by-step, utilize tools effectively, and synthesize cross-domain knowledge to solve complex problems." prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n" else: prompt = inputs # Tokenize encoded = self.tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") # Default generation parameters gen_kwargs = { "max_new_tokens": parameters.get("max_new_tokens", 512), "temperature": parameters.get("temperature", 0.7), "top_p": parameters.get("top_p", 0.9), "repetition_penalty": parameters.get("repetition_penalty", 1.1), "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id } # Generate with torch.no_grad(): outputs = self.model.generate( **encoded, **gen_kwargs ) # Extract just the generated text input_length = encoded.input_ids.shape[1] response = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) return [{"generated_text": response}]