JMLizano commited on
Commit
dfa426f
·
1 Parent(s): ebeb048

add custom handler

Browse files
Files changed (2) hide show
  1. handler.py +94 -0
  2. requirements.txt +2 -0
handler.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ import json
6
+ from typing import Any
7
+
8
+ from unsloth import FastLanguageModel
9
+ from vllm import SamplingParams
10
+
11
+
12
+
13
+ prompt_template = """Please answer the given financial question based on the context.
14
+ **Context:** {context}
15
+ **Question:** {question}"""
16
+
17
+
18
+ class EndpointHandler:
19
+ """
20
+ Custom handler for HF Inference Endpoints.
21
+ Loads a PEFT LoRA adapter on a 4-bit base model and performs text generation.
22
+ """
23
+
24
+ def __init__(self, path: str):
25
+ """
26
+ `path` points to the repo directory mounted by the service.
27
+ We load tokenizer from `path` (this repo) and the PEFT model via AutoPeft using `path`.
28
+ AutoPeft reads adapter_config.json to find the base model.
29
+ """
30
+ self.sampling_params = SamplingParams(
31
+ temperature=0.7,
32
+ top_p=0.95,
33
+ top_k=20,
34
+ max_tokens=7 * 1024,
35
+ )
36
+
37
+ ### Policy Model ###
38
+ model, self.tokenizer = FastLanguageModel.from_pretrained(
39
+ model_name=path,
40
+ max_seq_length=8192,
41
+ load_in_4bit=True, # False for LoRA 16bit
42
+ fast_inference=True, # Enable vLLM fast inference
43
+ max_lora_rank=128,
44
+ gpu_memory_utilization=0.5, # Reduce if out of memory
45
+ full_finetuning=False,
46
+ )
47
+
48
+ self.model = FastLanguageModel.get_peft_model(
49
+ model,
50
+ r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
51
+ target_modules=[
52
+ "q_proj",
53
+ "k_proj",
54
+ "v_proj",
55
+ "o_proj",
56
+ "gate_proj",
57
+ "up_proj",
58
+ "down_proj",
59
+ ],
60
+ lora_alpha=128 * 2, # *2 speeds up training
61
+ use_gradient_checkpointing="unsloth", # Reduces memory usage
62
+ random_state=3407,
63
+ use_rslora=True, # We support rank stabilized LoRA
64
+ loftq_config=None # And LoftQ
65
+ )
66
+
67
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
68
+ """
69
+ Request format:
70
+ {
71
+ "inputs": "optional raw prompt string",
72
+ "messages": [{"role": "system/user/assistant", "content": "..."}], # optional
73
+ "parameters": { ... generation overrides ... }
74
+ }
75
+
76
+ Returns:
77
+ [ { "generated_text": "<model reply>" } ]
78
+ """
79
+ text = self.tokenizer.apply_chat_template(
80
+ data["inputs"],
81
+ tokenize=False,
82
+ add_generation_prompt=True,
83
+ enable_thinking=True, # True is the default value for enable_thinking
84
+ )
85
+
86
+ output = (
87
+ self.model.fast_generate(
88
+ [text], sampling_params=self.sampling_params, lora_request=None, use_tqdm=False
89
+ )[0]
90
+ .outputs[0]
91
+ .text
92
+ )
93
+
94
+ return output
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ vllm
2
+ unsloth