from typing import Dict, List, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel, PeftConfig class EndpointHandler: def __init__(self, path=""): # 1. Load the adapter config from the local path (where the repo is cloned on the endpoint) self.peft_config = PeftConfig.from_pretrained(path) # 2. Load the Base Model # We use device_map="auto" to use the GPU available in the endpoint # torch_dtype=torch.float16 is standard for inference on T4/A10G self.base_model = AutoModelForCausalLM.from_pretrained( self.peft_config.base_model_name_or_path, return_dict=True, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) # 3. Load the Tokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.peft_config.base_model_name_or_path, trust_remote_code=True ) # 4. Load the Adapter (Fine-tuned weights) self.model = PeftModel.from_pretrained(self.base_model, path) self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (:obj: `Dict[str, Any]`): Input data payload. Expects a key 'inputs' containing the prompt text. Optional parameters: 'temperature', 'max_new_tokens', 'top_p', etc. """ # Get inputs inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Default generation parameters max_new_tokens = parameters.get("max_new_tokens", 512) temperature = parameters.get("temperature", 0.7) top_p = parameters.get("top_p", 0.9) # Handle list of inputs or single string if isinstance(inputs, list): inputs = inputs[0] # Simplification for single-turn # Tokenize input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device) # Generate with torch.no_grad(): output_ids = self.model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) # Decode # We slice [input_ids.shape[1]:] to return ONLY the generated response, not the prompt generated_text = self.tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) return [{"generated_text": generated_text}]