Update handler.py
Browse files- handler.py +71 -19
handler.py
CHANGED
|
@@ -45,32 +45,49 @@ class EndpointHandler:
|
|
| 45 |
# Find input length to properly calculate output length
|
| 46 |
input_length = inputs.shape[1]
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
inputs,
|
| 51 |
-
attention_mask
|
| 52 |
-
max_new_tokens
|
| 53 |
|
| 54 |
# Performance options
|
| 55 |
-
use_cache
|
| 56 |
|
| 57 |
# Quality vs. speed tradeoff
|
| 58 |
-
temperature
|
| 59 |
-
top_p
|
| 60 |
-
do_sample
|
| 61 |
-
num_beams
|
| 62 |
|
| 63 |
# Token handling
|
| 64 |
-
pad_token_id
|
| 65 |
-
eos_token_id
|
| 66 |
|
| 67 |
# Content quality
|
| 68 |
-
repetition_penalty
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
return outputs, input_length
|
| 76 |
|
|
@@ -190,6 +207,41 @@ class EndpointHandler:
|
|
| 190 |
device_map="auto",
|
| 191 |
)
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
print(f"Model loaded successfully on {self.device}")
|
| 194 |
return True
|
| 195 |
except Exception as e:
|
|
@@ -879,7 +931,7 @@ Return a JSON array containing ONLY the candidate numbers (starting from 1) that
|
|
| 879 |
return {
|
| 880 |
"team_analysis": team_analysis,
|
| 881 |
"model_info": {
|
| 882 |
-
"
|
| 883 |
"model_type": "phi-2-qlora-finetuned"
|
| 884 |
}
|
| 885 |
}
|
|
|
|
| 45 |
# Find input length to properly calculate output length
|
| 46 |
input_length = inputs.shape[1]
|
| 47 |
|
| 48 |
+
# Basic generation parameters
|
| 49 |
+
generation_kwargs = {
|
| 50 |
+
"inputs": inputs,
|
| 51 |
+
"attention_mask": attention_mask,
|
| 52 |
+
"max_new_tokens": max_new_tokens,
|
| 53 |
|
| 54 |
# Performance options
|
| 55 |
+
"use_cache": True, # Use KV cache for faster generation
|
| 56 |
|
| 57 |
# Quality vs. speed tradeoff
|
| 58 |
+
"temperature": 0.7 if self.use_sampling else 1.0,
|
| 59 |
+
"top_p": 0.9 if self.use_sampling else 1.0,
|
| 60 |
+
"do_sample": self.use_sampling, # Sampling is slightly slower but better quality
|
| 61 |
+
"num_beams": 1, # Beam search is slower but better quality (1 = no beam search)
|
| 62 |
|
| 63 |
# Token handling
|
| 64 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 65 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
| 66 |
|
| 67 |
# Content quality
|
| 68 |
+
"repetition_penalty": 1.1, # Reduce repetition
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Add Flash Attention parameters only if supported by the transformers version
|
| 72 |
+
# We check the transformer version by testing in a safe way
|
| 73 |
+
try:
|
| 74 |
+
import importlib
|
| 75 |
+
transformers_version = importlib.import_module('transformers').__version__
|
| 76 |
+
major, minor = map(int, transformers_version.split('.')[:2])
|
| 77 |
+
|
| 78 |
+
if major > 4 or (major == 4 and minor >= 32):
|
| 79 |
+
# Flash Attention support was added in transformers 4.32.0
|
| 80 |
+
if self.flash_attention_supported:
|
| 81 |
+
print("Using Flash Attention in generation")
|
| 82 |
+
generation_kwargs["flash_attn"] = True
|
| 83 |
+
generation_kwargs["flash_attn_cross_entropy"] = True
|
| 84 |
+
else:
|
| 85 |
+
print(f"Flash Attention not added - transformers version {transformers_version} doesn't support it")
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Error checking transformers version, skipping Flash Attention: {e}")
|
| 88 |
+
|
| 89 |
+
# Generate with optimized parameters for GPU performance
|
| 90 |
+
outputs = self.model.generate(**generation_kwargs)
|
| 91 |
|
| 92 |
return outputs, input_length
|
| 93 |
|
|
|
|
| 207 |
device_map="auto",
|
| 208 |
)
|
| 209 |
|
| 210 |
+
# Check for Flash Attention support with better error handling
|
| 211 |
+
try:
|
| 212 |
+
# First check if the transformers version supports it
|
| 213 |
+
import importlib
|
| 214 |
+
transformers_version = importlib.import_module('transformers').__version__
|
| 215 |
+
major, minor = map(int, transformers_version.split('.')[:2])
|
| 216 |
+
|
| 217 |
+
if major > 4 or (major == 4 and minor >= 32):
|
| 218 |
+
# Flash Attention support was added in transformers 4.32.0
|
| 219 |
+
try:
|
| 220 |
+
import flash_attn
|
| 221 |
+
self.flash_attention_supported = True
|
| 222 |
+
print(f"Flash Attention {flash_attn.__version__} detected and will be used if available!")
|
| 223 |
+
except ImportError:
|
| 224 |
+
print("Flash Attention library not installed. Using standard attention mechanism.")
|
| 225 |
+
self.flash_attention_supported = False
|
| 226 |
+
else:
|
| 227 |
+
print(f"Transformers version {transformers_version} doesn't support Flash Attention parameters. Using standard attention.")
|
| 228 |
+
self.flash_attention_supported = False
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print(f"Error checking Flash Attention support: {e}")
|
| 231 |
+
print("Falling back to standard attention mechanism.")
|
| 232 |
+
self.flash_attention_supported = False
|
| 233 |
+
|
| 234 |
+
# Enable TF32 precision for higher performance on newer NVIDIA GPUs
|
| 235 |
+
if self.device == "cuda":
|
| 236 |
+
# Only available on Ampere+ GPUs (A100, RTX 3090, etc.)
|
| 237 |
+
try:
|
| 238 |
+
if torch.cuda.get_device_capability()[0] >= 8:
|
| 239 |
+
print("Enabling TF32 precision for faster matrix operations")
|
| 240 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 241 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"Error enabling TF32 precision: {e}")
|
| 244 |
+
|
| 245 |
print(f"Model loaded successfully on {self.device}")
|
| 246 |
return True
|
| 247 |
except Exception as e:
|
|
|
|
| 931 |
return {
|
| 932 |
"team_analysis": team_analysis,
|
| 933 |
"model_info": {
|
| 934 |
+
"x": str(self.device),
|
| 935 |
"model_type": "phi-2-qlora-finetuned"
|
| 936 |
}
|
| 937 |
}
|