Spaces:
Sleeping
Sleeping
Ensure reward model has float dtype
Browse files
src/smc/scorers/image_reward_utils.py
CHANGED
|
@@ -313,6 +313,9 @@ def rm_load(
|
|
| 313 |
else:
|
| 314 |
model = model.to_empty(device=device)
|
| 315 |
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
print("checkpoint loaded")
|
| 318 |
model.eval()
|
|
|
|
| 313 |
else:
|
| 314 |
model = model.to_empty(device=device)
|
| 315 |
model.load_state_dict(state_dict, strict=False)
|
| 316 |
+
|
| 317 |
+
# For some reason, sometimes the model params are loaded as bfloat16
|
| 318 |
+
model = model.float()
|
| 319 |
|
| 320 |
print("checkpoint loaded")
|
| 321 |
model.eval()
|