Nekochu commited on
Commit
562fa54
·
1 Parent(s): dece91f

force float32 on training model for CPU

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -304,13 +304,14 @@ def train_lora(
304
  from acestep.training_v2.trainer_fixed import FixedLoRATrainer
305
  from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
306
 
307
- # Load model for training
308
  model = load_decoder_for_training(
309
  checkpoint_dir=CHECKPOINT_DIR,
310
  variant="turbo",
311
  device="cpu",
312
  precision="float32",
313
  )
 
314
 
315
  adapter_cfg = LoRAConfigV2(
316
  r=rank,
 
304
  from acestep.training_v2.trainer_fixed import FixedLoRATrainer
305
  from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
306
 
307
+ # Load model for training (force float32 for CPU)
308
  model = load_decoder_for_training(
309
  checkpoint_dir=CHECKPOINT_DIR,
310
  variant="turbo",
311
  device="cpu",
312
  precision="float32",
313
  )
314
+ model = model.float()
315
 
316
  adapter_cfg = LoRAConfigV2(
317
  r=rank,