# handler.py from __future__ import annotations from typing import Any, Dict, List, Union import torch from transformers import AutoModelForCausalLM, AutoTokenizer Json = Dict[str, Any] Messages = List[Dict[str, str]] # [{"role":"user|assistant|system", "content":"..."}] def _is_messages(x: Any) -> bool: return ( isinstance(x, list) and len(x) > 0 and all(isinstance(m, dict) and "role" in m and "content" in m for m in x) ) class EndpointHandler: """ Hugging Face Inference Endpoints custom handler. Expects: - request body is a dict - always contains `inputs` - may contain `parameters` for generation """ def __init__(self, model_dir: str): self.model_dir = model_dir # Pick dtype/device self.device = "cuda" if torch.cuda.is_available() else "cpu" if self.device == "cuda": # bfloat16 is usually safe on A100/H100; if your instance doesn't support bf16, change to float16 self.dtype = torch.bfloat16 else: self.dtype = torch.float32 # IMPORTANT: trust_remote_code=True because repo contains AsteriskForCausalLM.py + auto_map self.tokenizer = AutoTokenizer.from_pretrained( model_dir, trust_remote_code=True, use_fast=True, ) # Make sure pad token exists (your config uses pad_token_id=2 which equals eos_token_id in many llama-like models) if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, torch_dtype=self.dtype, device_map="auto" if self.device == "cuda" else None, ) if self.device != "cuda": self.model.to(self.device) self.model.eval() @torch.inference_mode() def __call__(self, data: Json) -> Union[Json, List[Json]]: inputs = data.get("inputs", "") params = data.get("parameters", {}) or {} # Generation defaults (can be overridden via `parameters`) max_new_tokens = int(params.get("max_new_tokens", 256)) temperature = float(params.get("temperature", 0.7)) top_p = float(params.get("top_p", 0.95)) top_k = int(params.get("top_k", 0)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) do_sample = bool(params.get("do_sample", temperature > 0)) num_beams = int(params.get("num_beams", 1)) def _one(item: Any) -> Json: # Accept: # 1) string prompt # 2) messages list: [{"role":"user","content":"..."}] # 3) dict {"messages":[...]} (common chat style) if isinstance(item, dict) and "messages" in item: item = item["messages"] if _is_messages(item): # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured input_ids = self.tokenizer.apply_chat_template( item, return_tensors="pt", add_generation_prompt=True, ) else: if not isinstance(item, str): item = str(item) enc = self.tokenizer(item, return_tensors="pt") input_ids = enc["input_ids"] input_ids = input_ids.to(self.model.device) input_len = input_ids.shape[-1] gen_ids = self.model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature if do_sample else None, top_p=top_p if do_sample else None, top_k=top_k if do_sample and top_k > 0 else None, num_beams=num_beams, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Only return newly generated tokens new_tokens = gen_ids[0, input_len:] text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) return {"generated_text": text} # Batch support if isinstance(inputs, list) and not _is_messages(inputs): return [_one(x) for x in inputs] else: return _one(inputs)