Spaces:
Sleeping
Sleeping
Fix device load race condition bug
Browse files- src/smc/inference.py +11 -11
- src/smc/rewards.py +3 -2
- src/smc/scorers/image_reward_utils.py +7 -1
src/smc/inference.py
CHANGED
|
@@ -24,8 +24,8 @@ MIN_GPU_DURATION = 60
|
|
| 24 |
|
| 25 |
|
| 26 |
pipe_build_lock = threading.Lock()
|
| 27 |
-
|
| 28 |
-
|
| 29 |
lora_load_lock = threading.Lock()
|
| 30 |
|
| 31 |
|
|
@@ -83,11 +83,11 @@ def _get_pretrained_duration(config: PretrainedInferenceConfig, pipe: Pipeline,
|
|
| 83 |
def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'):
|
| 84 |
if isinstance(device, str):
|
| 85 |
device = torch.device(device)
|
| 86 |
-
with
|
| 87 |
pipe = pipe.to(device)
|
| 88 |
reward_bias = 5.0
|
| 89 |
-
with
|
| 90 |
-
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
|
| 91 |
image_reward_fn = lambda images: reward_fn(
|
| 92 |
images,
|
| 93 |
[config.prompt] * len(images)
|
|
@@ -174,11 +174,11 @@ def _get_smc_grad_duration(config: SMCGradInferenceConfig, pipe: Pipeline, devic
|
|
| 174 |
def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'):
|
| 175 |
if isinstance(device, str):
|
| 176 |
device = torch.device(device)
|
| 177 |
-
with
|
| 178 |
pipe = pipe.to(device)
|
| 179 |
reward_bias = 5.0
|
| 180 |
-
with
|
| 181 |
-
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
|
| 182 |
image_reward_fn = lambda images: reward_fn(
|
| 183 |
images,
|
| 184 |
[config.prompt] * len(images)
|
|
@@ -240,13 +240,13 @@ def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') ->
|
|
| 240 |
def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
|
| 241 |
if isinstance(device, str):
|
| 242 |
device = torch.device(device)
|
| 243 |
-
with
|
| 244 |
pipe = pipe.to(device)
|
| 245 |
with lora_load_lock:
|
| 246 |
load_lora_weights(pipe, config.ckpt_uuid)
|
| 247 |
reward_bias = 5.0
|
| 248 |
-
with
|
| 249 |
-
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
|
| 250 |
image_reward_fn = lambda images: reward_fn(
|
| 251 |
images,
|
| 252 |
[config.prompt] * len(images)
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
pipe_build_lock = threading.Lock()
|
| 27 |
+
reward_model_build_lock = threading.Lock()
|
| 28 |
+
device_load_lock = threading.Lock()
|
| 29 |
lora_load_lock = threading.Lock()
|
| 30 |
|
| 31 |
|
|
|
|
| 83 |
def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'):
|
| 84 |
if isinstance(device, str):
|
| 85 |
device = torch.device(device)
|
| 86 |
+
with device_load_lock:
|
| 87 |
pipe = pipe.to(device)
|
| 88 |
reward_bias = 5.0
|
| 89 |
+
with reward_model_build_lock:
|
| 90 |
+
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
|
| 91 |
image_reward_fn = lambda images: reward_fn(
|
| 92 |
images,
|
| 93 |
[config.prompt] * len(images)
|
|
|
|
| 174 |
def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'):
|
| 175 |
if isinstance(device, str):
|
| 176 |
device = torch.device(device)
|
| 177 |
+
with device_load_lock:
|
| 178 |
pipe = pipe.to(device)
|
| 179 |
reward_bias = 5.0
|
| 180 |
+
with reward_model_build_lock:
|
| 181 |
+
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
|
| 182 |
image_reward_fn = lambda images: reward_fn(
|
| 183 |
images,
|
| 184 |
[config.prompt] * len(images)
|
|
|
|
| 240 |
def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
|
| 241 |
if isinstance(device, str):
|
| 242 |
device = torch.device(device)
|
| 243 |
+
with device_load_lock:
|
| 244 |
pipe = pipe.to(device)
|
| 245 |
with lora_load_lock:
|
| 246 |
load_lora_weights(pipe, config.ckpt_uuid)
|
| 247 |
reward_bias = 5.0
|
| 248 |
+
with reward_model_build_lock:
|
| 249 |
+
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
|
| 250 |
image_reward_fn = lambda images: reward_fn(
|
| 251 |
images,
|
| 252 |
[config.prompt] * len(images)
|
src/smc/rewards.py
CHANGED
|
@@ -155,13 +155,14 @@ def ImageReward(
|
|
| 155 |
|
| 156 |
def ImageReward_Fk_Steering(
|
| 157 |
inference_dtype=None,
|
| 158 |
-
device=None,
|
|
|
|
| 159 |
return_loss=False,
|
| 160 |
bias=None,
|
| 161 |
):
|
| 162 |
from src.smc.scorers.image_reward_utils import rm_load
|
| 163 |
|
| 164 |
-
scorer = rm_load("ImageReward-v1.0", device=device)
|
| 165 |
|
| 166 |
if not return_loss:
|
| 167 |
def _fn(images, prompts):
|
|
|
|
| 155 |
|
| 156 |
def ImageReward_Fk_Steering(
|
| 157 |
inference_dtype=None,
|
| 158 |
+
device=None,
|
| 159 |
+
device_load_lock=None,
|
| 160 |
return_loss=False,
|
| 161 |
bias=None,
|
| 162 |
):
|
| 163 |
from src.smc.scorers.image_reward_utils import rm_load
|
| 164 |
|
| 165 |
+
scorer = rm_load("ImageReward-v1.0", device=device, device_load_lock=device_load_lock)
|
| 166 |
|
| 167 |
if not return_loss:
|
| 168 |
def _fn(images, prompts):
|
src/smc/scorers/image_reward_utils.py
CHANGED
|
@@ -261,6 +261,7 @@ class IRSMC(nn.Module):
|
|
| 261 |
def rm_load(
|
| 262 |
name: str = "ImageReward-v1.0",
|
| 263 |
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
|
|
|
| 264 |
download_root: str = None,
|
| 265 |
med_config: str = None,
|
| 266 |
):
|
|
@@ -303,7 +304,12 @@ def rm_load(
|
|
| 303 |
download_root or os.path.expanduser("~/.cache/ImageReward"),
|
| 304 |
)
|
| 305 |
|
| 306 |
-
model = IRSMC(device=device, med_config=med_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
msg = model.load_state_dict(state_dict, strict=False)
|
| 308 |
print("checkpoint loaded")
|
| 309 |
model.eval()
|
|
|
|
| 261 |
def rm_load(
|
| 262 |
name: str = "ImageReward-v1.0",
|
| 263 |
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
| 264 |
+
device_load_lock=None,
|
| 265 |
download_root: str = None,
|
| 266 |
med_config: str = None,
|
| 267 |
):
|
|
|
|
| 304 |
download_root or os.path.expanduser("~/.cache/ImageReward"),
|
| 305 |
)
|
| 306 |
|
| 307 |
+
model = IRSMC(device=device, med_config=med_config)
|
| 308 |
+
if device_load_lock is not None:
|
| 309 |
+
with device_load_lock:
|
| 310 |
+
model = model.to(device)
|
| 311 |
+
else:
|
| 312 |
+
model = model.to(device)
|
| 313 |
msg = model.load_state_dict(state_dict, strict=False)
|
| 314 |
print("checkpoint loaded")
|
| 315 |
model.eval()
|