wealthcoders commited on
Commit
82c7d5c
·
verified ·
1 Parent(s): 59c9346

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -8
handler.py CHANGED
@@ -14,24 +14,27 @@ class EndpointHandler:
14
  self.tokenizer = AutoTokenizer.from_pretrained(
15
  model_path,
16
  trust_remote_code=True,
17
- local_files_only=bool(model_dir) # Only use local files if model_dir is provided
18
  )
19
 
20
  # Check if CUDA is available
21
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
22
 
23
- # Load model with appropriate settings
24
  model_kwargs = {
25
  'trust_remote_code': True,
26
- 'torch_dtype': torch.bfloat16 if self.device == 'cuda' else torch.float32
27
  }
28
 
29
- # Add flash attention if available and on CUDA
30
  if self.device == 'cuda':
31
- try:
32
- model_kwargs['_attn_implementation'] = 'flash_attention_2'
33
- except:
34
- pass # Fall back to default if flash attention not available
 
 
 
35
 
36
  self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
37
  self.model = self.model.eval()
 
14
  self.tokenizer = AutoTokenizer.from_pretrained(
15
  model_path,
16
  trust_remote_code=True,
17
+ local_files_only=bool(model_dir)
18
  )
19
 
20
  # Check if CUDA is available
21
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ print(f"Using device: {self.device}")
23
 
24
+ # Load model WITHOUT flash attention
25
  model_kwargs = {
26
  'trust_remote_code': True,
 
27
  }
28
 
29
+ # Use appropriate dtype based on GPU capability
30
  if self.device == 'cuda':
31
+ # T4 and L4 work better with float16
32
+ model_kwargs['torch_dtype'] = torch.float16
33
+ else:
34
+ model_kwargs['torch_dtype'] = torch.float32
35
+
36
+ # Explicitly disable flash attention
37
+ model_kwargs['_attn_implementation'] = 'eager'
38
 
39
  self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
40
  self.model = self.model.eval()