chiemekakalu commited on
Commit
3149176
·
verified ·
1 Parent(s): c2f6ac2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +71 -19
handler.py CHANGED
@@ -45,32 +45,49 @@ class EndpointHandler:
45
  # Find input length to properly calculate output length
46
  input_length = inputs.shape[1]
47
 
48
- # Generate with optimized parameters for GPU performance
49
- outputs = self.model.generate(
50
- inputs,
51
- attention_mask=attention_mask,
52
- max_new_tokens=max_new_tokens,
53
 
54
  # Performance options
55
- use_cache=True, # Use KV cache for faster generation
56
 
57
  # Quality vs. speed tradeoff
58
- temperature=0.7 if self.use_sampling else 1.0,
59
- top_p=0.9 if self.use_sampling else 1.0,
60
- do_sample=self.use_sampling, # Sampling is slightly slower but better quality
61
- num_beams=1, # Beam search is slower but better quality (1 = no beam search)
62
 
63
  # Token handling
64
- pad_token_id=self.tokenizer.pad_token_id,
65
- eos_token_id=self.tokenizer.eos_token_id,
66
 
67
  # Content quality
68
- repetition_penalty=1.1, # Reduce repetition
69
-
70
- # Memory optimization - enabled only if supported
71
- flash_attn=self.flash_attention_supported,
72
- flash_attn_cross_entropy=self.flash_attention_supported
73
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  return outputs, input_length
76
 
@@ -190,6 +207,41 @@ class EndpointHandler:
190
  device_map="auto",
191
  )
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  print(f"Model loaded successfully on {self.device}")
194
  return True
195
  except Exception as e:
@@ -879,7 +931,7 @@ Return a JSON array containing ONLY the candidate numbers (starting from 1) that
879
  return {
880
  "team_analysis": team_analysis,
881
  "model_info": {
882
- "device": str(self.device),
883
  "model_type": "phi-2-qlora-finetuned"
884
  }
885
  }
 
45
  # Find input length to properly calculate output length
46
  input_length = inputs.shape[1]
47
 
48
+ # Basic generation parameters
49
+ generation_kwargs = {
50
+ "inputs": inputs,
51
+ "attention_mask": attention_mask,
52
+ "max_new_tokens": max_new_tokens,
53
 
54
  # Performance options
55
+ "use_cache": True, # Use KV cache for faster generation
56
 
57
  # Quality vs. speed tradeoff
58
+ "temperature": 0.7 if self.use_sampling else 1.0,
59
+ "top_p": 0.9 if self.use_sampling else 1.0,
60
+ "do_sample": self.use_sampling, # Sampling is slightly slower but better quality
61
+ "num_beams": 1, # Beam search is slower but better quality (1 = no beam search)
62
 
63
  # Token handling
64
+ "pad_token_id": self.tokenizer.pad_token_id,
65
+ "eos_token_id": self.tokenizer.eos_token_id,
66
 
67
  # Content quality
68
+ "repetition_penalty": 1.1, # Reduce repetition
69
+ }
70
+
71
+ # Add Flash Attention parameters only if supported by the transformers version
72
+ # We check the transformer version by testing in a safe way
73
+ try:
74
+ import importlib
75
+ transformers_version = importlib.import_module('transformers').__version__
76
+ major, minor = map(int, transformers_version.split('.')[:2])
77
+
78
+ if major > 4 or (major == 4 and minor >= 32):
79
+ # Flash Attention support was added in transformers 4.32.0
80
+ if self.flash_attention_supported:
81
+ print("Using Flash Attention in generation")
82
+ generation_kwargs["flash_attn"] = True
83
+ generation_kwargs["flash_attn_cross_entropy"] = True
84
+ else:
85
+ print(f"Flash Attention not added - transformers version {transformers_version} doesn't support it")
86
+ except Exception as e:
87
+ print(f"Error checking transformers version, skipping Flash Attention: {e}")
88
+
89
+ # Generate with optimized parameters for GPU performance
90
+ outputs = self.model.generate(**generation_kwargs)
91
 
92
  return outputs, input_length
93
 
 
207
  device_map="auto",
208
  )
209
 
210
+ # Check for Flash Attention support with better error handling
211
+ try:
212
+ # First check if the transformers version supports it
213
+ import importlib
214
+ transformers_version = importlib.import_module('transformers').__version__
215
+ major, minor = map(int, transformers_version.split('.')[:2])
216
+
217
+ if major > 4 or (major == 4 and minor >= 32):
218
+ # Flash Attention support was added in transformers 4.32.0
219
+ try:
220
+ import flash_attn
221
+ self.flash_attention_supported = True
222
+ print(f"Flash Attention {flash_attn.__version__} detected and will be used if available!")
223
+ except ImportError:
224
+ print("Flash Attention library not installed. Using standard attention mechanism.")
225
+ self.flash_attention_supported = False
226
+ else:
227
+ print(f"Transformers version {transformers_version} doesn't support Flash Attention parameters. Using standard attention.")
228
+ self.flash_attention_supported = False
229
+ except Exception as e:
230
+ print(f"Error checking Flash Attention support: {e}")
231
+ print("Falling back to standard attention mechanism.")
232
+ self.flash_attention_supported = False
233
+
234
+ # Enable TF32 precision for higher performance on newer NVIDIA GPUs
235
+ if self.device == "cuda":
236
+ # Only available on Ampere+ GPUs (A100, RTX 3090, etc.)
237
+ try:
238
+ if torch.cuda.get_device_capability()[0] >= 8:
239
+ print("Enabling TF32 precision for faster matrix operations")
240
+ torch.backends.cuda.matmul.allow_tf32 = True
241
+ torch.backends.cudnn.allow_tf32 = True
242
+ except Exception as e:
243
+ print(f"Error enabling TF32 precision: {e}")
244
+
245
  print(f"Model loaded successfully on {self.device}")
246
  return True
247
  except Exception as e:
 
931
  return {
932
  "team_analysis": team_analysis,
933
  "model_info": {
934
+ "x": str(self.device),
935
  "model_type": "phi-2-qlora-finetuned"
936
  }
937
  }