wealthcoders commited on
Commit
7b796cd
·
verified ·
1 Parent(s): 72c9ee9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -8
handler.py CHANGED
@@ -21,18 +21,12 @@ class EndpointHandler:
21
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  print(f"Using device: {self.device}")
23
 
24
- # Load model WITHOUT flash attention
25
  model_kwargs = {
26
  'trust_remote_code': True,
 
27
  }
28
 
29
- # Use appropriate dtype based on GPU capability
30
- if self.device == 'cuda':
31
- # T4 and L4 work better with float16
32
- model_kwargs['torch_dtype'] = torch.float16
33
- else:
34
- model_kwargs['torch_dtype'] = torch.float32
35
-
36
  # Explicitly disable flash attention
37
  model_kwargs['_attn_implementation'] = 'eager'
38
 
 
21
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  print(f"Using device: {self.device}")
23
 
24
+ # Load model in float32 to avoid dtype conflicts
25
  model_kwargs = {
26
  'trust_remote_code': True,
27
+ 'torch_dtype': torch.float32 # Use float32 instead of float16
28
  }
29
 
 
 
 
 
 
 
 
30
  # Explicitly disable flash attention
31
  model_kwargs['_attn_implementation'] = 'eager'
32