cp524 commited on
Commit
99a6b27
·
1 Parent(s): ba9def6

Ensure reward model has float dtype

Browse files
src/smc/scorers/image_reward_utils.py CHANGED
@@ -313,6 +313,9 @@ def rm_load(
313
  else:
314
  model = model.to_empty(device=device)
315
  model.load_state_dict(state_dict, strict=False)
 
 
 
316
 
317
  print("checkpoint loaded")
318
  model.eval()
 
313
  else:
314
  model = model.to_empty(device=device)
315
  model.load_state_dict(state_dict, strict=False)
316
+
317
+ # For some reason, sometimes the model params are loaded as bfloat16
318
+ model = model.float()
319
 
320
  print("checkpoint loaded")
321
  model.eval()