Yong Liu commited on
Commit
4aa4d08
·
1 Parent(s): 093ad9c

update handler

Browse files
Files changed (1) hide show
  1. handler.py +54 -7
handler.py CHANGED
@@ -25,14 +25,37 @@ class EndpointHandler:
25
  # Load tokenizer
26
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
27
 
 
 
 
 
28
  # Load model directly without pipeline
29
  self.model = AutoModelForCausalLM.from_pretrained(
30
  self.model_path,
31
  torch_dtype=torch.float16,
32
  device_map="auto"
33
  )
 
 
 
 
34
  print("Model loaded successfully")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
37
  """Handle inference request in OpenAI-like format or HuggingFace Inference API format"""
38
  try:
@@ -142,8 +165,13 @@ class EndpointHandler:
142
  prompt = inputs["prompt"]
143
  params = inputs["generation_params"]
144
 
145
- # Tokenize input
146
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
 
 
 
 
 
147
 
148
  # Count input tokens
149
  input_tokens = input_ids.shape[1]
@@ -159,11 +187,30 @@ class EndpointHandler:
159
  }
160
 
161
  # Generate output
162
- with torch.no_grad():
163
- outputs = self.model.generate(
164
- input_ids,
165
- **generation_kwargs
166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # Decode output
169
  generated_texts = []
 
25
  # Load tokenizer
26
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
27
 
28
+ # Determine the device to use
29
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ print(f"Using device: {self.device}")
31
+
32
  # Load model directly without pipeline
33
  self.model = AutoModelForCausalLM.from_pretrained(
34
  self.model_path,
35
  torch_dtype=torch.float16,
36
  device_map="auto"
37
  )
38
+ # Ensure model is on the correct device
39
+ if torch.cuda.is_available():
40
+ self.model = self.model.cuda()
41
+
42
  print("Model loaded successfully")
43
 
44
+ # For Phi3 models, monkey patch the RotaryEmbedding
45
+ try:
46
+ from transformers.models.phi3.modeling_phi3 import PhiRotaryEmbedding
47
+ original_forward = PhiRotaryEmbedding.forward
48
+
49
+ def patched_forward(self, position_ids, query, key, value=None):
50
+ # Ensure position_ids is on the same device as query
51
+ position_ids = position_ids.to(query.device)
52
+ return original_forward(self, position_ids, query, key, value)
53
+
54
+ PhiRotaryEmbedding.forward = patched_forward
55
+ print("Successfully patched PhiRotaryEmbedding.forward")
56
+ except Exception as e:
57
+ print(f"Could not patch PhiRotaryEmbedding: {str(e)}")
58
+
59
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
60
  """Handle inference request in OpenAI-like format or HuggingFace Inference API format"""
61
  try:
 
165
  prompt = inputs["prompt"]
166
  params = inputs["generation_params"]
167
 
168
+ # Get the model's device
169
+ device = next(self.model.parameters()).device
170
+ print(f"Model is on device: {device}")
171
+
172
+ # Tokenize input and ensure it's on the correct device
173
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
174
+ print(f"Input tensor device: {input_ids.device}")
175
 
176
  # Count input tokens
177
  input_tokens = input_ids.shape[1]
 
187
  }
188
 
189
  # Generate output
190
+ try:
191
+ with torch.no_grad():
192
+ outputs = self.model.generate(
193
+ input_ids,
194
+ **generation_kwargs
195
+ )
196
+ print(f"Output tensor device: {outputs.device}")
197
+ except RuntimeError as e:
198
+ if "Expected all tensors to be on the same device" in str(e):
199
+ print("Caught device mismatch error, trying to fix...")
200
+ # A more drastic approach: move the model completely to CPU if there's a device issue
201
+ if torch.cuda.is_available():
202
+ print("Moving everything to CPU as a fallback")
203
+ self.model = self.model.cpu()
204
+ input_ids = input_ids.cpu()
205
+ with torch.no_grad():
206
+ outputs = self.model.generate(
207
+ input_ids,
208
+ **generation_kwargs
209
+ )
210
+ else:
211
+ raise
212
+ else:
213
+ raise
214
 
215
  # Decode output
216
  generated_texts = []