Enzo8930302 commited on
Commit
74d320c
·
verified ·
1 Parent(s): 9e92643

Upload bytedream/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bytedream/model.py +89 -44
bytedream/model.py CHANGED
@@ -33,7 +33,7 @@ class ResnetBlock2D(nn.Module):
33
  self.dropout = nn.Dropout(0.0)
34
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
35
 
36
- self.nonlinearity = nn.SiLU(inplace=True)
37
 
38
  if in_channels != out_channels:
39
  self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
@@ -79,15 +79,18 @@ class AttentionBlock(nn.Module):
79
  ):
80
  super().__init__()
81
 
82
- inner_dim = num_heads * head_dim if head_dim is not None else query_dim
83
- cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
 
 
 
 
84
 
85
  self.num_heads = num_heads
86
- self.head_dim = head_dim if head_dim is not None else query_dim // num_heads
87
 
88
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
89
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
90
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
91
 
92
  self.to_out = nn.ModuleList([
93
  nn.Linear(inner_dim, query_dim),
@@ -282,15 +285,21 @@ class UpBlock2D(nn.Module):
282
  align_corners=False
283
  )
284
 
285
- # Ensure channel dimensions match (project if needed)
286
- expected_channels = self.resnets[i].conv1.in_channels - hidden_states.shape[1]
287
- if res_hidden_state.shape[1] != expected_channels:
288
- # Project skip connection to expected channels
289
- res_hidden_state = nn.functional.conv2d(
290
- res_hidden_state,
291
- torch.randn(expected_channels, res_hidden_state.shape[1], 1, 1, device=res_hidden_state.device) * 0.01,
292
- padding=0
293
- )
 
 
 
 
 
 
294
 
295
  hidden_states = torch.cat([hidden_states, res_hidden_state], dim=1)
296
 
@@ -325,7 +334,7 @@ class TimestepEmbedding(nn.Module):
325
 
326
  # Projection layers
327
  self.linear_1 = nn.Linear(in_features, time_embed_dim)
328
- self.activation = nn.SiLU(inplace=True)
329
  self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
330
 
331
  def forward(self, timestep: torch.Tensor) -> torch.Tensor:
@@ -362,6 +371,7 @@ class UNet2DConditionModel(nn.Module):
362
  attention_head_dim: int = 8,
363
  cross_attention_dim: int = 768,
364
  use_linear_projection: bool = True,
 
365
  ):
366
  super().__init__()
367
 
@@ -369,6 +379,7 @@ class UNet2DConditionModel(nn.Module):
369
  self.block_out_channels = block_out_channels
370
  self.layers_per_block = layers_per_block
371
  self.cross_attention_dim = cross_attention_dim
 
372
 
373
  # Time embedding
374
  time_embed_dim = block_out_channels[0] * 4
@@ -445,7 +456,7 @@ class UNet2DConditionModel(nn.Module):
445
 
446
  # Output
447
  self.conv_norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[0], eps=1e-6)
448
- self.conv_act = nn.SiLU(inplace=True)
449
  self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
450
 
451
  def forward(
@@ -465,31 +476,59 @@ class UNet2DConditionModel(nn.Module):
465
  down_block_res_samples = (hidden_states,)
466
 
467
  for downsample_block in self.down_blocks:
468
- hidden_states, res_samples = downsample_block(
469
- hidden_states=hidden_states,
470
- temb=temb,
471
- encoder_hidden_states=encoder_hidden_states,
472
- )
 
 
 
 
 
 
 
473
  down_block_res_samples += res_samples
474
 
475
  # Middle
476
  for layer in self.mid_block:
477
- if isinstance(layer, ResnetBlock2D):
478
- hidden_states = layer(hidden_states, temb)
 
 
 
 
 
 
 
 
 
 
 
479
  else:
480
- hidden_states = layer(hidden_states, encoder_hidden_states)
 
 
 
481
 
482
  # Up sampling path
483
  for upsample_block in self.up_blocks:
484
  res_samples = down_block_res_samples[-len(upsample_block.resnets):]
485
  down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
486
 
487
- hidden_states = upsample_block(
488
- hidden_states=hidden_states,
489
- res_hidden_states_tuple=res_samples,
490
- temb=temb,
491
- encoder_hidden_states=encoder_hidden_states,
492
- )
 
 
 
 
 
 
 
493
 
494
  # Output
495
  hidden_states = self.conv_norm_out(hidden_states)
@@ -513,40 +552,42 @@ class AutoencoderKL(nn.Module):
513
  up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",) * 4,
514
  latent_channels: int = 4,
515
  sample_size: int = 512,
 
516
  ):
517
  super().__init__()
518
 
519
  self.sample_size = sample_size
 
520
 
521
- # Encoder
522
  self.encoder = nn.ModuleList()
523
- channels = [in_channels, 128, 256, 512, 512]
524
 
525
  for i in range(len(down_block_types)):
526
  block = nn.Sequential(
527
  nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
528
- nn.GroupNorm(num_groups=32, num_channels=channels[i+1], eps=1e-6),
529
- nn.SiLU(inplace=True),
530
  )
531
  self.encoder.append(block)
532
 
533
  # Latent space projection
534
- self.quant_conv = nn.Conv2d(512, latent_channels * 2, kernel_size=1)
535
 
536
- # Decoder
537
  self.decoder = nn.ModuleList()
538
- decoder_channels = [latent_channels, 512, 512, 256, 128]
539
 
540
  for i in range(len(up_block_types)):
541
  block = nn.Sequential(
542
  nn.ConvTranspose2d(decoder_channels[i], decoder_channels[i+1], kernel_size=4, stride=2, padding=1),
543
- nn.GroupNorm(num_groups=32, num_channels=decoder_channels[i+1], eps=1e-6),
544
- nn.SiLU(inplace=True),
545
  )
546
  self.decoder.append(block)
547
 
548
- self.post_quant_conv = nn.Conv2d(latent_channels, 512, kernel_size=1)
549
- self.conv_out = nn.Conv2d(128, out_channels, kernel_size=3, stride=1, padding=1)
550
 
551
  def encode(self, x: torch.Tensor) -> torch.Tensor:
552
  """Encode image to latent space"""
@@ -576,15 +617,17 @@ class CLIPTextModel(nn.Module):
576
  Extracts semantic features from text for conditioning
577
  """
578
 
579
- def __init__(self, model_name: str = "openai/clip-vit-large-patch14", max_length: int = 77):
580
  super().__init__()
581
 
582
  try:
583
  from transformers import CLIPTextModel as HFCLIPTextModel, CLIPTokenizer
584
 
 
585
  self.model = HFCLIPTextModel.from_pretrained(model_name)
586
  self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
587
  self.max_length = max_length
 
588
  except ImportError:
589
  print("Warning: transformers not installed. Using dummy text encoder.")
590
  self.model = None
@@ -603,7 +646,7 @@ class CLIPTextModel(nn.Module):
603
  """
604
  if self.model is None:
605
  # Dummy implementation if transformers not available
606
- return torch.zeros(1, 77, 768)
607
 
608
  inputs = self.tokenizer(
609
  text,
@@ -631,6 +674,7 @@ def create_unet(config):
631
  attention_head_dim=unet_config['attention_head_dim'],
632
  cross_attention_dim=unet_config['cross_attention_dim'],
633
  use_linear_projection=unet_config['use_linear_projection'],
 
634
  )
635
 
636
 
@@ -644,6 +688,7 @@ def create_vae(config):
644
  up_block_types=tuple(vae_config['up_block_types']),
645
  latent_channels=vae_config['latent_channels'],
646
  sample_size=vae_config['sample_size'],
 
647
  )
648
 
649
 
 
33
  self.dropout = nn.Dropout(0.0)
34
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
35
 
36
+ self.nonlinearity = nn.SiLU()
37
 
38
  if in_channels != out_channels:
39
  self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
 
79
  ):
80
  super().__init__()
81
 
82
+ # Use head_dim if provided, otherwise calculate from query_dim and num_heads
83
+ self.head_dim = head_dim if head_dim is not None else query_dim // num_heads
84
+ inner_dim = self.head_dim * num_heads
85
+
86
+ # Use cross_attention_dim if provided, otherwise use query_dim (self-attention)
87
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
88
 
89
  self.num_heads = num_heads
 
90
 
91
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
92
+ self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=False)
93
+ self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=False)
94
 
95
  self.to_out = nn.ModuleList([
96
  nn.Linear(inner_dim, query_dim),
 
285
  align_corners=False
286
  )
287
 
288
+ # Ensure channel dimensions match
289
+ # The resnet expects input = hidden_states + res_hidden_state concatenated
290
+ expected_in_channels = self.resnets[i].conv1.in_channels
291
+ actual_in_channels = hidden_states.shape[1] + res_hidden_state.shape[1]
292
+
293
+ if actual_in_channels != expected_in_channels:
294
+ # Project skip connection to match expected channels
295
+ channel_diff = expected_in_channels - hidden_states.shape[1]
296
+ if channel_diff > 0 and channel_diff != res_hidden_state.shape[1]:
297
+ # Need to project skip connection
298
+ res_hidden_state = nn.functional.conv2d(
299
+ res_hidden_state,
300
+ torch.randn(channel_diff, res_hidden_state.shape[1], 1, 1, device=res_hidden_state.device) * 0.01,
301
+ padding=0
302
+ )
303
 
304
  hidden_states = torch.cat([hidden_states, res_hidden_state], dim=1)
305
 
 
334
 
335
  # Projection layers
336
  self.linear_1 = nn.Linear(in_features, time_embed_dim)
337
+ self.activation = nn.SiLU()
338
  self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
339
 
340
  def forward(self, timestep: torch.Tensor) -> torch.Tensor:
 
371
  attention_head_dim: int = 8,
372
  cross_attention_dim: int = 768,
373
  use_linear_projection: bool = True,
374
+ use_gradient_checkpointing: bool = False,
375
  ):
376
  super().__init__()
377
 
 
379
  self.block_out_channels = block_out_channels
380
  self.layers_per_block = layers_per_block
381
  self.cross_attention_dim = cross_attention_dim
382
+ self.use_gradient_checkpointing = use_gradient_checkpointing
383
 
384
  # Time embedding
385
  time_embed_dim = block_out_channels[0] * 4
 
456
 
457
  # Output
458
  self.conv_norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[0], eps=1e-6)
459
+ self.conv_act = nn.SiLU()
460
  self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
461
 
462
  def forward(
 
476
  down_block_res_samples = (hidden_states,)
477
 
478
  for downsample_block in self.down_blocks:
479
+ if self.use_gradient_checkpointing and self.training:
480
+ hidden_states, res_samples = torch.utils.checkpoint.checkpoint(
481
+ lambda hs, t, ehs: downsample_block(hs, t, ehs),
482
+ hidden_states, temb, encoder_hidden_states,
483
+ use_reentrant=False
484
+ )
485
+ else:
486
+ hidden_states, res_samples = downsample_block(
487
+ hidden_states=hidden_states,
488
+ temb=temb,
489
+ encoder_hidden_states=encoder_hidden_states,
490
+ )
491
  down_block_res_samples += res_samples
492
 
493
  # Middle
494
  for layer in self.mid_block:
495
+ if self.use_gradient_checkpointing and self.training:
496
+ if isinstance(layer, ResnetBlock2D):
497
+ hidden_states = torch.utils.checkpoint.checkpoint(
498
+ lambda hs, t: layer(hs, t),
499
+ hidden_states, temb,
500
+ use_reentrant=False
501
+ )
502
+ else:
503
+ hidden_states = torch.utils.checkpoint.checkpoint(
504
+ lambda hs, ehs: layer(hs, ehs),
505
+ hidden_states, encoder_hidden_states,
506
+ use_reentrant=False
507
+ )
508
  else:
509
+ if isinstance(layer, ResnetBlock2D):
510
+ hidden_states = layer(hidden_states, temb)
511
+ else:
512
+ hidden_states = layer(hidden_states, encoder_hidden_states)
513
 
514
  # Up sampling path
515
  for upsample_block in self.up_blocks:
516
  res_samples = down_block_res_samples[-len(upsample_block.resnets):]
517
  down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
518
 
519
+ if self.use_gradient_checkpointing and self.training:
520
+ hidden_states = torch.utils.checkpoint.checkpoint(
521
+ lambda hs, res, t, ehs: upsample_block(hs, res, t, ehs),
522
+ hidden_states, res_samples, temb, encoder_hidden_states,
523
+ use_reentrant=False
524
+ )
525
+ else:
526
+ hidden_states = upsample_block(
527
+ hidden_states=hidden_states,
528
+ res_hidden_states_tuple=res_samples,
529
+ temb=temb,
530
+ encoder_hidden_states=encoder_hidden_states,
531
+ )
532
 
533
  # Output
534
  hidden_states = self.conv_norm_out(hidden_states)
 
552
  up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",) * 4,
553
  latent_channels: int = 4,
554
  sample_size: int = 512,
555
+ block_out_channels: Tuple[int, ...] = (64, 128, 256, 512),
556
  ):
557
  super().__init__()
558
 
559
  self.sample_size = sample_size
560
+ self.block_out_channels = block_out_channels
561
 
562
+ # Encoder - using reduced channels for memory efficiency
563
  self.encoder = nn.ModuleList()
564
+ channels = [in_channels] + list(block_out_channels)
565
 
566
  for i in range(len(down_block_types)):
567
  block = nn.Sequential(
568
  nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
569
+ nn.GroupNorm(num_groups=min(32, channels[i+1]), num_channels=channels[i+1], eps=1e-6),
570
+ nn.SiLU(),
571
  )
572
  self.encoder.append(block)
573
 
574
  # Latent space projection
575
+ self.quant_conv = nn.Conv2d(block_out_channels[-1], latent_channels * 2, kernel_size=1)
576
 
577
+ # Decoder - using reduced channels for memory efficiency
578
  self.decoder = nn.ModuleList()
579
+ decoder_channels = [latent_channels] + list(reversed(block_out_channels))
580
 
581
  for i in range(len(up_block_types)):
582
  block = nn.Sequential(
583
  nn.ConvTranspose2d(decoder_channels[i], decoder_channels[i+1], kernel_size=4, stride=2, padding=1),
584
+ nn.GroupNorm(num_groups=min(32, decoder_channels[i+1]), num_channels=decoder_channels[i+1], eps=1e-6),
585
+ nn.SiLU(),
586
  )
587
  self.decoder.append(block)
588
 
589
+ self.post_quant_conv = nn.Conv2d(latent_channels, block_out_channels[-1], kernel_size=1)
590
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
591
 
592
  def encode(self, x: torch.Tensor) -> torch.Tensor:
593
  """Encode image to latent space"""
 
617
  Extracts semantic features from text for conditioning
618
  """
619
 
620
+ def __init__(self, model_name: str = "openai/clip-vit-base-patch32", max_length: int = 77):
621
  super().__init__()
622
 
623
  try:
624
  from transformers import CLIPTextModel as HFCLIPTextModel, CLIPTokenizer
625
 
626
+ print(f"Loading CLIP text encoder: {model_name}...")
627
  self.model = HFCLIPTextModel.from_pretrained(model_name)
628
  self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
629
  self.max_length = max_length
630
+ print(f"✓ CLIP text encoder loaded successfully on CPU")
631
  except ImportError:
632
  print("Warning: transformers not installed. Using dummy text encoder.")
633
  self.model = None
 
646
  """
647
  if self.model is None:
648
  # Dummy implementation if transformers not available
649
+ return torch.zeros(1, 77, 512)
650
 
651
  inputs = self.tokenizer(
652
  text,
 
674
  attention_head_dim=unet_config['attention_head_dim'],
675
  cross_attention_dim=unet_config['cross_attention_dim'],
676
  use_linear_projection=unet_config['use_linear_projection'],
677
+ use_gradient_checkpointing=True, # Enable for memory efficiency
678
  )
679
 
680
 
 
688
  up_block_types=tuple(vae_config['up_block_types']),
689
  latent_channels=vae_config['latent_channels'],
690
  sample_size=vae_config['sample_size'],
691
+ block_out_channels=tuple(vae_config.get('block_out_channels', [64, 128, 256, 512])),
692
  )
693
 
694