File size: 2,149 Bytes
009ab0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class EndpointHandler():
    def __init__(self, path=""):
        """
        Initialize the model and tokenizer using the local path.
        Uses Zenith Coder v1.1 custom code (modeling_deepseek.py, configuration_deepseek.py, tokenization_deepseek_fast.py).
        """
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map="auto"
        )
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Accepts a dictionary with the prompt and optional `max_new_tokens`.
        Returns generated text.
        """
        prompt = data.get("inputs") or data.get("prompt")
        if not prompt or not isinstance(prompt, str):
            return [{"error": "No valid input prompt provided."}]
        
        max_new_tokens = int(data.get("max_new_tokens", 256))
        temperature = float(data.get("temperature", 1.0))
        top_p = float(data.get("top_p", 0.95))
        top_k = int(data.get("top_k", 50))

        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        if torch.cuda.is_available():
            input_ids = input_ids.cuda()
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                input_ids,
                do_sample=True,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        # Skip the prompt part
        gen_text = self.tokenizer.decode(
            generated_ids[0][input_ids.shape[1]:],
            skip_special_tokens=True
        )
        return [{"generated_text": gen_text}]