File size: 6,166 Bytes
94a0812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
from peft import LoraConfig, get_peft_model
import os
import requests
from transformers import AutoConfig
def safe_load_t5(model_name, local_path):
has_local = os.path.exists(local_path)
try:
print(f"[INFO] Trying to load {model_name} from HuggingFace…")
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("[INFO] Loaded from HF successfully.")
return model
except Exception as e:
print(f"[WARN] HF load failed: {e}")
if not has_local:
raise RuntimeError(
f"No local copy available at {local_path} and HF download failed."
)
print("[INFO] Falling back to local Drive copy...")
return T5ForConditionalGeneration.from_pretrained(local_path)
class VisionT5(nn.Module):
def __init__(self, vision_encoder, projector, t5_name="t5-small", decoder_params=None):
super().__init__()
decoder_params = decoder_params or {}
self.vision_encoder = vision_encoder
self.projector = projector
# Load full T5, but we only use decoder
local_large = "/content/drive/MyDrive/Models/t5-large"
if t5_name == "t5-large":
self.t5 = safe_load_t5("t5-large", local_large)
else:
self.t5 = T5ForConditionalGeneration.from_pretrained(t5_name)
self.apply_decoder_options(decoder_params)
for p in self.t5.encoder.parameters():
p.requires_grad = False
self.hidden_size = self.t5.config.d_model
def apply_decoder_options(self, params):
# LoRA setup
if params.get("use_lora", False):
lora_rank = params.get("lora_rank", 8)
lora_alpha = params.get("lora_alpha", 16)
print(f"[INFO] LoRA enabled for T5 decoder (Rank={lora_rank})")
# Target query and value matrices in all T5 attention blocks
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["q", "v"],
lora_dropout=params.get("lora_dropout", 0.1),
bias="none",
task_type="CAUSAL_LM"
)
self.t5 = get_peft_model(self.t5, lora_config)
self.t5.print_trainable_parameters()
# The freeze_decoder flag (if present) is ignored when using LoRA, as LoRA automatically handles freezing and only exposes the adapter weights.
return
num_decoder_layers = self.t5.config.num_decoder_layers #
trainable_layers = params.get("trainable_decoder_layers")
if trainable_layers is not None:
num_frozen = num_decoder_layers - trainable_layers
if num_frozen > 0:
print(f"[INFO] Partial Tuning: Freezing first {num_frozen} of {num_decoder_layers} decoder blocks.")
for i, block in enumerate(self.t5.decoder.block):
if i < num_frozen:
for p in block.parameters():
p.requires_grad = False
print(f" > Block {i} frozen.")
else:
for p in block.parameters():
p.requires_grad = True
print(f" > Block {i} trainable.")
if num_frozen > 0:
for p in self.t5.decoder.embed_tokens.parameters():
p.requires_grad = False
print(" > Decoder embeddings frozen.")
return
if params.get("freeze_decoder", False):
print("[INFO] Freezing all T5 decoder parameters.")
for p in self.t5.decoder.parameters():
p.requires_grad = False
if params.get("dropout_override") is not None:
self.t5.config.dropout_rate = params["dropout_override"]
def forward(
self,
pixel_values=None,
input_ids=None,
attention_mask=None,
labels=None
):
vision_out = self.vision_encoder(pixel_values)
image_embeds = vision_out["image_embeds"]
if image_embeds.dim() == 2:
image_embeds = image_embeds.unsqueeze(1)
projected = self.projector(image_embeds)
B, S, _ = projected.shape
encoder_attention_mask = torch.ones(B, S, dtype=torch.long, device=projected.device)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
decoder_attention_mask = attention_mask
output = self.t5(
input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
attention_mask=encoder_attention_mask,
encoder_outputs=encoder_outputs,
labels=labels,
return_dict=True,
)
return output
@torch.no_grad()
def generate(self, pixel_values, tokenizer, max_length=32, num_beams=3):
vision_out = self.vision_encoder(pixel_values)
image_embeds = vision_out["image_embeds"]
if image_embeds.dim() == 2:
image_embeds = image_embeds.unsqueeze(1) # (B, 1, D)
projected = self.projector(image_embeds) # (B, S, d_model)
encoder_outputs = BaseModelOutput(
last_hidden_state=projected
)
generated_ids = self.t5.generate(
encoder_outputs=encoder_outputs,
decoder_start_token_id=self.t5.config.decoder_start_token_id,
input_ids=torch.tensor([[tokenizer.pad_token_id]]).to(projected.device),
max_length=max_length,
num_beams=num_beams
)
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
@staticmethod
def get_t5_hidden_size(t5_name):
cfg = AutoConfig.from_pretrained(t5_name)
return cfg.d_model |