File size: 1,129 Bytes
2cc78fb
1505272
2cc78fb
1505272
 
 
 
 
0e3d77c
1505272
 
 
 
2cc78fb
1505272
 
 
 
 
 
 
2cc78fb
1505272
2cc78fb
1505272
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs):
        """
        Initialize the handler. This is required by Hugging Face Inference Endpoints.
        """
        self.model_id = "vrouco/jais-13b-custom"  
        
        # Load the tokenizer and model with trust_remote_code=True
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True)

    def __call__(self, data):
        """
        This function is required to process inference requests.
        """
        prompt = data.get("inputs", "")
        if not prompt:
            return {"error": "No input text provided"}

        input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids

        with torch.no_grad():
            output_ids = self.model.generate(input_ids, max_length=200)

        response_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return {"generated_text": response_text}