File size: 2,956 Bytes
dfa426f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
import torch
import json
from typing import Any
from unsloth import FastLanguageModel
from vllm import SamplingParams
prompt_template = """Please answer the given financial question based on the context.
**Context:** {context}
**Question:** {question}"""
class EndpointHandler:
"""
Custom handler for HF Inference Endpoints.
Loads a PEFT LoRA adapter on a 4-bit base model and performs text generation.
"""
def __init__(self, path: str):
"""
`path` points to the repo directory mounted by the service.
We load tokenizer from `path` (this repo) and the PEFT model via AutoPeft using `path`.
AutoPeft reads adapter_config.json to find the base model.
"""
self.sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
top_k=20,
max_tokens=7 * 1024,
)
### Policy Model ###
model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=path,
max_seq_length=8192,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=128,
gpu_memory_utilization=0.5, # Reduce if out of memory
full_finetuning=False,
)
self.model = FastLanguageModel.get_peft_model(
model,
r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=128 * 2, # *2 speeds up training
use_gradient_checkpointing="unsloth", # Reduces memory usage
random_state=3407,
use_rslora=True, # We support rank stabilized LoRA
loftq_config=None # And LoftQ
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
Request format:
{
"inputs": "optional raw prompt string",
"messages": [{"role": "system/user/assistant", "content": "..."}], # optional
"parameters": { ... generation overrides ... }
}
Returns:
[ { "generated_text": "<model reply>" } ]
"""
text = self.tokenizer.apply_chat_template(
data["inputs"],
tokenize=False,
add_generation_prompt=True,
enable_thinking=True, # True is the default value for enable_thinking
)
output = (
self.model.fast_generate(
[text], sampling_params=self.sampling_params, lora_request=None, use_tqdm=False
)[0]
.outputs[0]
.text
)
return output
|