schrum2 commited on
Commit
634e6bc
·
verified ·
1 Parent(s): b54bd25

just use snapshot download?

Browse files
models/pipeline_loader.py CHANGED
@@ -2,6 +2,7 @@ from models.text_diffusion_pipeline import TextConditionalDDPMPipeline
2
  from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
3
  import os
4
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
 
5
 
6
 
7
  def get_pipeline(model_path):
@@ -16,24 +17,16 @@ def get_pipeline(model_path):
16
  #If it has no text encoder, use the unconditional diffusion model
17
  pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
18
  else:
19
- # Assume it's a Hugging Face Hub model ID
20
- # Try to load config to determine if it's text-conditional
21
- config = DiffusionPipeline.load_config(model_path)
22
- has_text_encoder = "text_encoder" in config
23
-
 
 
24
  if has_text_encoder:
25
- # Use the local pipeline file for custom_pipeline
26
- pipe = DiffusionPipeline.from_pretrained(
27
- model_path,
28
- custom_pipeline=model_path,
29
- trust_remote_code=True,
30
- )
31
  else:
32
- # Fallback: try unconditional
33
- pipe = DiffusionPipeline.from_pretrained(
34
- model_path,
35
- custom_pipeline=model_path,
36
- trust_remote_code=True,
37
- )
38
 
39
  return pipe
 
2
  from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
3
  import os
4
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
5
+ from huggingface_hub import snapshot_download
6
 
7
 
8
  def get_pipeline(model_path):
 
17
  #If it has no text encoder, use the unconditional diffusion model
18
  pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
19
  else:
20
+ # For HF Hub models, download first then load locally
21
+ print(f"Downloading model {model_path}...")
22
+ local_path = snapshot_download(repo_id=model_path, cache_dir="./temp_model_cache")
23
+
24
+ # Check what components exist
25
+ has_text_encoder = os.path.exists(os.path.join(local_path, "text_encoder"))
26
+
27
  if has_text_encoder:
28
+ pipe = TextConditionalDDPMPipeline.from_pretrained(local_path)
 
 
 
 
 
29
  else:
30
+ pipe = UnconditionalDDPMPipeline.from_pretrained(local_path)
 
 
 
 
 
31
 
32
  return pipe
models/text_diffusion_pipeline.py CHANGED
@@ -18,18 +18,11 @@ class PipelineOutput(NamedTuple):
18
  images: torch.Tensor
19
 
20
 
21
-
22
  # Create a custom pipeline for text-conditional generation
23
  class TextConditionalDDPMPipeline(DDPMPipeline):
24
  def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
25
- # Debug: Print what we're receiving
26
- print(f"unet type: {type(unet)}, value: {unet}")
27
- print(f"scheduler type: {type(scheduler)}, value: {scheduler}")
28
- print(f"text_encoder type: {type(text_encoder)}, value: {text_encoder}")
29
- print(f"tokenizer type: {type(tokenizer)}, value: {tokenizer}")
30
-
31
- # Call DiffusionPipeline.__init__() directly (skipping DDPMPipeline's init)
32
- DiffusionPipeline.__init__(self)
33
 
34
  self.text_encoder = text_encoder
35
  self.tokenizer = tokenizer
@@ -39,24 +32,18 @@ class TextConditionalDDPMPipeline(DDPMPipeline):
39
 
40
  if self.tokenizer is None and self.text_encoder is not None:
41
  # Use the tokenizer from the text encoder if not provided
42
- self.tokenizer = self.text_encoder.tokenizer
43
-
44
- # Only register modules that are actual objects, not None or lists
45
- modules_to_register = {}
46
-
47
- if unet is not None and not isinstance(unet, (list, tuple)):
48
- modules_to_register['unet'] = unet
49
- if scheduler is not None and not isinstance(scheduler, (list, tuple)):
50
- modules_to_register['scheduler'] = scheduler
51
- if self.text_encoder is not None and not isinstance(self.text_encoder, (list, tuple)):
52
- modules_to_register['text_encoder'] = self.text_encoder
53
- if self.tokenizer is not None and not isinstance(self.tokenizer, (list, tuple)):
54
- modules_to_register['tokenizer'] = self.tokenizer
55
 
56
- print(f"Registering modules: {list(modules_to_register.keys())}")
57
-
58
- # Register ALL modules at once
59
- self.register_modules(**modules_to_register)
 
 
 
 
 
60
 
61
  # Override the to() method to ensure text_encoder is moved to the correct device
62
  def to(self, device=None, dtype=None):
 
18
  images: torch.Tensor
19
 
20
 
 
21
  # Create a custom pipeline for text-conditional generation
22
  class TextConditionalDDPMPipeline(DDPMPipeline):
23
  def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
24
+ # Call parent class init normally
25
+ super().__init__(unet=unet, scheduler=scheduler)
 
 
 
 
 
 
26
 
27
  self.text_encoder = text_encoder
28
  self.tokenizer = tokenizer
 
32
 
33
  if self.tokenizer is None and self.text_encoder is not None:
34
  # Use the tokenizer from the text encoder if not provided
35
+ if hasattr(self.text_encoder, 'tokenizer'):
36
+ self.tokenizer = self.text_encoder.tokenizer
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Register additional modules if they exist
39
+ additional_modules = {}
40
+ if self.text_encoder is not None:
41
+ additional_modules['text_encoder'] = self.text_encoder
42
+ if self.tokenizer is not None:
43
+ additional_modules['tokenizer'] = self.tokenizer
44
+
45
+ if additional_modules:
46
+ self.register_modules(**additional_modules)
47
 
48
  # Override the to() method to ensure text_encoder is moved to the correct device
49
  def to(self, device=None, dtype=None):