File size: 7,127 Bytes
ec5bb4e
0b6ae9b
051c5a5
d36359f
051c5a5
72ed73b
ec5bb4e
 
051c5a5
 
d36359f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
051c5a5
ec5bb4e
051c5a5
 
 
 
 
 
 
 
 
72ed73b
051c5a5
 
72ed73b
051c5a5
 
72ed73b
051c5a5
 
72ed73b
051c5a5
 
72ed73b
051c5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
1b1b06a
051c5a5
 
1b1b06a
051c5a5
 
 
 
 
 
 
 
 
 
 
1b1b06a
051c5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b1b06a
051c5a5
 
 
 
 
ec5bb4e
051c5a5
 
bb64432
051c5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b6ae9b
051c5a5
 
 
 
 
 
 
 
 
 
 
 
 
0b6ae9b
051c5a5
 
 
 
 
 
 
 
 
 
 
 
 
0b6ae9b
051c5a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import json
import torch
from transformers import pipeline, AutoTokenizer, AutoConfig
from typing import Dict, List, Any, Optional, Union

class EndpointHandler:
    def __init__(self, path=""):
        # Initialize model and tokenizer
        self.model_path = path if path else os.environ.get("MODEL_PATH", "")
        
        # Fix RoPE scaling configuration
        try:
            config = AutoConfig.from_pretrained(self.model_path)
            
            # Check if config has rope_scaling attribute and fix the short_factor length
            if hasattr(config, "rope_scaling") and "short_factor" in config.rope_scaling:
                short_factor = config.rope_scaling["short_factor"]
                if len(short_factor) == 48:  # If we have the problematic length
                    print("Fixing rope_scaling short_factor length from 48 to 64")
                    # Pad to length 64
                    padded_short_factor = list(short_factor) + [0.0] * (64 - len(short_factor))
                    config.rope_scaling["short_factor"] = padded_short_factor
                    
                    # Save the fixed config
                    config.save_pretrained(self.model_path)
                    print("Fixed config saved")
        except Exception as e:
            print(f"Warning: Could not fix RoPE scaling configuration: {str(e)}")
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        
        # Create text generation pipeline
        self.pipe = pipeline(
            "text-generation",
            model=self.model_path,
            tokenizer=self.tokenizer,
            torch_dtype=torch.float16,
            device_map="auto",
            return_full_text=False  # Only return the generated text, not the prompt
        )
        
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Handle inference request in OpenAI-like format"""
        try:
            # Parse input data
            inputs = self._parse_input(data)
            
            # Generate response
            outputs = self._generate(inputs)
            
            # Format response in OpenAI-like format
            return self._format_response(outputs, inputs)
        except Exception as e:
            return {
                "error": {
                    "message": str(e),
                    "type": "invalid_request_error",
                    "code": 400
                }
            }
    
    def _parse_input(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Parse input data to extract generation parameters"""
        # Extract messages
        messages = data.get("messages", [])
        if not messages:
            raise ValueError("No messages provided")
        
        # Convert messages to prompt
        prompt = self._convert_messages_to_prompt(messages)
        
        # Extract generation parameters with defaults
        generation_params = {
            "max_tokens": data.get("max_tokens", 256),
            "temperature": data.get("temperature", 0.7),
            "top_p": data.get("top_p", 1.0),
            "n": data.get("n", 1),
            "stream": data.get("stream", False),
            "stop": data.get("stop", None),
            "presence_penalty": data.get("presence_penalty", 0.0),
            "frequency_penalty": data.get("frequency_penalty", 0.0),
        }
        
        return {
            "prompt": prompt,
            "messages": messages,
            "generation_params": generation_params
        }
    
    def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
        """Convert list of messages to a prompt string"""
        prompt = ""
        for message in messages:
            role = message.get("role", "")
            content = message.get("content", "")
            
            if role == "system":
                prompt += f"System: {content}\n\n"
            elif role == "user":
                prompt += f"User: {content}\n\n"
            elif role == "assistant":
                prompt += f"Assistant: {content}\n\n"
            
        # Add final assistant prompt
        prompt += "Assistant: "
        return prompt
    
    def _generate(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Generate response using the pipeline"""
        prompt = inputs["prompt"]
        params = inputs["generation_params"]
        
        # Count input tokens
        input_tokens = len(self.tokenizer.encode(prompt))
        
        # Convert OpenAI-like parameters to pipeline parameters
        generation_kwargs = {
            "max_new_tokens": params["max_tokens"],
            "temperature": params["temperature"],
            "top_p": params["top_p"],
            "num_return_sequences": params["n"],
            "do_sample": params["temperature"] > 0,
        }
        
        # Add stopping criteria if provided
        if params["stop"]:
            generation_kwargs["stopping_criteria"] = params["stop"]
        
        # Generate output using the pipeline
        pipeline_outputs = self.pipe(
            prompt,
            **generation_kwargs
        )
        
        # Extract generated texts
        generated_texts = []
        for output in pipeline_outputs:
            gen_text = output["generated_text"]
            
            # Apply stop sequences if provided
            if params["stop"]:
                for stop in params["stop"]:
                    if stop in gen_text:
                        gen_text = gen_text[:gen_text.find(stop)]
            
            generated_texts.append(gen_text)
        
        # Count completion tokens
        completion_tokens = [len(self.tokenizer.encode(text)) for text in generated_texts]
        
        return {
            "generated_texts": generated_texts,
            "prompt_tokens": input_tokens,
            "completion_tokens": completion_tokens,
        }
    
    def _format_response(self, outputs: Dict[str, Any], inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Format response in OpenAI-like format"""
        generated_texts = outputs["generated_texts"]
        prompt_tokens = outputs["prompt_tokens"]
        completion_tokens = outputs["completion_tokens"]
        
        choices = []
        for i, text in enumerate(generated_texts):
            choices.append({
                "index": i,
                "message": {
                    "role": "assistant",
                    "content": text
                },
                "finish_reason": "stop"
            })
        
        return {
            "id": f"cmpl-{hash(inputs['prompt']) % 10000}",
            "object": "chat.completion",
            "created": int(torch.cuda.current_device()) if torch.cuda.is_available() else 0,
            "model": os.path.basename(self.model_path),
            "choices": choices,
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": sum(completion_tokens),
                "total_tokens": prompt_tokens + sum(completion_tokens)
            }
        }