Shaoan commited on
Commit
ead4126
·
verified ·
1 Parent(s): 523214a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. text_encoder.py +5 -1018
text_encoder.py CHANGED
@@ -207,694 +207,6 @@ import random
207
 
208
  from torch.utils.checkpoint import checkpoint
209
  from peft import LoraConfig, set_peft_model_state_dict
210
- class LoraT5EmbedderNoGradientCheck(torch.nn.Module):
211
- def __init__(self, device, rank=64, max_length=300):
212
- super().__init__()
213
- self.device = device
214
- self.max_length = max_length
215
- dtype = torch.bfloat16
216
- self.dtype = dtype
217
- t5_version = './t5-v1_1-xxl'
218
- self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
219
- self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device).to(dtype)
220
- self.t5_encoder.gradient_checkpointing_enable()
221
- self.t5_encoder.config.gradient_checkpointing = True
222
- self.t5_encoder.requires_grad_(False)
223
- self.t5_encoder.eval()
224
- # Add LoRA adapters to the T5 model
225
- text_lora_config = LoraConfig(
226
- r=rank,
227
- lora_alpha=rank,
228
- lora_dropout=0.0,
229
- init_lora_weights="gaussian",
230
- target_modules=["SelfAttention.q", "SelfAttention.k", "SelfAttention.v", "SelfAttention.o", "DenseReluDense.wi", "DenseReluDense.wo"],
231
- )
232
- self.t5_encoder.add_adapter(text_lora_config)
233
- #self.t5_encoder.encoder.embed_tokens.weight.requires_grad = True
234
- print(f"Gradient checkpointing enabled: {self.t5_encoder.is_gradient_checkpointing}")
235
-
236
- image_encoder_path = 'openai/clip-vit-large-patch14'
237
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device=device).to(torch.bfloat16)
238
- self.image_encoder = self.image_encoder.eval().requires_grad_(False)
239
-
240
- def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding):
241
- """
242
- Compute group lasso for non-pad non-change tokens, L1 for change tokens,
243
- and group sparsity for pad non-change tokens.
244
-
245
- Args:
246
- prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim]
247
- perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim]
248
- replaced_ids: List of replaced token indices for each sample in batch
249
- batch_encoding: The tokenizer output containing input_ids
250
-
251
- Returns:
252
- l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor)
253
- l1_loss: L1 loss for change tokens (scalar tensor)
254
- pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor)
255
- """
256
- batch_size = prompt_embeds.size(0)
257
- pad_token_id = self.t5_tokenizer.pad_token_id
258
- input_ids = batch_encoding["input_ids"]
259
-
260
- l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
261
- l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
262
- pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
263
-
264
- # Track valid samples for each loss type separately
265
- l1_valid_samples = 0
266
- l2_valid_samples = 0
267
- pad_valid_samples = 0
268
-
269
- for i in range(batch_size):
270
- # Get the replaced index for this sample
271
- replaced_idx = replaced_ids[i]
272
-
273
- if replaced_idx is None:
274
- # No replacement happened (all padding), skip
275
- continue
276
-
277
- # Find padding and non-padding token indices
278
- pad_mask = input_ids[i] == pad_token_id
279
- non_pad_mask = ~pad_mask
280
-
281
- pad_indices = torch.where(pad_mask)[0]
282
- non_pad_indices = torch.where(non_pad_mask)[0]
283
-
284
- # Filter out the replaced index from non-padding indices (non-pad non-change)
285
- non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx]
286
-
287
- # Compute L1 loss on selected (replaced) index - CHANGE TOKEN
288
- selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx]
289
- l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean()
290
- l1_valid_samples += 1
291
-
292
- # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens
293
- if len(non_selected_non_pad_indices) > 0:
294
- non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[
295
- i, non_selected_non_pad_indices]
296
- l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1))
297
- l2_loss_total = l2_loss_total + l2_per_token.mean()
298
- l2_valid_samples += 1
299
-
300
- # Compute group sparsity loss on PAD NON-CHANGE tokens
301
- if len(pad_indices) > 0:
302
- pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices]
303
- # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero)
304
- pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1))
305
- pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean()
306
- pad_valid_samples += 1
307
-
308
- # Average over valid samples for each loss type
309
- l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0,
310
- device=prompt_embeds.device)
311
- l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0,
312
- device=prompt_embeds.device)
313
- pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0,
314
- device=prompt_embeds.device)
315
-
316
- return l2_loss, l1_loss, pad_group_loss
317
-
318
-
319
-
320
-
321
- def forward(self, text, image=None):
322
- if isinstance(text, str):
323
- text = [text]
324
- batch_encoding = self.t5_tokenizer(
325
- text,
326
- truncation=True,
327
- max_length=self.max_length,
328
- return_length=False,
329
- return_overflowing_tokens=False,
330
- padding="max_length",
331
- return_tensors="pt",
332
- )
333
- prompt_embeds = self.t5_encoder(
334
- input_ids=batch_encoding["input_ids"].to(self.device),
335
- attention_mask=None,
336
- output_hidden_states=False,
337
- )['last_hidden_state']
338
-
339
- # Get input_ids and create a copy to modify
340
- input_ids = batch_encoding["input_ids"].clone()
341
- batch_size = input_ids.size(0)
342
-
343
- # Get the padding token id
344
- pad_token_id = self.t5_tokenizer.pad_token_id
345
-
346
- replaced_ids = []
347
- # For each sample in the batch
348
- for i in range(batch_size):
349
- # Find indices of non-padding tokens
350
- non_pad_mask = input_ids[i] != pad_token_id
351
- non_pad_indices = torch.where(non_pad_mask)[0]
352
-
353
- # If there are meaningful tokens, randomly select one to replace
354
- if len(non_pad_indices) > 0:
355
- # Randomly select an index from non-padding tokens
356
- random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
357
- # Replace with padding token
358
- input_ids[i, random_idx] = pad_token_id
359
- replaced_ids.append(random_idx.item())
360
- else:
361
- replaced_ids.append(None) # No replacement if all tokens are padding
362
-
363
-
364
- perturbed_prompt_embeds = self.t5_encoder(
365
- input_ids=input_ids.to(self.device),
366
- attention_mask=None,
367
- output_hidden_states=False,
368
- )['last_hidden_state']
369
-
370
- l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss(
371
- prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding
372
- )
373
-
374
- with torch.no_grad():
375
- if image is not None:
376
- clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds
377
- else:
378
- clip_image_embeds = None
379
-
380
-
381
- return prompt_embeds, l2_loss, l1_loss, pad_loss,clip_image_embeds
382
-
383
-
384
- from peft import LoraConfig, set_peft_model_state_dict
385
- import torch.utils.checkpoint as checkpoint
386
- from transformers import CLIPVisionModelWithProjection
387
-
388
- class LoraT5Embedder(torch.nn.Module):
389
- def __init__(self, device, rank=128, max_length=300, use_gradient_checkpointing=True):
390
- super().__init__()
391
- self.device = device
392
- self.max_length = max_length
393
- self.use_gradient_checkpointing = use_gradient_checkpointing
394
- dtype = torch.bfloat16
395
- self.dtype = dtype
396
- t5_version = './t5-v1_1-xxl'
397
- self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
398
-
399
- self.t5_encoder = T5EncoderModel.from_pretrained(
400
- t5_version,
401
- torch_dtype=dtype
402
- ).to(device=device).to(dtype)
403
-
404
- self.t5_encoder.requires_grad_(False)
405
-
406
- # Add LoRA adapters to the T5 model
407
- text_lora_config = LoraConfig(
408
- r=rank,
409
- lora_alpha=rank,
410
- lora_dropout=0.0,
411
- init_lora_weights="gaussian",
412
- target_modules=["q", "k", "v", "o", "wi", "wo"],
413
- )
414
- self.t5_encoder.add_adapter(text_lora_config)
415
- self.t5_encoder.encoder.embed_tokens.weight.requires_grad_(True)
416
-
417
- # Manually implement gradient checkpointing for T5 encoder blocks
418
- if self.use_gradient_checkpointing:
419
- self._enable_gradient_checkpointing()
420
-
421
- print(f"Gradient checkpointing enabled: {self.use_gradient_checkpointing}")
422
-
423
- image_encoder_path = './clip-vit-large-patch14'
424
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
425
- image_encoder_path
426
- ).to(device=device).to(torch.bfloat16)
427
- self.image_encoder = self.image_encoder.eval().requires_grad_(False)
428
-
429
- def _enable_gradient_checkpointing(self):
430
- """
431
- Manually wrap T5 encoder blocks with gradient checkpointing.
432
- """
433
-
434
- def create_custom_forward(module):
435
- def custom_forward(*inputs):
436
- return module(*inputs)
437
-
438
- return custom_forward
439
-
440
- # Wrap each T5 block with checkpointing
441
- for block in self.t5_encoder.encoder.block:
442
- # Store original forward
443
- block._original_forward = block.forward
444
-
445
- # Create checkpointed forward
446
- def make_checkpointed_forward(blk):
447
- def checkpointed_forward(*args, **kwargs):
448
- # Checkpoint requires a function that takes tensors as input
449
- def forward_wrapper(*inputs):
450
- # Reconstruct kwargs from inputs
451
- hidden_states = inputs[0]
452
- attention_mask = inputs[1] if len(inputs) > 1 else None
453
- position_bias = inputs[2] if len(inputs) > 2 else None
454
-
455
- return blk._original_forward(
456
- hidden_states=hidden_states,
457
- attention_mask=attention_mask,
458
- position_bias=position_bias,
459
- **{k: v for k, v in kwargs.items() if
460
- k not in ['hidden_states', 'attention_mask', 'position_bias']}
461
- )
462
-
463
- # Prepare inputs for checkpointing
464
- hidden_states = kwargs.get('hidden_states', args[0] if args else None)
465
- attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None)
466
- position_bias = kwargs.get('position_bias', args[2] if len(args) > 2 else None)
467
-
468
- # Use checkpoint
469
- checkpoint_inputs = [hidden_states]
470
- if attention_mask is not None:
471
- checkpoint_inputs.append(attention_mask)
472
- if position_bias is not None:
473
- checkpoint_inputs.append(position_bias)
474
-
475
- return checkpoint.checkpoint(
476
- forward_wrapper,
477
- *checkpoint_inputs,
478
- use_reentrant=False
479
- )
480
-
481
- return checkpointed_forward
482
-
483
- block.forward = make_checkpointed_forward(block)
484
-
485
- def _encode_text(self, input_ids):
486
- """Helper function to encode text through T5."""
487
- return self.t5_encoder(
488
- input_ids=input_ids.to(self.device),
489
- attention_mask=None,
490
- output_hidden_states=False,
491
- )['last_hidden_state']
492
-
493
- def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding):
494
- """
495
- Compute group lasso for non-pad non-change tokens, L1 for change tokens,
496
- and group sparsity for pad non-change tokens.
497
-
498
- Args:
499
- prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim]
500
- perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim]
501
- replaced_ids: List of replaced token indices for each sample in batch
502
- batch_encoding: The tokenizer output containing input_ids
503
-
504
- Returns:
505
- l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor)
506
- l1_loss: L1 loss for change tokens (scalar tensor)
507
- pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor)
508
- """
509
- batch_size = prompt_embeds.size(0)
510
- pad_token_id = self.t5_tokenizer.pad_token_id
511
- input_ids = batch_encoding["input_ids"]
512
-
513
- l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
514
- l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
515
- pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
516
-
517
- # Track valid samples for each loss type separately
518
- l1_valid_samples = 0
519
- l2_valid_samples = 0
520
- pad_valid_samples = 0
521
-
522
- for i in range(batch_size):
523
- # Get the replaced index for this sample
524
- replaced_idx = replaced_ids[i]
525
-
526
- if replaced_idx is None:
527
- # No replacement happened (all padding), skip
528
- continue
529
-
530
- # Find padding and non-padding token indices
531
- pad_mask = input_ids[i] == pad_token_id
532
- non_pad_mask = ~pad_mask
533
-
534
- pad_indices = torch.where(pad_mask)[0]
535
- non_pad_indices = torch.where(non_pad_mask)[0]
536
-
537
- # Filter out the replaced index from non-padding indices (non-pad non-change)
538
- non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx]
539
-
540
- # Compute L1 loss on selected (replaced) index - CHANGE TOKEN
541
- selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx]
542
- l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean()
543
- l1_valid_samples += 1
544
-
545
- # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens
546
- if len(non_selected_non_pad_indices) > 0:
547
- non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[
548
- i, non_selected_non_pad_indices]
549
- l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1))
550
- l2_loss_total = l2_loss_total + l2_per_token.mean()
551
- l2_valid_samples += 1
552
-
553
- # Compute group sparsity loss on PAD NON-CHANGE tokens
554
- if len(pad_indices) > 0:
555
- pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices]
556
- # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero)
557
- pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1))
558
- pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean()
559
- pad_valid_samples += 1
560
-
561
- # Average over valid samples for each loss type
562
- l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0,
563
- device=prompt_embeds.device)
564
- l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0,
565
- device=prompt_embeds.device)
566
- pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0,
567
- device=prompt_embeds.device)
568
-
569
- return l2_loss, l1_loss, pad_group_loss
570
-
571
- def forward(self, text, image=None):
572
- if isinstance(text, str):
573
- text = [text]
574
- batch_encoding = self.t5_tokenizer(
575
- text,
576
- truncation=True,
577
- max_length=self.max_length,
578
- return_length=False,
579
- return_overflowing_tokens=False,
580
- padding="max_length",
581
- return_tensors="pt",
582
- )
583
- attn_mask = batch_encoding["attention_mask"].to(self.device)
584
-
585
- # First encoding
586
- prompt_embeds = self._encode_text(batch_encoding["input_ids"])
587
-
588
- # Get input_ids and create a copy to modify
589
- input_ids = batch_encoding["input_ids"].clone()
590
- batch_size = input_ids.size(0)
591
-
592
- # Get the padding token id
593
- # get the id for the first sentinel token
594
- mask_token = "<extra_id_0>"
595
- mask_token_id = self.t5_tokenizer.convert_tokens_to_ids(mask_token)
596
- pad_token_id = self.t5_tokenizer.pad_token_id
597
-
598
- replaced_ids = []
599
- # For each sample in the batch
600
- for i in range(batch_size):
601
- # Find indices of non-padding tokens
602
- non_pad_mask = input_ids[i] != pad_token_id
603
- non_pad_indices = torch.where(non_pad_mask)[0]
604
-
605
- # If there are meaningful tokens, randomly select one to replace
606
- if len(non_pad_indices) > 0:
607
- # Randomly select an index from non-padding tokens
608
- random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
609
- random_idx2 = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
610
- # Replace with padding token
611
- input_ids[i, random_idx] = mask_token_id
612
- replaced_ids.append(random_idx.item())
613
- else:
614
- replaced_ids.append(None) # No replacement if all tokens are padding
615
-
616
- # Second encoding with perturbed input
617
- perturbed_prompt_embeds = self._encode_text(input_ids)
618
-
619
- """
620
- l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss(
621
- prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding
622
- )
623
- """
624
-
625
- with torch.no_grad():
626
- if image is not None:
627
- clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds
628
- else:
629
- clip_image_embeds = None
630
-
631
- #return prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask
632
- return prompt_embeds, clip_image_embeds, perturbed_prompt_embeds, replaced_ids, self.t5_tokenizer, batch_encoding
633
-
634
-
635
- import torch.func as func
636
-
637
- class FullJacobianLoraT5Embedder(torch.nn.Module):
638
- def __init__(self, device, rank=64, max_length=512, use_gradient_checkpointing=True,
639
- num_jacobian_samples=1):
640
- super().__init__()
641
- self.device = device
642
- self.max_length = max_length
643
- self.use_gradient_checkpointing = use_gradient_checkpointing
644
- self.num_jacobian_samples = num_jacobian_samples # Number of random columns to sample
645
-
646
- dtype = torch.bfloat16
647
- self.dtype = dtype
648
- t5_version = './t5-v1_1-xxl'
649
- self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
650
-
651
- self.t5_encoder = T5EncoderModel.from_pretrained(
652
- t5_version,
653
- dtype=dtype
654
- ).to(device=device).to(dtype)
655
-
656
- self.t5_encoder.requires_grad_(False)
657
-
658
- # Add LoRA adapters to the T5 model
659
- text_lora_config = LoraConfig(
660
- r=rank,
661
- lora_alpha=rank,
662
- lora_dropout=0.0,
663
- init_lora_weights="gaussian",
664
- target_modules=["q", "k", "v", "o", "wi", "wo"],
665
- )
666
- self.t5_encoder.add_adapter(text_lora_config)
667
- self.t5_encoder.encoder.embed_tokens.weight.requires_grad_(True)
668
-
669
- # Manually implement gradient checkpointing for T5 encoder blocks
670
- if self.use_gradient_checkpointing:
671
- self._enable_gradient_checkpointing()
672
-
673
- print(f"Gradient checkpointing enabled: {self.use_gradient_checkpointing}")
674
- print(f"Jacobian samples per batch: {self.num_jacobian_samples}")
675
-
676
- image_encoder_path = './clip-vit-large-patch14'
677
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
678
- image_encoder_path
679
- ).to(device=device).to(torch.bfloat16)
680
- self.image_encoder = self.image_encoder.eval().requires_grad_(False)
681
-
682
- def compute_jacobian_loss(self, input_embeds, attention_mask):
683
- """
684
- Compute L1 Jacobian sparsity loss using forward-mode AD (JVP).
685
-
686
- Note: Temporarily disables gradient checkpointing as it's incompatible with JVP.
687
- """
688
- batch_size, seq_len, hidden_dim = input_embeds.shape
689
- input_embeds = input_embeds[:1]
690
- attention_mask = attention_mask[:1]
691
-
692
- # Temporarily disable gradient checkpointing
693
- original_checkpointing = self.use_gradient_checkpointing
694
- if original_checkpointing:
695
- self._disable_gradient_checkpointing()
696
-
697
- if True:
698
- if True:
699
- """
700
- Compute same-token and cross-token Jacobian sparsity losses.
701
- Assumes left-aligned mask: attention_mask[b] = [1...1, 0...0]
702
- Probes one (token, dim) per batch element per JVP sample.
703
- """
704
- B, S, H = input_embeds.shape
705
- device = input_embeds.device
706
-
707
- # Count valid tokens per batch element
708
- lengths = attention_mask.sum(dim=1) # [B]
709
- valid_batch = lengths>0
710
- if valid_batch.sum() == 0:
711
- z = input_embeds.new_zeros(())
712
- return z, z
713
-
714
- same_token_loss = input_embeds.new_zeros(())
715
- cross_token_loss = input_embeds.new_zeros(())
716
-
717
- def model_fn(embeds):
718
- return self.t5_encoder.encoder(
719
- inputs_embeds=embeds,
720
- attention_mask=None,
721
- output_hidden_states=False,
722
- ).last_hidden_state
723
-
724
- batch_idx = torch.arange(B, device=device)
725
-
726
- for _ in range(self.num_jacobian_samples):
727
- # Sample one valid token position per batch element
728
- t = torch.zeros(B, dtype=torch.long, device=device)
729
- u = torch.rand(B, device=device)
730
- # For valid batches: uniform over [0, lengths[b])
731
- # For invalid batches: stays 0 (doesn't matter, will be masked out)
732
- t[valid_batch] = (u[valid_batch] * lengths[valid_batch].float()).long()
733
-
734
- # Sample one hidden dim per batch element
735
- k = torch.randint(0, H, (B,), device=device)
736
-
737
- # Tangent: one scalar per batch element at position [b, t[b], k[b]]
738
- tangent = torch.zeros_like(input_embeds)
739
- tangent[batch_idx, t, k] = 1.0
740
-
741
- # JVP
742
- _, jvp = func.jvp(model_fn, (input_embeds,), (tangent,))
743
- abs_jvp = jvp.abs() # [B, S, H]
744
-
745
- # SAME-token: diagonal element for each batch
746
- diag = abs_jvp[batch_idx, t, :] # [B, H]
747
- same_token_loss = same_token_loss + diag[valid_batch].sum()
748
-
749
- # CROSS-token: all valid positions except diagonal
750
- # Create position mask: valid positions are [0, lengths[b])
751
- pos = torch.arange(S, device=device).unsqueeze(0) # [1, S]
752
- valid_pos_mask = pos < lengths.unsqueeze(1) # [B, S]
753
-
754
- # Exclude diagonal
755
- cross_mask = valid_pos_mask.clone()
756
- cross_mask[batch_idx, t] = False
757
-
758
- cross_token_loss = cross_token_loss + abs_jvp[cross_mask].sum()
759
-
760
- # ---- Normalization (keep as tensors for AMP) ----
761
- num_valid_batches = valid_batch.sum() # Keep as tensor
762
-
763
- # Same-token: mean per output element over (num_samples × num_valid_batches × H)
764
- same_token_loss = same_token_loss / (self.num_jacobian_samples * num_valid_batches)
765
-
766
- # Cross-token: mean per output element over (num_samples × total_cross_positions × H)
767
- # total_cross_positions = sum over valid batches of (lengths[b] - 1)
768
- cross_counts = (lengths[valid_batch] - 1).clamp(min=0).sum() # Keep as tensor
769
-
770
- if cross_counts > 0:
771
- cross_token_loss = cross_token_loss / (self.num_jacobian_samples * cross_counts)
772
- else:
773
- cross_token_loss = input_embeds.new_zeros(())
774
-
775
- # Re-enable gradient checkpointing
776
- if original_checkpointing:
777
- self._enable_gradient_checkpointing()
778
-
779
- return same_token_loss, cross_token_loss
780
-
781
- def _disable_gradient_checkpointing(self):
782
- """Restore original forward methods without checkpointing."""
783
- for block in self.t5_encoder.encoder.block:
784
- if hasattr(block, '_original_forward'):
785
- block.forward = block._original_forward
786
-
787
- def _enable_gradient_checkpointing(self):
788
- """Manually wrap T5 encoder blocks with gradient checkpointing."""
789
- from torch.utils.checkpoint import checkpoint as cp
790
-
791
- # Wrap each T5 block with checkpointing
792
- for block in self.t5_encoder.encoder.block:
793
- # Store original forward if not already stored
794
- if not hasattr(block, '_original_forward'):
795
- block._original_forward = block.forward
796
-
797
- # Create checkpointed forward
798
- def make_checkpointed_forward(blk):
799
- def checkpointed_forward(*args, **kwargs):
800
- def forward_wrapper(*inputs):
801
- hidden_states = inputs[0]
802
- attention_mask = inputs[1] if len(inputs) > 1 else None
803
- position_bias = inputs[2] if len(inputs) > 2 else None
804
-
805
- return blk._original_forward(
806
- hidden_states=hidden_states,
807
- attention_mask=attention_mask,
808
- position_bias=position_bias,
809
- **{k: v for k, v in kwargs.items() if
810
- k not in ['hidden_states', 'attention_mask', 'position_bias']}
811
- )
812
-
813
- hidden_states = kwargs.get('hidden_states', args[0] if args else None)
814
- attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None)
815
- position_bias = kwargs.get('position_bias', args[2] if len(args) > 2 else None)
816
-
817
- checkpoint_inputs = [hidden_states]
818
- if attention_mask is not None:
819
- checkpoint_inputs.append(attention_mask)
820
- if position_bias is not None:
821
- checkpoint_inputs.append(position_bias)
822
-
823
- return cp(
824
- forward_wrapper,
825
- *checkpoint_inputs,
826
- use_reentrant=False
827
- )
828
-
829
- return checkpointed_forward
830
-
831
- block.forward = make_checkpointed_forward(block)
832
-
833
- def forward(self, text, image=None, compute_jacobian=False):
834
- """
835
- Forward pass with optional Jacobian regularization.
836
-
837
- Args:
838
- text: Input text (string or list of strings)
839
- image: Optional image input
840
- compute_jacobian: Whether to compute Jacobian loss (set False during inference)
841
-
842
- Returns:
843
- prompt_embeds: T5 encoder output
844
- clip_image_embeds: CLIP image embeddings (if image provided)
845
- jacobian_loss: Jacobian sparsity loss (if compute_jacobian=True)
846
- attn_mask: Attention mask
847
- """
848
- if isinstance(text, str):
849
- text = [text]
850
-
851
- batch_encoding = self.t5_tokenizer(
852
- text,
853
- truncation=True,
854
- max_length=self.max_length,
855
- return_length=False,
856
- return_overflowing_tokens=False,
857
- padding="max_length",
858
- return_tensors="pt",
859
- )
860
- attn_mask = batch_encoding["attention_mask"].to(self.device)
861
-
862
- # Get input embeddings
863
- input_ids = batch_encoding["input_ids"].to(self.device)
864
- input_embeds = self.t5_encoder.encoder.embed_tokens(input_ids)
865
-
866
- # Forward pass through encoder
867
- prompt_embeds = self.t5_encoder.encoder(
868
- inputs_embeds=input_embeds,
869
- attention_mask=None,
870
- output_hidden_states=False,
871
- ).last_hidden_state
872
-
873
- # Compute Jacobian loss if requested (during training)
874
- jacobian_loss = {}
875
- if compute_jacobian:
876
- jacobian_same_loss, jacobian_cross_loss = self.compute_jacobian_loss(input_embeds, attn_mask)
877
- jacobian_loss["same_token"] = jacobian_same_loss
878
- jacobian_loss["cross_token"] = jacobian_cross_loss
879
- else:
880
- jacobian_loss['same_token'] = torch.tensor(0.0, device=self.device)
881
- jacobian_loss['cross_token'] = torch.tensor(0.0, device=self.device)
882
-
883
- # Encode image
884
- with torch.no_grad():
885
- if image is not None:
886
- clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds
887
- else:
888
- clip_image_embeds = None
889
-
890
- return prompt_embeds, clip_image_embeds, jacobian_loss, attn_mask
891
-
892
-
893
- import torch
894
- from torch import nn, func
895
- from typing import Optional
896
- from transformers import T5Tokenizer, CLIPVisionModelWithProjection
897
- from transformers.models.t5.modeling_t5 import T5PreTrainedModel, T5Stack, T5Config
898
 
899
  import torch
900
  from torch import nn, func
@@ -1007,101 +319,6 @@ class JacobianT5Encoder(T5PreTrainedModel):
1007
 
1008
  return hidden_states, position_bias, cache_position
1009
 
1010
- def compute_jacobian_loss(self, second_last_output, position_bias, cache_position, attention_mask):
1011
- """
1012
- Compute L1 Jacobian sparsity loss using forward-mode AD (JVP).
1013
- Only computes through the last block + final layer norm.
1014
-
1015
- attention_mask is ONLY used for sampling valid tokens, NOT for masking during forward.
1016
- """
1017
- batch_size, seq_len, hidden_dim = second_last_output.shape
1018
-
1019
- # Use only first sample for Jacobian
1020
- second_last_output = second_last_output[:8]
1021
- position_bias_sample = position_bias[:8] if position_bias is not None else None
1022
- attention_mask = attention_mask[:8]
1023
-
1024
- last_block = self.encoder.block[-1]
1025
- final_layer_norm = self.encoder.final_layer_norm
1026
-
1027
- B, S, H = second_last_output.shape
1028
- device = second_last_output.device
1029
-
1030
- # Use attention_mask ONLY to determine valid tokens for sampling
1031
- lengths = attention_mask.sum(dim=1)
1032
- valid_batch = lengths > 0
1033
-
1034
- if valid_batch.sum() == 0:
1035
- z = second_last_output.new_zeros(())
1036
- return z, z
1037
-
1038
- same_token_loss = second_last_output.new_zeros(())
1039
- cross_token_loss = second_last_output.new_zeros(())
1040
-
1041
- def model_fn(embeds):
1042
- """Forward through ONLY the last block + final layer norm (NO MASKING)"""
1043
- layer_outputs = last_block(
1044
- embeds,
1045
- None, # No attention mask - all tokens attend to all
1046
- position_bias_sample,
1047
- None, None, None,
1048
- past_key_values=None,
1049
- use_cache=False,
1050
- output_attentions=False,
1051
- return_dict=True,
1052
- cache_position=cache_position,
1053
- )
1054
- hidden = layer_outputs[0]
1055
- hidden = final_layer_norm(hidden)
1056
- hidden = self.encoder.dropout(hidden)
1057
- return hidden
1058
-
1059
- batch_idx = torch.arange(B, device=device)
1060
-
1061
- for _ in range(self.num_jacobian_samples):
1062
- # Sample one valid token position per batch element
1063
- # Use attention_mask to know which tokens are valid (not padding)
1064
- t = torch.zeros(B, dtype=torch.long, device=device)
1065
- u = torch.rand(B, device=device)
1066
- t[valid_batch] = (u[valid_batch] * lengths[valid_batch].float()).long()
1067
-
1068
- # Sample one hidden dim per batch element
1069
- k = torch.randint(0, H, (B,), device=device)
1070
-
1071
- # Tangent: one scalar per batch element at position [b, t[b], k[b]]
1072
- tangent = torch.zeros_like(second_last_output)
1073
- tangent[batch_idx, t, k] = 1.0
1074
-
1075
- # JVP through ONLY the last block
1076
- _, jvp = func.jvp(model_fn, (second_last_output,), (tangent,))
1077
- abs_jvp = jvp.abs()
1078
-
1079
- # SAME-token: diagonal element for each batch
1080
- diag = abs_jvp[batch_idx, t, :]
1081
- same_token_loss = same_token_loss + diag[valid_batch].sum()
1082
-
1083
- # CROSS-token: all valid positions except diagonal
1084
- # Use attention_mask to know which positions are valid
1085
- pos = torch.arange(S, device=device).unsqueeze(0)
1086
- valid_pos_mask = pos < lengths.unsqueeze(1)
1087
-
1088
- # Exclude diagonal
1089
- cross_mask = valid_pos_mask.clone()
1090
- cross_mask[batch_idx, t] = False
1091
-
1092
- cross_token_loss = cross_token_loss + abs_jvp[cross_mask].sum()
1093
-
1094
- # Normalization
1095
- num_valid_batches = valid_batch.sum()
1096
- same_token_loss = same_token_loss / (self.num_jacobian_samples * num_valid_batches)
1097
-
1098
- cross_counts = (lengths[valid_batch] - 1).clamp(min=0).sum()
1099
- if cross_counts > 0:
1100
- cross_token_loss = cross_token_loss / (self.num_jacobian_samples * cross_counts)
1101
- else:
1102
- cross_token_loss = second_last_output.new_zeros(())
1103
-
1104
- return same_token_loss, cross_token_loss
1105
 
1106
  def forward(
1107
  self,
@@ -1180,20 +397,7 @@ class JacobianT5Encoder(T5PreTrainedModel):
1180
  hidden_states = self.encoder.dropout(hidden_states)
1181
 
1182
  # Compute Jacobian loss if requested
1183
- jacobian_loss = {}
1184
- if compute_jacobian:
1185
- jacobian_same_loss, jacobian_cross_loss = self.compute_jacobian_loss(
1186
- second_last_output,
1187
- position_bias,
1188
- cache_position,
1189
- attention_mask # Used ONLY for sampling valid tokens
1190
- )
1191
- jacobian_loss = {
1192
- "same_token": jacobian_same_loss,
1193
- "cross_token": jacobian_cross_loss
1194
- }
1195
- else:
1196
- jacobian_loss = {
1197
  "same_token": torch.tensor(0.0, device=input_ids.device),
1198
  "cross_token": torch.tensor(0.0, device=input_ids.device)
1199
  }
@@ -1212,11 +416,11 @@ class JacobianLoraT5Embedder(nn.Module):
1212
 
1213
  # Load T5 config
1214
  from transformers import T5Config
1215
- config = T5Config.from_pretrained('./t5-v1_1-xxl')
1216
 
1217
  # Create encoder model
1218
  self.t5_encoder = JacobianT5Encoder.from_pretrained(
1219
- './t5-v1_1-xxl',
1220
  config=config,
1221
  num_jacobian_samples=num_jacobian_samples,
1222
  max_length=max_length
@@ -1224,18 +428,15 @@ class JacobianLoraT5Embedder(nn.Module):
1224
  self.dtype = torch.bfloat16
1225
 
1226
  # Tokenizer
1227
- self.t5_tokenizer = T5Tokenizer.from_pretrained('./t5-v1_1-xxl', max_length=max_length)
1228
 
1229
  # Image encoder
1230
- image_encoder_path = './clip-vit-large-patch14'
1231
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
1232
  image_encoder_path
1233
  ).to(device=device).to(torch.bfloat16)
1234
  self.image_encoder = self.image_encoder.eval().requires_grad_(False)
1235
 
1236
- print(f"Gradient checkpointing: {use_gradient_checkpointing} (using T5's built-in)")
1237
- print(f"Jacobian samples per batch: {num_jacobian_samples}")
1238
- print(f"NO ATTENTION MASKING during forward pass - all tokens attend to all tokens")
1239
 
1240
  def forward(self, text, image=None, compute_jacobian=False):
1241
  """
@@ -1285,217 +486,3 @@ class JacobianLoraT5Embedder(nn.Module):
1285
  return prompt_embeds, clip_image_embeds, jacobian_loss, attn_mask
1286
 
1287
 
1288
-
1289
- import gc
1290
- from PIL import Image
1291
- from transformers import AutoProcessor
1292
- import numpy as np
1293
-
1294
-
1295
- def get_gpu_memory():
1296
- """Get current GPU memory usage in MB"""
1297
- return torch.cuda.memory_allocated() / 1024 ** 2
1298
-
1299
-
1300
- def get_peak_memory():
1301
- """Get peak GPU memory usage in MB"""
1302
- return torch.cuda.max_memory_allocated() / 1024 ** 2
1303
-
1304
-
1305
- def reset_peak_memory():
1306
- """Reset peak memory counter"""
1307
- torch.cuda.reset_peak_memory_stats()
1308
-
1309
-
1310
- def clear_memory():
1311
- """Clear GPU cache and run garbage collection"""
1312
- gc.collect()
1313
- torch.cuda.empty_cache()
1314
- torch.cuda.reset_peak_memory_stats()
1315
-
1316
-
1317
- def test_memory_usage():
1318
- """Test memory usage with and without Jacobian loss"""
1319
-
1320
-
1321
- # Initialize model
1322
- print("=" * 80)
1323
- print("Initializing model...")
1324
- clear_memory()
1325
-
1326
- model = JacobianLoraT5Embedder(
1327
- device="cuda:0",
1328
- use_gradient_checkpointing=True,
1329
- num_jacobian_samples=10
1330
- )
1331
-
1332
- clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14", use_fast=True)
1333
-
1334
- init_memory = get_gpu_memory()
1335
- print(f"Memory after model init: {init_memory:.2f} MB")
1336
- print("=" * 80)
1337
-
1338
- # Prepare inputs
1339
- image = Image.open('example512.jpg').convert('RGB')
1340
- prompt = """A heartwarming 3D rendered scene of
1341
- an elderly farmer and a tiny orange
1342
- kitten. The farmer, with a gentle smile,
1343
- walks alongside the kitten in a lush,
1344
- green garden filled with thriving plants,
1345
- showcasing a fruitful harvest. The
1346
- intricate details of the overalls and the
1347
- farmer's worn, weathered face tell a
1348
- story of years spent tending to the land, the farmer is wearing a blue shirt"""
1349
-
1350
- # Test different batch sizes
1351
- batch_sizes = [1, 2, 5, 10]
1352
-
1353
- results = []
1354
-
1355
- for batch_size in batch_sizes:
1356
- print(f"\n{'=' * 80}")
1357
- print(f"BATCH SIZE: {batch_size}")
1358
- print(f"{'=' * 80}")
1359
-
1360
- text_batch = [prompt] * batch_size
1361
- pixel_values = clip_processor(
1362
- images=image,
1363
- return_tensors="pt"
1364
- ).pixel_values.to("cuda:0").to(torch.bfloat16)
1365
-
1366
- # Test WITHOUT Jacobian
1367
- print(f"\n--- WITHOUT Jacobian Loss ---")
1368
- clear_memory()
1369
- reset_peak_memory()
1370
-
1371
- mem_before = get_gpu_memory()
1372
- print(f"Memory before forward: {mem_before:.2f} MB")
1373
-
1374
- with torch.no_grad():
1375
- prompt_embeds, clip_image_embeds, jacobian_loss, attn_mask = model(
1376
- text_batch,
1377
- image=pixel_values,
1378
- compute_jacobian=False
1379
- )
1380
-
1381
- mem_after = get_gpu_memory()
1382
- peak_mem = get_peak_memory()
1383
-
1384
- print(f"Memory after forward: {mem_after:.2f} MB")
1385
- print(f"Peak memory: {peak_mem:.2f} MB")
1386
- print(f"Memory increase: {mem_after - mem_before:.2f} MB")
1387
- print(f"Peak increase: {peak_mem - mem_before:.2f} MB")
1388
-
1389
- no_jac_peak = peak_mem - mem_before
1390
-
1391
- # Clean up
1392
- del prompt_embeds, clip_image_embeds, jacobian_loss, attn_mask
1393
-
1394
- # Test WITH Jacobian (requires grad)
1395
- print(f"\n--- WITH Jacobian Loss ---")
1396
- clear_memory()
1397
- reset_peak_memory()
1398
-
1399
- mem_before = get_gpu_memory()
1400
- print(f"Memory before forward: {mem_before:.2f} MB")
1401
-
1402
- try:
1403
- prompt_embeds, clip_image_embeds, jacobian_loss, attn_mask = model(
1404
- text_batch,
1405
- image=pixel_values,
1406
- compute_jacobian=True
1407
- )
1408
-
1409
- mem_after = get_gpu_memory()
1410
- peak_mem = get_peak_memory()
1411
-
1412
- print(f"Memory after forward: {mem_after:.2f} MB")
1413
- print(f"Peak memory: {peak_mem:.2f} MB")
1414
- print(f"Memory increase: {mem_after - mem_before:.2f} MB")
1415
- print(f"Peak increase: {peak_mem - mem_before:.2f} MB")
1416
-
1417
- if jacobian_loss is not None:
1418
- print(f"\nJacobian Loss Values:")
1419
- print(f" Same-token loss: {jacobian_loss['same_token'].item():.6f}")
1420
- print(f" Cross-token loss: {jacobian_loss['cross_token'].item():.6f}")
1421
-
1422
- with_jac_peak = peak_mem - mem_before
1423
-
1424
- print(f"\n{'*' * 60}")
1425
- print(f"JACOBIAN OVERHEAD: {with_jac_peak - no_jac_peak:.2f} MB")
1426
- print(f"MEMORY MULTIPLIER: {with_jac_peak / no_jac_peak:.2f}x")
1427
- print(f"{'*' * 60}")
1428
-
1429
- results.append({
1430
- 'batch_size': batch_size,
1431
- 'no_jacobian_mb': no_jac_peak,
1432
- 'with_jacobian_mb': with_jac_peak,
1433
- 'overhead_mb': with_jac_peak - no_jac_peak,
1434
- 'multiplier': with_jac_peak / no_jac_peak
1435
- })
1436
-
1437
- except RuntimeError as e:
1438
- print(f"❌ CUDA OUT OF MEMORY with Jacobian at batch_size={batch_size}")
1439
- print(f"Error: {str(e)}")
1440
- results.append({
1441
- 'batch_size': batch_size,
1442
- 'no_jacobian_mb': no_jac_peak,
1443
- 'with_jacobian_mb': float('inf'),
1444
- 'overhead_mb': float('inf'),
1445
- 'multiplier': float('inf')
1446
- })
1447
-
1448
- # Clean up
1449
- del prompt_embeds, clip_image_embeds, jacobian_loss, attn_mask
1450
- clear_memory()
1451
-
1452
- # Print summary table
1453
- print(f"\n\n{'=' * 80}")
1454
- print("SUMMARY TABLE")
1455
- print(f"{'=' * 80}")
1456
- print(f"{'Batch':>6} | {'No Jacobian':>12} | {'With Jacobian':>14} | {'Overhead':>10} | {'Multiplier':>10}")
1457
- print(f"{'Size':>6} | {'(MB)':>12} | {'(MB)':>14} | {'(MB)':>10} | {'':>10}")
1458
- print(f"{'-' * 80}")
1459
-
1460
- for r in results:
1461
- batch = r['batch_size']
1462
- no_jac = r['no_jacobian_mb']
1463
- with_jac = r['with_jacobian_mb']
1464
- overhead = r['overhead_mb']
1465
- mult = r['multiplier']
1466
-
1467
- if overhead == float('inf'):
1468
- print(f"{batch:>6} | {no_jac:>11.2f} | {'OOM':>14} | {'OOM':>10} | {'OOM':>10}")
1469
- else:
1470
- print(f"{batch:>6} | {no_jac:>11.2f} | {with_jac:>13.2f} | {overhead:>9.2f} | {mult:>9.2f}x")
1471
-
1472
- print(f"{'=' * 80}")
1473
-
1474
- # Comparison with original
1475
- print(f"\n\n{'=' * 80}")
1476
- print("COMPARISON WITH ORIGINAL IMPLEMENTATION")
1477
- print(f"{'=' * 80}")
1478
- print("\nORIGINAL (all 24 blocks in Jacobian):")
1479
- print(" Batch 1: 30,900 MB overhead, 144x multiplier")
1480
- print(" Batch 10: 30,328 MB overhead, 15x multiplier")
1481
- print("\nNEW (only last block in Jacobian):")
1482
- if len(results) > 0:
1483
- r1 = results[0]
1484
- r10 = results[-1] if len(results) >= 4 else results[-1]
1485
- print(f" Batch 1: {r1['overhead_mb']:>6.0f} MB overhead, {r1['multiplier']:>4.1f}x multiplier")
1486
- print(f" Batch 10: {r10['overhead_mb']:>6.0f} MB overhead, {r10['multiplier']:>4.1f}x multiplier")
1487
-
1488
- if r1['overhead_mb'] != float('inf'):
1489
- reduction = 30900 / r1['overhead_mb']
1490
- print(f"\n🎉 MEMORY REDUCTION: {reduction:.1f}x improvement!")
1491
-
1492
- print(f"{'=' * 80}")
1493
-
1494
-
1495
- if __name__ == "__main__":
1496
- # Set random seed for reproducibility
1497
- torch.manual_seed(42)
1498
- np.random.seed(42)
1499
-
1500
- # Run test
1501
- test_memory_usage()
 
207
 
208
  from torch.utils.checkpoint import checkpoint
209
  from peft import LoraConfig, set_peft_model_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  import torch
212
  from torch import nn, func
 
319
 
320
  return hidden_states, position_bias, cache_position
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  def forward(
324
  self,
 
397
  hidden_states = self.encoder.dropout(hidden_states)
398
 
399
  # Compute Jacobian loss if requested
400
+ jacobian_loss = {
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  "same_token": torch.tensor(0.0, device=input_ids.device),
402
  "cross_token": torch.tensor(0.0, device=input_ids.device)
403
  }
 
416
 
417
  # Load T5 config
418
  from transformers import T5Config
419
+ config = T5Config.from_pretrained('google/t5-v1_1-xxl')
420
 
421
  # Create encoder model
422
  self.t5_encoder = JacobianT5Encoder.from_pretrained(
423
+ 'google/t5-v1_1-xxl',
424
  config=config,
425
  num_jacobian_samples=num_jacobian_samples,
426
  max_length=max_length
 
428
  self.dtype = torch.bfloat16
429
 
430
  # Tokenizer
431
+ self.t5_tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-xxl', max_length=max_length)
432
 
433
  # Image encoder
434
+ image_encoder_path = 'openai/clip-vit-large-patch14'
435
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
436
  image_encoder_path
437
  ).to(device=device).to(torch.bfloat16)
438
  self.image_encoder = self.image_encoder.eval().requires_grad_(False)
439
 
 
 
 
440
 
441
  def forward(self, text, image=None, compute_jacobian=False):
442
  """
 
486
  return prompt_embeds, clip_image_embeds, jacobian_loss, attn_mask
487
 
488