chiemekakalu commited on
Commit
48fa2fe
·
verified ·
1 Parent(s): 2e1f144

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +293 -49
handler.py CHANGED
@@ -1,8 +1,17 @@
1
  import os
2
  import json
3
  import torch
4
- from typing import Dict, List, Any
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class EndpointHandler:
@@ -17,39 +26,175 @@ class EndpointHandler:
17
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
  self.model_dir = model_dir or os.getenv("MODEL_PATH", "/model")
19
 
 
 
 
 
20
  # Load model immediately
21
  self.load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def load_model(self):
24
  """Load the finetuned model and tokenizer."""
25
  try:
26
  print(f"Loading model from {self.model_dir} to {self.device}...")
27
 
28
- # Load tokenizer
29
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
30
-
31
- # Try to load model with quantization, fall back to standard loading if bitsandbytes is missing
32
  try:
33
- # First try with 4-bit quantization
34
- self.model = AutoModelForCausalLM.from_pretrained(
35
  self.model_dir,
36
- torch_dtype=torch.float16, # Use FP16 for efficiency
37
- device_map="auto", # Auto-assign to available devices
38
- load_in_4bit=True, # Use 4-bit quantization for memory efficiency
39
  )
40
- except ImportError as e:
41
- print(f"Warning: Could not use quantization, falling back to standard loading: {e}")
42
- # Fallback to standard loading without quantization
43
- self.model = AutoModelForCausalLM.from_pretrained(
44
- self.model_dir,
45
- torch_dtype=torch.float16, # Still use FP16 for efficiency
46
- device_map="auto", # Auto-assign to available devices
 
 
 
 
 
 
 
 
47
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  print(f"Model loaded successfully on {self.device}")
50
  return True
51
  except Exception as e:
52
  print(f"Error loading model: {e}")
 
 
53
  return False
54
 
55
  def format_candidates_for_prompt(self, candidates: List[Dict[str, Any]]) -> str:
@@ -131,21 +276,54 @@ Format your response carefully with clear headings and make it comprehensive eno
131
  return_tensors="pt"
132
  ).to(self.device)
133
 
134
- # Generate
135
  with torch.no_grad():
136
- outputs = self.model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  inputs,
138
- max_length=4096,
139
- temperature=0.7,
140
- top_p=0.9,
141
- do_sample=True,
142
  )
143
 
144
- # Decode
145
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- # Extract the assistant's response (everything after the user's prompt)
148
- assistant_response = response.split(prompt)[-1].strip()
 
 
 
149
 
150
  return assistant_response
151
 
@@ -219,21 +397,54 @@ Please provide:
219
  return_tensors="pt"
220
  ).to(self.device)
221
 
222
- # Generate
223
  with torch.no_grad():
224
- outputs = self.model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  inputs,
226
- max_length=3072,
227
- temperature=0.7,
228
- top_p=0.9,
229
- do_sample=True,
230
  )
231
 
232
- # Decode
233
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- # Extract the assistant's response
236
- assistant_response = response.split(prompt)[-1].strip()
 
 
 
237
 
238
  return assistant_response
239
 
@@ -337,21 +548,54 @@ Format your analysis with clear sections and detailed insights to help assess th
337
  return_tensors="pt"
338
  ).to(self.device)
339
 
340
- # Generate
341
  with torch.no_grad():
342
- outputs = self.model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  inputs,
344
- max_length=3072,
345
- temperature=0.7,
346
- top_p=0.9,
347
- do_sample=True,
348
  )
349
 
350
- # Decode
351
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
352
-
353
- # Extract the assistant's response
354
- assistant_response = response.split(prompt)[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  return assistant_response
357
 
 
1
  import os
2
  import json
3
  import torch
4
+ from typing import Dict, List, Any, Optional, Union
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
+
7
+ # Import PEFT for adapter handling
8
+ try:
9
+ import peft
10
+ from peft import PeftModel, PeftConfig
11
+ PEFT_AVAILABLE = True
12
+ except ImportError:
13
+ PEFT_AVAILABLE = False
14
+ print("Warning: PEFT library not available. Adapter loading may fail.")
15
 
16
 
17
  class EndpointHandler:
 
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.model_dir = model_dir or os.getenv("MODEL_PATH", "/model")
28
 
29
+ # GPU performance optimization flags
30
+ self.flash_attention_supported = False # Will be set during model loading
31
+ self.use_sampling = True # Better quality but slightly slower than greedy
32
+
33
  # Load model immediately
34
  self.load_model()
35
+
36
+ def generate_optimized(self, inputs, attention_mask=None, max_new_tokens=512):
37
+ """
38
+ Optimized generation function that maximizes GPU utilization
39
+ while respecting model constraints.
40
+ """
41
+ # Check if we need to create an attention mask
42
+ if attention_mask is None:
43
+ attention_mask = inputs.ne(self.tokenizer.pad_token_id).long()
44
+
45
+ # Find input length to properly calculate output length
46
+ input_length = inputs.shape[1]
47
+
48
+ # Generate with optimized parameters for GPU performance
49
+ outputs = self.model.generate(
50
+ inputs,
51
+ attention_mask=attention_mask,
52
+ max_new_tokens=max_new_tokens,
53
+
54
+ # Performance options
55
+ use_cache=True, # Use KV cache for faster generation
56
+
57
+ # Quality vs. speed tradeoff
58
+ temperature=0.7 if self.use_sampling else 1.0,
59
+ top_p=0.9 if self.use_sampling else 1.0,
60
+ do_sample=self.use_sampling, # Sampling is slightly slower but better quality
61
+ num_beams=1, # Beam search is slower but better quality (1 = no beam search)
62
+
63
+ # Token handling
64
+ pad_token_id=self.tokenizer.pad_token_id,
65
+ eos_token_id=self.tokenizer.eos_token_id,
66
+
67
+ # Content quality
68
+ repetition_penalty=1.1, # Reduce repetition
69
+
70
+ # Memory optimization - enabled only if supported
71
+ flash_attn=self.flash_attention_supported,
72
+ flash_attn_cross_entropy=self.flash_attention_supported
73
+ )
74
+
75
+ return outputs, input_length
76
 
77
  def load_model(self):
78
  """Load the finetuned model and tokenizer."""
79
  try:
80
  print(f"Loading model from {self.model_dir} to {self.device}...")
81
 
82
+ # Load tokenizer with explicit padding token configuration
 
 
 
83
  try:
84
+ self.tokenizer = AutoTokenizer.from_pretrained(
 
85
  self.model_dir,
86
+ padding_side="left", # Set padding to left side for causal LM
87
+ trust_remote_code=False
 
88
  )
89
+
90
+ # Ensure pad token is set properly (important for attention masks)
91
+ if self.tokenizer.pad_token is None:
92
+ self.tokenizer.pad_token = self.tokenizer.eos_token
93
+ print("Set pad_token to eos_token")
94
+
95
+ except Exception as tokenizer_error:
96
+ print(f"Error loading tokenizer from {self.model_dir}: {tokenizer_error}")
97
+ print("Attempting to load base Phi-2 tokenizer...")
98
+
99
+ # Fall back to base Phi-2 tokenizer if model dir tokenizer fails
100
+ self.tokenizer = AutoTokenizer.from_pretrained(
101
+ "microsoft/phi-2",
102
+ padding_side="left",
103
+ trust_remote_code=False
104
  )
105
+ if self.tokenizer.pad_token is None:
106
+ self.tokenizer.pad_token = self.tokenizer.eos_token
107
+
108
+ # Try to load model with quantization with consistent dtype settings
109
+ try:
110
+ from bitsandbytes.nn import Linear4bit
111
+ from transformers import BitsAndBytesConfig
112
+
113
+ print("Using 4-bit quantization with float16 compute type")
114
+
115
+ # Use consistent float16 for both compute and parameters
116
+ quantization_config = BitsAndBytesConfig(
117
+ load_in_4bit=True,
118
+ bnb_4bit_compute_dtype=torch.float16, # Match with model dtype
119
+ bnb_4bit_use_double_quant=True,
120
+ bnb_4bit_quant_type="nf4"
121
+ )
122
+
123
+ # Try to load with base model specification for better adapter compatibility
124
+ if os.path.exists(os.path.join(self.model_dir, "adapter_model.safetensors")):
125
+ print("Found adapter model, loading Phi-2 base with adapter")
126
+
127
+ # Check if PEFT is available
128
+ if not PEFT_AVAILABLE:
129
+ print("PEFT not available, installing...")
130
+ try:
131
+ import pip
132
+ pip.main(['install', 'peft'])
133
+ import peft
134
+ from peft import PeftModel, PeftConfig
135
+ PEFT_AVAILABLE = True
136
+ except Exception as e:
137
+ print(f"Failed to install PEFT: {e}")
138
+
139
+ # First load base model with quantization
140
+ base_model = AutoModelForCausalLM.from_pretrained(
141
+ "microsoft/phi-2",
142
+ quantization_config=quantization_config,
143
+ torch_dtype=torch.float16,
144
+ device_map="auto"
145
+ )
146
+
147
+ try:
148
+ # Then load adapter on top
149
+ self.model = PeftModel.from_pretrained(
150
+ base_model,
151
+ self.model_dir,
152
+ torch_dtype=torch.float16,
153
+ device_map="auto"
154
+ )
155
+ print("Successfully loaded adapter model")
156
+ except Exception as adapter_error:
157
+ print(f"Error loading adapter: {adapter_error}")
158
+ # Fall back to just using the base model
159
+ print("Falling back to base model without adapter")
160
+ self.model = base_model
161
+ else:
162
+ # Load as a standard model if no adapter is found
163
+ print("Loading model directly from directory")
164
+ self.model = AutoModelForCausalLM.from_pretrained(
165
+ self.model_dir,
166
+ torch_dtype=torch.float16,
167
+ device_map="auto",
168
+ quantization_config=quantization_config
169
+ )
170
+
171
+ except ImportError as e:
172
+ print(f"Warning: Could not use bitsandbytes quantization, falling back to standard loading: {e}")
173
+
174
+ # Fallback to standard FP16 loading without quantization
175
+ try:
176
+ self.model = AutoModelForCausalLM.from_pretrained(
177
+ self.model_dir,
178
+ torch_dtype=torch.float16,
179
+ device_map="auto",
180
+ )
181
+ except Exception as model_error:
182
+ print(f"Error loading from model directory: {model_error}")
183
+ print("Attempting to load base Phi-2 model...")
184
+
185
+ # Final fallback - try loading just the base model
186
+ self.model = AutoModelForCausalLM.from_pretrained(
187
+ "microsoft/phi-2",
188
+ torch_dtype=torch.float16,
189
+ device_map="auto",
190
+ )
191
 
192
  print(f"Model loaded successfully on {self.device}")
193
  return True
194
  except Exception as e:
195
  print(f"Error loading model: {e}")
196
+ import traceback
197
+ print(traceback.format_exc())
198
  return False
199
 
200
  def format_candidates_for_prompt(self, candidates: List[Dict[str, Any]]) -> str:
 
276
  return_tensors="pt"
277
  ).to(self.device)
278
 
279
+ # Generate with proper context limits and attention masks
280
  with torch.no_grad():
281
+ # Find input length to set appropriate output length
282
+ input_length = inputs.shape[1]
283
+ # Phi-2 has a context limit of 2048
284
+ max_context_length = 2048
285
+
286
+ # Calculate max new tokens to avoid exceeding model's context limits
287
+ max_new_tokens = max(100, min(1024, max_context_length - input_length))
288
+
289
+ print(f"Input length: {input_length}, Max new tokens: {max_new_tokens}")
290
+
291
+ # Create attention mask (explicitly handle padding)
292
+ attention_mask = inputs.ne(self.tokenizer.pad_token_id).long()
293
+
294
+ # Use the optimized generator instead of direct model.generate call
295
+ outputs, input_length = self.generate_optimized(
296
  inputs,
297
+ attention_mask=attention_mask,
298
+ max_new_tokens=max_new_tokens
 
 
299
  )
300
 
301
+ # Decode more carefully
302
+ try:
303
+ # Get only the generated part (exclude input tokens)
304
+ generated_output = outputs[0][input_length:]
305
+
306
+ # Decode just the new tokens
307
+ generated_text = self.tokenizer.decode(
308
+ generated_output,
309
+ skip_special_tokens=True,
310
+ clean_up_tokenization_spaces=True
311
+ )
312
+
313
+ # Remove any model-specific artifacts
314
+ generated_text = generated_text.replace("<|im_end|>", "").replace("<|im_start|>", "")
315
+ assistant_response = generated_text.strip()
316
+
317
+ # If that failed, try traditional approach
318
+ if not assistant_response:
319
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
320
+ assistant_response = full_response.split(prompt)[-1].strip()
321
 
322
+ except Exception as decode_error:
323
+ print(f"Error decoding response: {decode_error}")
324
+ # Fallback to simpler decoding
325
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
326
+ assistant_response = full_response.split(prompt)[-1].strip()
327
 
328
  return assistant_response
329
 
 
397
  return_tensors="pt"
398
  ).to(self.device)
399
 
400
+ # Generate with proper context limits and attention masks
401
  with torch.no_grad():
402
+ # Find input length to set appropriate output length
403
+ input_length = inputs.shape[1]
404
+ # Phi-2 has a context limit of 2048
405
+ max_context_length = 2048
406
+
407
+ # Calculate max new tokens to avoid exceeding model's context limits
408
+ max_new_tokens = max(100, min(1024, max_context_length - input_length))
409
+
410
+ print(f"Team analysis - Input length: {input_length}, Max new tokens: {max_new_tokens}")
411
+
412
+ # Create attention mask (explicitly handle padding)
413
+ attention_mask = inputs.ne(self.tokenizer.pad_token_id).long()
414
+
415
+ # Use the optimized generator instead of direct model.generate call
416
+ outputs, input_length = self.generate_optimized(
417
  inputs,
418
+ attention_mask=attention_mask,
419
+ max_new_tokens=max_new_tokens
 
 
420
  )
421
 
422
+ # Decode more carefully
423
+ try:
424
+ # Get only the generated part (exclude input tokens)
425
+ generated_output = outputs[0][input_length:]
426
+
427
+ # Decode just the new tokens
428
+ generated_text = self.tokenizer.decode(
429
+ generated_output,
430
+ skip_special_tokens=True,
431
+ clean_up_tokenization_spaces=True
432
+ )
433
+
434
+ # Remove any model-specific artifacts
435
+ generated_text = generated_text.replace("<|im_end|>", "").replace("<|im_start|>", "")
436
+ assistant_response = generated_text.strip()
437
+
438
+ # If that failed, try traditional approach
439
+ if not assistant_response:
440
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
441
+ assistant_response = full_response.split(prompt)[-1].strip()
442
 
443
+ except Exception as decode_error:
444
+ print(f"Error decoding team analysis response: {decode_error}")
445
+ # Fallback to simpler decoding
446
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
447
+ assistant_response = full_response.split(prompt)[-1].strip()
448
 
449
  return assistant_response
450
 
 
548
  return_tensors="pt"
549
  ).to(self.device)
550
 
551
+ # Generate with proper context limits and attention masks
552
  with torch.no_grad():
553
+ # Find input length to set appropriate output length
554
+ input_length = inputs.shape[1]
555
+ # Phi-2 has a context limit of 2048
556
+ max_context_length = 2048
557
+
558
+ # Calculate max new tokens to avoid exceeding model's context limits
559
+ max_new_tokens = max(100, min(1024, max_context_length - input_length))
560
+
561
+ print(f"Candidate analysis - Input length: {input_length}, Max new tokens: {max_new_tokens}")
562
+
563
+ # Create attention mask (explicitly handle padding)
564
+ attention_mask = inputs.ne(self.tokenizer.pad_token_id).long()
565
+
566
+ # Use the optimized generator instead of direct model.generate call
567
+ outputs, input_length = self.generate_optimized(
568
  inputs,
569
+ attention_mask=attention_mask,
570
+ max_new_tokens=max_new_tokens
 
 
571
  )
572
 
573
+ # Decode more carefully
574
+ try:
575
+ # Get only the generated part (exclude input tokens)
576
+ generated_output = outputs[0][input_length:]
577
+
578
+ # Decode just the new tokens
579
+ generated_text = self.tokenizer.decode(
580
+ generated_output,
581
+ skip_special_tokens=True,
582
+ clean_up_tokenization_spaces=True
583
+ )
584
+
585
+ # Remove any model-specific artifacts
586
+ generated_text = generated_text.replace("<|im_end|>", "").replace("<|im_start|>", "")
587
+ assistant_response = generated_text.strip()
588
+
589
+ # If that failed, try traditional approach
590
+ if not assistant_response:
591
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
592
+ assistant_response = full_response.split(prompt)[-1].strip()
593
+
594
+ except Exception as decode_error:
595
+ print(f"Error decoding candidate analysis response: {decode_error}")
596
+ # Fallback to simpler decoding
597
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
598
+ assistant_response = full_response.split(prompt)[-1].strip()
599
 
600
  return assistant_response
601