File size: 5,865 Bytes
e8fdd65
ee72852
 
 
096d229
ee72852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b815173
267fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee72852
 
 
 
 
 
e8fdd65
b815173
ee72852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8fdd65
ee72852
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the model and tokenizer when the endpoint starts.
        
        Args:
            path (str): Path to the model files
        """
        logger.info(f"Loading model from {path}")
        
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        
        # Try to load without quantization first
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                trust_remote_code=True,
                load_in_8bit=False,
                load_in_4bit=False
            )
        except Exception as e:
            logger.warning(f"Failed to load without quantization: {e}")
            # Fallback: try with different settings
            self.model = AutoModelForCausalLM.from_pretrained(
                path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                trust_remote_code=True,
                use_safetensors=True
            )
        
        # Set pad token if it doesn't exist
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        logger.info("Model loaded successfully")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the inference request.
        
        Args:
            data (Dict[str, Any]): Request data containing:
                - inputs (str): The input text/prompt
                - parameters (dict, optional): Generation parameters
                    - max_new_tokens (int): Maximum tokens to generate (default: 256)
                    - temperature (float): Sampling temperature (default: 0.7)
                    - top_p (float): Top-p sampling (default: 0.9)
                    - do_sample (bool): Whether to use sampling (default: True)
                    - repetition_penalty (float): Repetition penalty (default: 1.1)
                    - return_full_text (bool): Return full text including input (default: False)
                
        Returns:
            List[Dict[str, Any]]: Generated text response
        """
        try:
            # Extract inputs
            inputs = data.get("inputs", "")
            if not inputs:
                return [{"error": "No input text provided"}]
            
            # Extract generation parameters
            parameters = data.get("parameters", {})
            max_new_tokens = parameters.get("max_new_tokens", 256)
            temperature = parameters.get("temperature", 0.7)
            top_p = parameters.get("top_p", 0.9)
            do_sample = parameters.get("do_sample", True)
            repetition_penalty = parameters.get("repetition_penalty", 1.1)
            return_full_text = parameters.get("return_full_text", False)
            
            # Format the input as a chat message if it doesn't already contain instruction formatting
            if not any(marker in inputs.lower() for marker in ["[inst]", "<s>", "### instruction", "user:", "assistant:"]):
                formatted_input = f"[INST] {inputs} [/INST]"
            else:
                formatted_input = inputs
            
            # Tokenize input
            input_ids = self.tokenizer.encode(
                formatted_input, 
                return_tensors="pt",
                truncation=True,
                max_length=2048  # Reasonable limit for input
            )
            
            # Move to GPU if available
            if torch.cuda.is_available():
                input_ids = input_ids.cuda()
            
            # Generate response
            with torch.no_grad():
                output_ids = self.model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=do_sample,
                    repetition_penalty=repetition_penalty,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    use_cache=True
                )
            
            # Decode the response
            if return_full_text:
                generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
            else:
                # Only return the newly generated tokens
                new_tokens = output_ids[0][input_ids.shape[-1]:]
                generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
            
            # Clean up the response
            generated_text = generated_text.strip()
            
            # Return in the expected format
            return [{
                "generated_text": generated_text,
                "input_length": input_ids.shape[-1],
                "output_length": len(output_ids[0]) - input_ids.shape[-1]
            }]
            
        except Exception as e:
            logger.error(f"Error during inference: {str(e)}")
            return [{"error": f"Inference failed: {str(e)}"}]

    def __del__(self):
        """Clean up resources when the handler is destroyed."""
        if hasattr(self, 'model'):
            del self.model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()