schrum2 commited on
Commit
b54bd25
·
verified ·
1 Parent(s): 8dfa94d

debugging

Browse files
Files changed (1) hide show
  1. models/text_diffusion_pipeline.py +24 -9
models/text_diffusion_pipeline.py CHANGED
@@ -22,8 +22,14 @@ class PipelineOutput(NamedTuple):
22
  # Create a custom pipeline for text-conditional generation
23
  class TextConditionalDDPMPipeline(DDPMPipeline):
24
  def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
25
- # Don't call super().__init__() with arguments, call it without arguments first
26
- super(DiffusionPipeline, self).__init__()
 
 
 
 
 
 
27
 
28
  self.text_encoder = text_encoder
29
  self.tokenizer = tokenizer
@@ -35,13 +41,22 @@ class TextConditionalDDPMPipeline(DDPMPipeline):
35
  # Use the tokenizer from the text encoder if not provided
36
  self.tokenizer = self.text_encoder.tokenizer
37
 
38
- # Register ALL modules at once, including the ones from the parent class
39
- self.register_modules(
40
- unet=unet,
41
- scheduler=scheduler,
42
- text_encoder=self.text_encoder,
43
- tokenizer=self.tokenizer,
44
- )
 
 
 
 
 
 
 
 
 
45
 
46
  # Override the to() method to ensure text_encoder is moved to the correct device
47
  def to(self, device=None, dtype=None):
 
22
  # Create a custom pipeline for text-conditional generation
23
  class TextConditionalDDPMPipeline(DDPMPipeline):
24
  def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
25
+ # Debug: Print what we're receiving
26
+ print(f"unet type: {type(unet)}, value: {unet}")
27
+ print(f"scheduler type: {type(scheduler)}, value: {scheduler}")
28
+ print(f"text_encoder type: {type(text_encoder)}, value: {text_encoder}")
29
+ print(f"tokenizer type: {type(tokenizer)}, value: {tokenizer}")
30
+
31
+ # Call DiffusionPipeline.__init__() directly (skipping DDPMPipeline's init)
32
+ DiffusionPipeline.__init__(self)
33
 
34
  self.text_encoder = text_encoder
35
  self.tokenizer = tokenizer
 
41
  # Use the tokenizer from the text encoder if not provided
42
  self.tokenizer = self.text_encoder.tokenizer
43
 
44
+ # Only register modules that are actual objects, not None or lists
45
+ modules_to_register = {}
46
+
47
+ if unet is not None and not isinstance(unet, (list, tuple)):
48
+ modules_to_register['unet'] = unet
49
+ if scheduler is not None and not isinstance(scheduler, (list, tuple)):
50
+ modules_to_register['scheduler'] = scheduler
51
+ if self.text_encoder is not None and not isinstance(self.text_encoder, (list, tuple)):
52
+ modules_to_register['text_encoder'] = self.text_encoder
53
+ if self.tokenizer is not None and not isinstance(self.tokenizer, (list, tuple)):
54
+ modules_to_register['tokenizer'] = self.tokenizer
55
+
56
+ print(f"Registering modules: {list(modules_to_register.keys())}")
57
+
58
+ # Register ALL modules at once
59
+ self.register_modules(**modules_to_register)
60
 
61
  # Override the to() method to ensure text_encoder is moved to the correct device
62
  def to(self, device=None, dtype=None):