cp524 commited on
Commit
b49998b
·
1 Parent(s): 3e0672b

Fix device load race condition bug

Browse files
src/smc/inference.py CHANGED
@@ -24,8 +24,8 @@ MIN_GPU_DURATION = 60
24
 
25
 
26
  pipe_build_lock = threading.Lock()
27
- pipe_load_lock = threading.Lock()
28
- reward_model_load_lock = threading.Lock()
29
  lora_load_lock = threading.Lock()
30
 
31
 
@@ -83,11 +83,11 @@ def _get_pretrained_duration(config: PretrainedInferenceConfig, pipe: Pipeline,
83
  def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'):
84
  if isinstance(device, str):
85
  device = torch.device(device)
86
- with pipe_load_lock:
87
  pipe = pipe.to(device)
88
  reward_bias = 5.0
89
- with reward_model_load_lock:
90
- reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
91
  image_reward_fn = lambda images: reward_fn(
92
  images,
93
  [config.prompt] * len(images)
@@ -174,11 +174,11 @@ def _get_smc_grad_duration(config: SMCGradInferenceConfig, pipe: Pipeline, devic
174
  def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'):
175
  if isinstance(device, str):
176
  device = torch.device(device)
177
- with pipe_load_lock:
178
  pipe = pipe.to(device)
179
  reward_bias = 5.0
180
- with reward_model_load_lock:
181
- reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
182
  image_reward_fn = lambda images: reward_fn(
183
  images,
184
  [config.prompt] * len(images)
@@ -240,13 +240,13 @@ def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') ->
240
  def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
241
  if isinstance(device, str):
242
  device = torch.device(device)
243
- with pipe_load_lock:
244
  pipe = pipe.to(device)
245
  with lora_load_lock:
246
  load_lora_weights(pipe, config.ckpt_uuid)
247
  reward_bias = 5.0
248
- with reward_model_load_lock:
249
- reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
250
  image_reward_fn = lambda images: reward_fn(
251
  images,
252
  [config.prompt] * len(images)
 
24
 
25
 
26
  pipe_build_lock = threading.Lock()
27
+ reward_model_build_lock = threading.Lock()
28
+ device_load_lock = threading.Lock()
29
  lora_load_lock = threading.Lock()
30
 
31
 
 
83
  def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'):
84
  if isinstance(device, str):
85
  device = torch.device(device)
86
+ with device_load_lock:
87
  pipe = pipe.to(device)
88
  reward_bias = 5.0
89
+ with reward_model_build_lock:
90
+ reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
91
  image_reward_fn = lambda images: reward_fn(
92
  images,
93
  [config.prompt] * len(images)
 
174
  def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'):
175
  if isinstance(device, str):
176
  device = torch.device(device)
177
+ with device_load_lock:
178
  pipe = pipe.to(device)
179
  reward_bias = 5.0
180
+ with reward_model_build_lock:
181
+ reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
182
  image_reward_fn = lambda images: reward_fn(
183
  images,
184
  [config.prompt] * len(images)
 
240
  def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
241
  if isinstance(device, str):
242
  device = torch.device(device)
243
+ with device_load_lock:
244
  pipe = pipe.to(device)
245
  with lora_load_lock:
246
  load_lora_weights(pipe, config.ckpt_uuid)
247
  reward_bias = 5.0
248
+ with reward_model_build_lock:
249
+ reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
250
  image_reward_fn = lambda images: reward_fn(
251
  images,
252
  [config.prompt] * len(images)
src/smc/rewards.py CHANGED
@@ -155,13 +155,14 @@ def ImageReward(
155
 
156
  def ImageReward_Fk_Steering(
157
  inference_dtype=None,
158
- device=None,
 
159
  return_loss=False,
160
  bias=None,
161
  ):
162
  from src.smc.scorers.image_reward_utils import rm_load
163
 
164
- scorer = rm_load("ImageReward-v1.0", device=device)
165
 
166
  if not return_loss:
167
  def _fn(images, prompts):
 
155
 
156
  def ImageReward_Fk_Steering(
157
  inference_dtype=None,
158
+ device=None,
159
+ device_load_lock=None,
160
  return_loss=False,
161
  bias=None,
162
  ):
163
  from src.smc.scorers.image_reward_utils import rm_load
164
 
165
+ scorer = rm_load("ImageReward-v1.0", device=device, device_load_lock=device_load_lock)
166
 
167
  if not return_loss:
168
  def _fn(images, prompts):
src/smc/scorers/image_reward_utils.py CHANGED
@@ -261,6 +261,7 @@ class IRSMC(nn.Module):
261
  def rm_load(
262
  name: str = "ImageReward-v1.0",
263
  device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
 
264
  download_root: str = None,
265
  med_config: str = None,
266
  ):
@@ -303,7 +304,12 @@ def rm_load(
303
  download_root or os.path.expanduser("~/.cache/ImageReward"),
304
  )
305
 
306
- model = IRSMC(device=device, med_config=med_config).to(device)
 
 
 
 
 
307
  msg = model.load_state_dict(state_dict, strict=False)
308
  print("checkpoint loaded")
309
  model.eval()
 
261
  def rm_load(
262
  name: str = "ImageReward-v1.0",
263
  device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
264
+ device_load_lock=None,
265
  download_root: str = None,
266
  med_config: str = None,
267
  ):
 
304
  download_root or os.path.expanduser("~/.cache/ImageReward"),
305
  )
306
 
307
+ model = IRSMC(device=device, med_config=med_config)
308
+ if device_load_lock is not None:
309
+ with device_load_lock:
310
+ model = model.to(device)
311
+ else:
312
+ model = model.to(device)
313
  msg = model.load_state_dict(state_dict, strict=False)
314
  print("checkpoint loaded")
315
  model.eval()