dizza01
/

File size: 2,862 Bytes
42fd383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10f9067
 
 
 
42fd383
 
 
 
 
10f9067
42fd383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class EndpointHandler:
    def __init__(self, path: str = ""):
        model_dir = path or "/repository"

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir,
            trust_remote_code=True,
        )
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto",
        )
        self.model.eval()

    def _messages_to_prompt(self, inputs):
        # Use chat template only if both the method and a non-empty template exist
        if hasattr(self.tokenizer, "apply_chat_template") and getattr(
            self.tokenizer, "chat_template", None
        ):
            return self.tokenizer.apply_chat_template(
                inputs,
                tokenize=False,
                add_generation_prompt=True,
            )
        # Fallback for plain causal LMs with no chat_template (e.g. MedAlpaca)
        parts = []
        for msg in inputs:
            role = (msg.get("role") or "user").upper()
            content = msg.get("content", "")
            parts.append(f"[{role}]\n{content}")
        parts.append("[ASSISTANT]\n")
        return "\n\n".join(parts)

    def __call__(self, data):
        inputs = data.get("inputs", "")
        params = data.get("parameters", {}) or {}

        max_new_tokens = int(params.get("max_new_tokens", 128))
        temperature = float(params.get("temperature", 0.0))
        top_p = float(params.get("top_p", 1.0))
        do_sample = bool(params.get("do_sample", temperature > 0))
        repetition_penalty = float(params.get("repetition_penalty", 1.0))
        no_repeat_ngram_size = int(params.get("no_repeat_ngram_size", 0))

        if isinstance(inputs, list):
            prompt = self._messages_to_prompt(inputs)
        else:
            prompt = str(inputs)

        enc = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            out = self.model.generate(
                **enc,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=do_sample,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        generated_ids = out[0][enc["input_ids"].shape[-1]:]
        text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        return {"generated_text": text}