File size: 3,881 Bytes
a328cf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

class EndpointHandler():
    """
    Custom handler for Hugging Face Inference Endpoints.
    This handler will be used to load the model and tokenizer, and to handle inference requests.
    """
    def __init__(self, path=""):
        """
        Initializes the model and tokenizer. This method is called only once
        when the endpoint is created.

        Args:
            path (str, optional): The path to the model directory.
                                  If not provided, it defaults to the model loaded by the endpoint.
        """
        # Get the model ID from the environment variable set by Hugging Face Inference Endpoints
        model_id = os.environ.get("HF_MODEL_ID", "Pragmanic0/Nomadic-ICDU-v8")
        
        print(f"Loading model: {model_id}...")

        # Load the tokenizer from the pretrained model
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)

        # Load the model with recommended settings
        # torch.bfloat16 is used for better performance on compatible hardware (e.g., Ampere GPUs)
        # device_map="auto" automatically distributes the model across available GPUs
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )

        # Create a text generation pipeline
        # This simplifies the process of generating text from a prompt
        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
        )
        
        print("Model and pipeline loaded successfully.")

    def __call__(self, data: dict) -> list:
        """
        This method is called for every inference request.

        Args:
            data (dict): The request payload from the user. It contains the inputs and parameters.

        Returns:
            list: A list containing the generated text in a dictionary.
        """
        # Extract the prompt from the input data
        prompt = data.get("inputs", "")
        
        # Extract generation parameters, with sensible defaults
        # These parameters can be overridden by the user in the request
        parameters = data.get("parameters", {})
        max_new_tokens = parameters.get("max_new_tokens", 512)
        temperature = parameters.get("temperature", 0.7)
        top_p = parameters.get("top_p", 0.95)
        do_sample = parameters.get("do_sample", True)

        # Apply the specific prompt template required by the Nomadic-ICDU-v8 model
        # This is crucial for getting high-quality responses from instruction-tuned models
        formatted_prompt = f"<s>[INST] {prompt} [/INST]"

        print(f"Generating text for prompt: '{prompt}'")
        
        # Use the pipeline to generate text
        # We pass the formatted prompt and the generation parameters
        try:
            generated = self.pipeline(
                formatted_prompt,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                temperature=temperature,
                top_p=top_p,
                return_full_text=False, # Only return the generated part, not the prompt
            )
            
            # The pipeline returns a list of dictionaries
            # We extract the 'generated_text' from the first element
            result = generated[0]

        except Exception as e:
            print(f"An error occurred during generation: {e}")
            # Return an error message in the expected format
            result = {"generated_text": f"Error: {e}"}
            
        print(f"Generated text: {result['generated_text']}")

        # Return the result in a list, as expected by the Inference Endpoints framework
        return [result]