devranx commited on
Commit
bf2f4a4
·
1 Parent(s): 4d16b73

Optimize memory usage to bfloat16

Browse files
Files changed (1) hide show
  1. utils.py +2 -1
utils.py CHANGED
@@ -21,7 +21,8 @@ def load_model():
21
  print(f"Loading model: {MODEL_ID}...")
22
  try:
23
  device_type = "cuda" if torch.cuda.is_available() else "cpu"
24
- torch_dtype = torch.float16 if device_type == "cuda" else torch.float32
 
25
  print(f"Using device: {device_type}, dtype: {torch_dtype}")
26
 
27
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
21
  print(f"Loading model: {MODEL_ID}...")
22
  try:
23
  device_type = "cuda" if torch.cuda.is_available() else "cpu"
24
+ # Use bfloat16 for CPU to save memory (4B params * 4 bytes is too big for 16GB)
25
+ torch_dtype = torch.float16 if device_type == "cuda" else torch.bfloat16
26
  print(f"Using device: {device_type}, dtype: {torch_dtype}")
27
 
28
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)