from __future__ import annotations from typing import Any, Dict, List, Optional, Union import torch import json from typing import Any from unsloth import FastLanguageModel from vllm import SamplingParams prompt_template = """Please answer the given financial question based on the context. **Context:** {context} **Question:** {question}""" class EndpointHandler: """ Custom handler for HF Inference Endpoints. Loads a PEFT LoRA adapter on a 4-bit base model and performs text generation. """ def __init__(self, path: str): """ `path` points to the repo directory mounted by the service. We load tokenizer from `path` (this repo) and the PEFT model via AutoPeft using `path`. AutoPeft reads adapter_config.json to find the base model. """ self.sampling_params = SamplingParams( temperature=0.7, top_p=0.95, top_k=20, max_tokens=7 * 1024, ) ### Policy Model ### model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=path, max_seq_length=8192, load_in_4bit=True, # False for LoRA 16bit fast_inference=True, # Enable vLLM fast inference max_lora_rank=128, gpu_memory_utilization=0.5, # Reduce if out of memory full_finetuning=False, ) self.model = FastLanguageModel.get_peft_model( model, r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=128 * 2, # *2 speeds up training use_gradient_checkpointing="unsloth", # Reduces memory usage random_state=3407, use_rslora=True, # We support rank stabilized LoRA loftq_config=None # And LoftQ ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: """ Request format: { "inputs": "optional raw prompt string", "messages": [{"role": "system/user/assistant", "content": "..."}], # optional "parameters": { ... generation overrides ... } } Returns: [ { "generated_text": "" } ] """ text = self.tokenizer.apply_chat_template( data["inputs"], tokenize=False, add_generation_prompt=True, enable_thinking=True, # True is the default value for enable_thinking ) output = ( self.model.fast_generate( [text], sampling_params=self.sampling_params, lora_request=None, use_tqdm=False )[0] .outputs[0] .text ) return output