File size: 2,956 Bytes
dfa426f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union

import torch
import json
from typing import Any

from unsloth import FastLanguageModel
from vllm import SamplingParams



prompt_template = """Please answer the given financial question based on the context.
**Context:** {context}
**Question:** {question}"""


class EndpointHandler:
    """
    Custom handler for HF Inference Endpoints.
    Loads a PEFT LoRA adapter on a 4-bit base model and performs text generation.
    """

    def __init__(self, path: str):
        """
        `path` points to the repo directory mounted by the service.
        We load tokenizer from `path` (this repo) and the PEFT model via AutoPeft using `path`.
        AutoPeft reads adapter_config.json to find the base model.
        """
        self.sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.95,
            top_k=20,
            max_tokens=7 * 1024,
        )

        ### Policy Model ###
        model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name=path,
            max_seq_length=8192,
            load_in_4bit=True,  # False for LoRA 16bit
            fast_inference=True,  # Enable vLLM fast inference
            max_lora_rank=128,
            gpu_memory_utilization=0.5,  # Reduce if out of memory
            full_finetuning=False,
        )

        self.model = FastLanguageModel.get_peft_model(
            model,
            r=128,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
            target_modules=[
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ],
            lora_alpha=128 * 2,  # *2 speeds up training
            use_gradient_checkpointing="unsloth",  # Reduces memory usage
            random_state=3407,
            use_rslora=True,  # We support rank stabilized LoRA
            loftq_config=None  # And LoftQ
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
        """
        Request format:
          {
            "inputs": "optional raw prompt string",
            "messages": [{"role": "system/user/assistant", "content": "..."}],  # optional
            "parameters": { ... generation overrides ... }
          }

        Returns:
          [ { "generated_text": "<model reply>" } ]
        """
        text = self.tokenizer.apply_chat_template(
            data["inputs"],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,  # True is the default value for enable_thinking
        )

        output = (
            self.model.fast_generate(
                [text], sampling_params=self.sampling_params, lora_request=None, use_tqdm=False
            )[0]
            .outputs[0]
            .text
        )

        return output