Lior-0618 commited on
Commit
64b796f
·
1 Parent(s): b7f8db8

fix: use bfloat16 on CPU to halve memory usage (~6 GB vs ~12 GB)

Browse files

float32 exceeded HF Spaces 16 GiB limit. bfloat16 supported on
PyTorch CPU since 1.12.

Files changed (1) hide show
  1. api/main.py +6 -5
api/main.py CHANGED
@@ -37,12 +37,13 @@ def _init_model() -> None:
37
  from transformers import VoxtralForConditionalGeneration, AutoProcessor
38
  from peft import PeftModel
39
 
 
 
 
40
  if torch.cuda.is_available():
41
- _model_dtype = torch.bfloat16
42
- device_map = "auto"
43
  else:
44
- _model_dtype = torch.float32
45
- device_map = "cpu"
46
 
47
  print(f"[voxtral] Loading processor {MODEL_ID} ...")
48
  _processor = AutoProcessor.from_pretrained(MODEL_ID)
@@ -50,7 +51,7 @@ def _init_model() -> None:
50
  print(f"[voxtral] Loading base model {MODEL_ID} (dtype={_model_dtype}) ...")
51
  base_model = VoxtralForConditionalGeneration.from_pretrained(
52
  MODEL_ID,
53
- torch_dtype=_model_dtype,
54
  device_map=device_map,
55
  )
56
 
 
37
  from transformers import VoxtralForConditionalGeneration, AutoProcessor
38
  from peft import PeftModel
39
 
40
+ # bfloat16 on both GPU and CPU — halves memory vs float32 (~6 GB vs ~12 GB)
41
+ # PyTorch CPU supports bfloat16 natively since 1.12
42
+ _model_dtype = torch.bfloat16
43
  if torch.cuda.is_available():
44
+ device_map = "auto"
 
45
  else:
46
+ device_map = "cpu"
 
47
 
48
  print(f"[voxtral] Loading processor {MODEL_ID} ...")
49
  _processor = AutoProcessor.from_pretrained(MODEL_ID)
 
51
  print(f"[voxtral] Loading base model {MODEL_ID} (dtype={_model_dtype}) ...")
52
  base_model = VoxtralForConditionalGeneration.from_pretrained(
53
  MODEL_ID,
54
+ dtype=_model_dtype,
55
  device_map=device_map,
56
  )
57