ambari-7b-finetuned / handler.py
Akshaymp's picture
Upload handler.py with huggingface_hub
abc64f0 verified
from typing import Dict, List, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
class EndpointHandler:
def __init__(self, path=""):
# 1. Load the adapter config from the local path (where the repo is cloned on the endpoint)
self.peft_config = PeftConfig.from_pretrained(path)
# 2. Load the Base Model
# We use device_map="auto" to use the GPU available in the endpoint
# torch_dtype=torch.float16 is standard for inference on T4/A10G
self.base_model = AutoModelForCausalLM.from_pretrained(
self.peft_config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# 3. Load the Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.peft_config.base_model_name_or_path,
trust_remote_code=True
)
# 4. Load the Adapter (Fine-tuned weights)
self.model = PeftModel.from_pretrained(self.base_model, path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj: `Dict[str, Any]`):
Input data payload. Expects a key 'inputs' containing the prompt text.
Optional parameters: 'temperature', 'max_new_tokens', 'top_p', etc.
"""
# Get inputs
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Default generation parameters
max_new_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.9)
# Handle list of inputs or single string
if isinstance(inputs, list):
inputs = inputs[0] # Simplification for single-turn
# Tokenize
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device)
# Generate
with torch.no_grad():
output_ids = self.model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode
# We slice [input_ids.shape[1]:] to return ONLY the generated response, not the prompt
generated_text = self.tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
return [{"generated_text": generated_text}]