Burf commited on
Commit
7c81d0a
·
1 Parent(s): 541e9bd

fix weight path issue

Browse files
Files changed (1) hide show
  1. wrapper.py +11 -7
wrapper.py CHANGED
@@ -183,7 +183,7 @@ def peca(pipeline, save_path = "./weight", n_layer = 10):
183
  return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip
184
 
185
  class DrUM(DiffusionPipeline):
186
- def __init__(self, pipeline, repo_id = "Burf/DrUM"):
187
  """
188
  DrUM for various diffusion models
189
 
@@ -194,7 +194,7 @@ class DrUM(DiffusionPipeline):
194
  self.pipeline = pipeline
195
  self.repo_id = repo_id
196
 
197
- self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(pipeline, repo_id)
198
 
199
  @classmethod
200
  def from_pretrained(cls, model_id, repo_id = "Burf/DrUM", torch_dtype = torch.bfloat16, device = "cuda"):
@@ -220,7 +220,7 @@ class DrUM(DiffusionPipeline):
220
  #pipeline.safety_checker = lambda images, clip_input: (images, [False] * len(images))
221
  return cls(pipeline, repo_id)
222
 
223
- def load_weight(self, pipeline, repo_id = "Burf/DrUM"):
224
  name = pipeline.config._name_or_path.split("/")[-1].lower()
225
 
226
  weights = []
@@ -236,12 +236,16 @@ class DrUM(DiffusionPipeline):
236
  weights = ["L.safetensors"]
237
 
238
  for weight_file in weights:
239
- safetensor_path = hf_hub_download(repo_id = repo_id, filename = weight_file)
240
- weight_path = os.path.dirname(safetensor_path)
 
 
 
 
241
  return weight_path
242
 
243
- def load_peca(self, pipeline, repo_id = "Burf/DrUM"):
244
- adapter, feature_encoder, size, num_inference_steps, skip = peca(pipeline, save_path = self.load_weight(pipeline, repo_id))
245
  return adapter, feature_encoder, size, num_inference_steps, skip
246
 
247
  def __call__(self, prompt, ref = None, weight = None, alpha = 0.3, skip = None, sampling = False, seed = 42,
 
183
  return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip
184
 
185
  class DrUM(DiffusionPipeline):
186
+ def __init__(self, pipeline, repo_id = "Burf/DrUM", weight = None):
187
  """
188
  DrUM for various diffusion models
189
 
 
194
  self.pipeline = pipeline
195
  self.repo_id = repo_id
196
 
197
+ self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(pipeline, repo_id, weight)
198
 
199
  @classmethod
200
  def from_pretrained(cls, model_id, repo_id = "Burf/DrUM", torch_dtype = torch.bfloat16, device = "cuda"):
 
220
  #pipeline.safety_checker = lambda images, clip_input: (images, [False] * len(images))
221
  return cls(pipeline, repo_id)
222
 
223
+ def load_weight(self, pipeline, repo_id = "Burf/DrUM", weight = None):
224
  name = pipeline.config._name_or_path.split("/")[-1].lower()
225
 
226
  weights = []
 
236
  weights = ["L.safetensors"]
237
 
238
  for weight_file in weights:
239
+ if isinstance(weight, str) and os.path.exists(os.path.join(weight, weight_file)):
240
+ weight_path = weight
241
+ break
242
+ else:
243
+ safetensor_path = hf_hub_download(repo_id = repo_id, filename = "weight/" + weight_file)
244
+ weight_path = os.path.dirname(safetensor_path)
245
  return weight_path
246
 
247
+ def load_peca(self, pipeline, repo_id = "Burf/DrUM", weight = None):
248
+ adapter, feature_encoder, size, num_inference_steps, skip = peca(pipeline, save_path = self.load_weight(pipeline, repo_id, weight))
249
  return adapter, feature_encoder, size, num_inference_steps, skip
250
 
251
  def __call__(self, prompt, ref = None, weight = None, alpha = 0.3, skip = None, sampling = False, seed = 42,