qwen-capybara-sft / handler.py
nahf's picture
Upload handler.py with huggingface_hub
45d4c43 verified
"""Custom handler for HF Inference Endpoints.
Loads Qwen2.5-0.5B base model, applies the LoRA adapter from this repo,
merges weights for faster inference, and serves predictions.
"""
from typing import Any, Dict, List, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
class EndpointHandler:
def __init__(self, path: str = ""):
base_model_id = "Qwen/Qwen2.5-0.5B"
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# Apply LoRA adapter from this repo and merge
model = PeftModel.from_pretrained(base_model, path)
self.model = model.merge_and_unload()
self.model.eval()
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_id, trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
inputs = data.get("inputs", "")
params = data.get("parameters", {})
max_new_tokens = params.get("max_new_tokens", 256)
temperature = params.get("temperature", 0.7)
top_p = params.get("top_p", 0.9)
# Support both plain string and chat-format inputs
if isinstance(inputs, str):
prompt = inputs
elif isinstance(inputs, list):
# Chat format: [{"role": "user", "content": "..."}]
prompt = self.tokenizer.apply_chat_template(
inputs, tokenize=False, add_generation_prompt=True
)
else:
prompt = str(inputs)
tokenized = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**tokenized,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
)
# Decode only the generated tokens (skip the prompt)
new_tokens = output_ids[0][tokenized["input_ids"].shape[1]:]
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return [{"generated_text": generated_text}]