File size: 1,384 Bytes
11f4a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# handler.py
from typing import Any, Dict, List
import os
from unsloth import FastLanguageModel

class EndpointHandler:
    def __init__(self, model_id: str):
        # Called once at endpoint startup with your model repo ID/path
        max_seq = int(os.getenv("MAX_SEQ_LENGTH", 1024))
        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_id,
            max_seq_length = max_seq,
            load_in_4bit   = True,
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """

        data: {"inputs": "<str>"} or {"inputs": ["<str>", ...]}

        returns: [{"generated_text": "<str>"}, ...]

        """
        inputs = data.get("inputs", data)
        if isinstance(inputs, str):
            prompts = [inputs]
        elif isinstance(inputs, list):
            prompts = inputs
        else:
            raise ValueError(f"Unsupported inputs type: {type(inputs)}")

        outputs: List[Dict[str, Any]] = []
        for prompt in prompts:
            # generate one response per prompt
            out = self.model.generate(
                prompt,
                max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 64)),
                pad_token_id   = self.tokenizer.eos_token_id,
            )
            outputs.append({"generated_text": out})
        return outputs