Rithankoushik commited on
Commit
a2a6945
·
verified ·
1 Parent(s): f2a8c76

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +2 -1
inference.py CHANGED
@@ -21,11 +21,12 @@ def load_model_and_tokenizer():
21
  model = AutoModelForCausalLM.from_pretrained(
22
  MODEL_REPO,
23
  trust_remote_code=True,
24
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
  device_map="auto"
26
  )
27
 
28
 
 
29
  return tokenizer, model
30
 
31
 
 
21
  model = AutoModelForCausalLM.from_pretrained(
22
  MODEL_REPO,
23
  trust_remote_code=True,
24
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
25
  device_map="auto"
26
  )
27
 
28
 
29
+
30
  return tokenizer, model
31
 
32