File size: 1,333 Bytes
507c889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline

class EndpointHandler:
    def __init__(self, model_path="djangodevloper/llama3-70b-4bit-medqa"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            trust_remote_code=True
        )
        self.pipeline = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer)
    
    def __call__(self, data):
        try:
            # Validate input format
            messages = data.get("inputs")

            # Static parameters
            params = {
                "max_new_tokens": 512,
                "temperature": 0.1,
                "eos_token_id": [
                    self.tokenizer.eos_token_id,
                    self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
                ]
            }
            # Generate response
            outputs = self.pipeline(messages, **params)
            generated_text = outputs[0]["generated_text"]
            # reply = generated_text[len(outputs):].strip()
            return {"generated_text": generated_text}

        except Exception as e:
            return {"error": f"An error occurred during inference: {e}"}