yol146 commited on
Commit
72ed73b
·
1 Parent(s): c3375d0

modify the handler

Browse files
Files changed (1) hide show
  1. handler.py +168 -111
handler.py CHANGED
@@ -1,9 +1,17 @@
1
  import os
2
  import torch
3
- from typing import Dict, List, Any
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from threading import Thread
6
 
 
 
 
 
 
 
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  """
@@ -13,7 +21,7 @@ class EndpointHandler:
13
  path (str): Path to the model directory
14
  """
15
  # Set default parameters for inference
16
- self.max_new_tokens = 4096
17
  self.temperature = 0.7
18
  self.top_p = 0.9
19
  self.do_sample = True
@@ -22,41 +30,53 @@ class EndpointHandler:
22
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
  self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
24
 
25
- # Load tokenizer
26
- self.tokenizer = AutoTokenizer.from_pretrained(path)
27
-
28
- # Load model with appropriate settings
29
- self.model = AutoModelForCausalLM.from_pretrained(
30
- path,
31
- torch_dtype=self.dtype,
32
- device_map="auto" if self.device == "cuda" else None,
33
- trust_remote_code=True
34
- )
35
-
36
- # Move model to device if CPU
37
- if self.device == "cpu":
38
- self.model = self.model.to(self.device)
39
-
40
- # Set model to evaluation mode
41
- self.model.eval()
42
-
43
- print(f"Model loaded on {self.device} using {self.dtype}")
44
-
45
- def format_prompt(self, prompt: str) -> str:
46
- """
47
- Format the user prompt for Phi-4 model.
48
-
49
- Args:
50
- prompt (str): User input prompt
51
 
52
- Returns:
53
- str: Formatted prompt
54
- """
55
- # For Phi-4-mini-instruct, the prompt format is simple
56
- # You may need to adjust this based on your specific fine-tuning
57
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
60
  """
61
  Process the input data and generate a response using the Phi-4 model.
62
 
@@ -64,89 +84,126 @@ class EndpointHandler:
64
  data (Dict[str, Any]): Input data containing the prompt and generation parameters
65
 
66
  Returns:
67
- Dict[str, Any]: Model response
68
  """
69
- # Extract input parameters with defaults
70
- prompt = data.pop("inputs", "")
71
- parameters = data.pop("parameters", {})
72
-
73
- # Get generation parameters with fallbacks to defaults
74
- max_new_tokens = parameters.get("max_new_tokens", self.max_new_tokens)
75
- temperature = parameters.get("temperature", self.temperature)
76
- top_p = parameters.get("top_p", self.top_p)
77
- do_sample = parameters.get("do_sample", self.do_sample)
78
- stream = parameters.get("stream", False)
79
-
80
- # Format the prompt according to model requirements
81
- formatted_prompt = self.format_prompt(prompt)
82
-
83
- # Tokenize the input
84
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
85
-
86
- # Handle streaming if requested
87
- if stream:
88
- return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample)
89
- else:
90
- return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample)
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
93
  """Generate text non-streaming mode"""
94
- with torch.no_grad():
95
- outputs = self.model.generate(
96
- **inputs,
97
- max_new_tokens=max_new_tokens,
98
- temperature=temperature,
99
- top_p=top_p,
100
- do_sample=do_sample,
101
- pad_token_id=self.tokenizer.eos_token_id
102
- )
103
-
104
- # Decode the generated text
105
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
106
-
107
- # Return only the newly generated text (without the prompt)
108
- prompt_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
109
- response_text = generated_text[prompt_length:]
110
-
111
- return {"generated_text": response_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample):
114
  """Generate text in streaming mode"""
115
- # Create a streamer object
116
- streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
117
-
118
- # Set up generation in a separate thread
119
- generation_kwargs = dict(
120
- **inputs,
121
- streamer=streamer,
122
- max_new_tokens=max_new_tokens,
123
- temperature=temperature,
124
- top_p=top_p,
125
- do_sample=do_sample,
126
- pad_token_id=self.tokenizer.eos_token_id
127
- )
128
-
129
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
130
- thread.start()
131
-
132
- # Determine input text length to strip it from outputs
133
- prompt_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
134
- prompt_length = len(prompt_text)
135
-
136
- # Stream the output
137
- def generate_stream():
138
- # Skip the prompt part in the first chunk
139
- first_chunk = True
140
- for text in streamer:
141
- if first_chunk:
142
- # Only yield new tokens, not the original prompt
143
- if len(text) > prompt_length:
144
- yield {"generated_text": text[prompt_length:]}
145
- first_chunk = False
146
- else:
147
- yield {"generated_text": text}
148
-
149
- return generate_stream()
 
 
 
 
 
 
 
150
 
151
  # For local testing
152
  if __name__ == "__main__":
 
1
  import os
2
  import torch
3
+ import logging
4
+ from typing import Dict, List, Any, Union, Generator
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
 
8
+ # Set up logging
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
  class EndpointHandler:
16
  def __init__(self, path=""):
17
  """
 
21
  path (str): Path to the model directory
22
  """
23
  # Set default parameters for inference
24
+ self.max_new_tokens = 1024 # Reduced from 4096 to avoid memory issues
25
  self.temperature = 0.7
26
  self.top_p = 0.9
27
  self.do_sample = True
 
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
32
 
33
+ logger.info(f"Initializing model from {path} on {self.device}")
34
+
35
+ try:
36
+ # Load tokenizer - use original model ID as fallback
37
+ # This helps with common tokenizer mismatch issues
38
+ try:
39
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
40
+ logger.info(f"Loaded tokenizer from local path")
41
+ except Exception as e:
42
+ logger.warning(f"Failed to load tokenizer from local path: {e}")
43
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct")
44
+ logger.info("Loaded tokenizer from microsoft/Phi-4-mini-instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Ensure tokenizer has EOS token set
47
+ if self.tokenizer.eos_token_id is None:
48
+ logger.warning("EOS token not set in tokenizer, using default")
49
+ self.tokenizer.eos_token_id = 199999 # Phi-4's default EOS token
50
+
51
+ # Load model with appropriate settings
52
+ self.model = AutoModelForCausalLM.from_pretrained(
53
+ path,
54
+ torch_dtype=self.dtype,
55
+ device_map="auto" if self.device == "cuda" else None,
56
+ trust_remote_code=True
57
+ )
58
+
59
+ # Move model to device if CPU
60
+ if self.device == "cpu":
61
+ self.model = self.model.to(self.device)
62
+
63
+ # Set model to evaluation mode
64
+ self.model.eval()
65
+
66
+ # Print diagnostic information
67
+ logger.info(f"Model loaded on {self.device} using {self.dtype}")
68
+ logger.info(f"Tokenizer vocabulary size: {len(self.tokenizer)}")
69
+ logger.info(f"Model vocabulary size: {self.model.config.vocab_size}")
70
+ logger.info(f"Model embedding size: {self.model.get_input_embeddings().weight.shape}")
71
+
72
+ if len(self.tokenizer) != self.model.config.vocab_size:
73
+ logger.warning(f"Tokenizer vocab size ({len(self.tokenizer)}) doesn't match model vocab size ({self.model.config.vocab_size})")
74
+
75
+ except Exception as e:
76
+ logger.error(f"Error during model initialization: {e}")
77
+ raise
78
 
79
+ def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, str], Generator]:
80
  """
81
  Process the input data and generate a response using the Phi-4 model.
82
 
 
84
  data (Dict[str, Any]): Input data containing the prompt and generation parameters
85
 
86
  Returns:
87
+ Dict[str, str] or Generator: Model response or stream
88
  """
89
+ try:
90
+ # Extract input parameters with defaults
91
+ if "inputs" not in data:
92
+ logger.warning("No 'inputs' field in request data")
93
+ return {"error": "Missing 'inputs' field in request"}
94
+
95
+ prompt = data.get("inputs", "")
96
+ parameters = data.get("parameters", {})
97
+
98
+ logger.info(f"Processing input with {len(prompt)} characters")
99
+
100
+ # Get generation parameters with fallbacks to defaults
101
+ max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 2048)
102
+ temperature = parameters.get("temperature", self.temperature)
103
+ top_p = parameters.get("top_p", self.top_p)
104
+ do_sample = parameters.get("do_sample", self.do_sample)
105
+ stream = parameters.get("stream", False)
106
+
107
+ # Tokenize the input safely
108
+ inputs = self.tokenizer(prompt, return_tensors="pt")
109
+ logger.info(f"Input tokens shape: {inputs.input_ids.shape}")
110
+
111
+ # Move to device
112
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
113
+
114
+ # Handle streaming if requested
115
+ if stream:
116
+ return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample)
117
+ else:
118
+ return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample)
119
+
120
+ except Exception as e:
121
+ logger.error(f"Error during generation: {e}")
122
+ return {"error": str(e)}
123
 
124
  def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
125
  """Generate text non-streaming mode"""
126
+ try:
127
+ with torch.no_grad():
128
+ generation_config = {
129
+ "max_new_tokens": max_new_tokens,
130
+ "temperature": temperature,
131
+ "top_p": top_p,
132
+ "do_sample": do_sample,
133
+ "pad_token_id": self.tokenizer.eos_token_id
134
+ }
135
+
136
+ logger.info(f"Generating with config: {generation_config}")
137
+
138
+ outputs = self.model.generate(
139
+ inputs.input_ids,
140
+ attention_mask=inputs.attention_mask if hasattr(inputs, 'attention_mask') else None,
141
+ **generation_config
142
+ )
143
+
144
+ # Decode the generated text
145
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
146
+
147
+ # Return only the newly generated text (without the prompt)
148
+ input_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
149
+
150
+ if generated_text.startswith(input_text):
151
+ response_text = generated_text[len(input_text):]
152
+ else:
153
+ # Fallback if the decoded text doesn't start with the input
154
+ response_text = generated_text
155
+
156
+ logger.info(f"Generated {len(response_text)} characters")
157
+ return {"generated_text": response_text}
158
+
159
+ except Exception as e:
160
+ logger.error(f"Error in _generate: {e}")
161
+ return {"error": str(e)}
162
 
163
  def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample):
164
  """Generate text in streaming mode"""
165
+ try:
166
+ # Create a streamer object
167
+ streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
168
+
169
+ # Set up generation in a separate thread
170
+ generation_kwargs = {
171
+ "input_ids": inputs.input_ids,
172
+ "attention_mask": inputs.attention_mask if hasattr(inputs, 'attention_mask') else None,
173
+ "streamer": streamer,
174
+ "max_new_tokens": max_new_tokens,
175
+ "temperature": temperature,
176
+ "top_p": top_p,
177
+ "do_sample": do_sample,
178
+ "pad_token_id": self.tokenizer.eos_token_id
179
+ }
180
+
181
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
182
+ thread.start()
183
+
184
+ # Determine input text length to strip it from outputs
185
+ input_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
186
+
187
+ # Stream the output
188
+ def generate_stream():
189
+ # Skip the prompt part in the first chunk
190
+ full_text = ""
191
+ for text in streamer:
192
+ full_text += text
193
+ # Only return the part after the prompt
194
+ if full_text.startswith(input_text):
195
+ current_response = full_text[len(input_text):]
196
+ else:
197
+ current_response = full_text
198
+ yield {"generated_text": current_response}
199
+
200
+ return generate_stream()
201
+
202
+ except Exception as e:
203
+ logger.error(f"Error in _generate_stream: {e}")
204
+ def error_stream():
205
+ yield {"error": str(e)}
206
+ return error_stream()
207
 
208
  # For local testing
209
  if __name__ == "__main__":