File size: 992 Bytes
8c26326
cca2aed
 
 
 
3dfb844
cca2aed
339211f
8c26326
 
3dfb844
cca2aed
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


class EndpointHandler():
    def __init__(self, path=""):
        model = AutoModelForCausalLM.from_pretrained("hyperspaceai/hyperEngine_phi3_128k", device_map="auto", torch_dtype="auto", trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
        self.pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

    def __call__(self, data:Dict[str, Any]) :
        messages = data.pop("messages", None)
        generation_args = data.pop("generation_args", None)

        if generation_args==None :
            generation_args = {
                "max_new_tokens": 500,
                "return_full_text": False,
                "temperature": 0.0,
                "do_sample": False,
            }

        output = self.pipe(messages, **generation_args)
        return output[0]['generated_text']