|
|
|
|
|
import os |
|
|
import yaml |
|
|
import torch |
|
|
from datetime import datetime |
|
|
|
|
|
from transformers import T5TokenizerFast |
|
|
from models.vision_t5 import VisionT5 |
|
|
import models.encoders as encoders |
|
|
from models.encoder_projection_t5 import ImageProjection |
|
|
import inspect |
|
|
|
|
|
|
|
|
|
|
|
def timestamp(): |
|
|
return datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
|
|
|
|
|
|
def save_experiment(model, tokenizer, config, save_dir, notes="", run_name=None, add_timestamp=True): |
|
|
|
|
|
if add_timestamp: |
|
|
tag = timestamp() |
|
|
if run_name: |
|
|
save_dir = os.path.join(save_dir, f"{run_name}_{tag}") |
|
|
else: |
|
|
save_dir = os.path.join(save_dir, tag) |
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
torch.save(model.state_dict(), os.path.join(save_dir, "pytorch_model.bin")) |
|
|
|
|
|
tok_dir = os.path.join(save_dir, "tokenizer") |
|
|
os.makedirs(tok_dir, exist_ok=True) |
|
|
tokenizer.save_pretrained(tok_dir) |
|
|
|
|
|
with open(os.path.join(save_dir, "config_trained.yaml"), "w") as f: |
|
|
yaml.safe_dump(config, f) |
|
|
|
|
|
metadata = { |
|
|
"encoder": config["model"]["encoder"], |
|
|
"encoder_params": config["model"].get("encoder_params", {}), |
|
|
"decoder": config["model"]["t5_name"], |
|
|
"decoder_params": config["model"].get("decoder_params", {}), |
|
|
"train_epochs": config["training"]["epochs"], |
|
|
"batch_size": config["training"]["batch_size"], |
|
|
"lr": config["training"]["lr"], |
|
|
"notes": notes, |
|
|
"run_name": run_name, |
|
|
"timestamp": timestamp(), |
|
|
} |
|
|
|
|
|
with open(os.path.join(save_dir, "metadata.yaml"), "w") as f: |
|
|
yaml.safe_dump(metadata, f) |
|
|
|
|
|
print(f"[OK] Experiment saved → {save_dir}") |
|
|
return save_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_experiment(checkpoint_dir, device="cpu"): |
|
|
import yaml, torch, os |
|
|
|
|
|
metadata_path = os.path.join(checkpoint_dir, "metadata.yaml") |
|
|
config_path = os.path.join(checkpoint_dir, "config_trained.yaml") |
|
|
|
|
|
if not os.path.exists(metadata_path): |
|
|
raise FileNotFoundError(f"No metadata.yaml found at {checkpoint_dir}") |
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"No config_trained.yaml found at {checkpoint_dir}") |
|
|
|
|
|
with open(metadata_path, "r") as f: |
|
|
metadata = yaml.safe_load(f) |
|
|
|
|
|
with open(config_path, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
model, tokenizer = build_model(config) |
|
|
|
|
|
tok_dir = os.path.join(checkpoint_dir, "tokenizer") |
|
|
if os.path.isdir(tok_dir): |
|
|
tokenizer = T5TokenizerFast.from_pretrained(tok_dir) |
|
|
|
|
|
ckpt_path = os.path.join(checkpoint_dir, "pytorch_model.bin") |
|
|
weights = torch.load(ckpt_path, map_location=device) |
|
|
model.load_state_dict(weights, strict=False) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"Loaded experiment from {checkpoint_dir}") |
|
|
return model, tokenizer, metadata, config |
|
|
|
|
|
|
|
|
|
|
|
def filter_kwargs(cls, kwargs): |
|
|
sig = inspect.signature(cls.__init__).parameters |
|
|
return {k: v for k, v in kwargs.items() if k in sig} |
|
|
|
|
|
|
|
|
|
|
|
def build_model(config): |
|
|
|
|
|
encoder_name = config["model"]["encoder"] |
|
|
raw_encoder_params = config["model"].get("encoder_params", {}) |
|
|
|
|
|
t5_name = config["model"]["t5_name"] |
|
|
decoder_params = config["model"].get("decoder_params", {}) |
|
|
|
|
|
tokenizer = T5TokenizerFast.from_pretrained(t5_name) |
|
|
|
|
|
|
|
|
if not hasattr(encoders, encoder_name): |
|
|
raise ValueError(f"Encoder '{encoder_name}' not found in encoders.py") |
|
|
|
|
|
EncoderClass = getattr(encoders, encoder_name) |
|
|
|
|
|
encoder_params = filter_kwargs(EncoderClass, raw_encoder_params) |
|
|
|
|
|
|
|
|
vision_encoder = EncoderClass(**encoder_params) |
|
|
|
|
|
|
|
|
t5_hidden = VisionT5.get_t5_hidden_size(t5_name) |
|
|
projector = ImageProjection( |
|
|
encoder_dim=vision_encoder.get_output_dim(), |
|
|
t5_hidden_size=t5_hidden |
|
|
) |
|
|
|
|
|
|
|
|
model = VisionT5( |
|
|
vision_encoder=vision_encoder, |
|
|
projector=projector, |
|
|
t5_name=t5_name, |
|
|
decoder_params=decoder_params |
|
|
) |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
def load_yaml(path): |
|
|
with open(path, "r") as f: |
|
|
return yaml.safe_load(f) |
|
|
|
|
|
|
|
|
|
|
|
def count_encoder_decoder_params(model): |
|
|
|
|
|
enc_total = enc_train = 0 |
|
|
proj_total = proj_train = 0 |
|
|
dec_total = dec_train = 0 |
|
|
other_total = other_train = 0 |
|
|
|
|
|
for name, p in model.named_parameters(): |
|
|
n = p.numel() |
|
|
|
|
|
|
|
|
if name.startswith("vision_encoder."): |
|
|
enc_total += n |
|
|
if p.requires_grad: |
|
|
enc_train += n |
|
|
continue |
|
|
|
|
|
|
|
|
if name.startswith("projector."): |
|
|
proj_total += n |
|
|
if p.requires_grad: |
|
|
proj_train += n |
|
|
continue |
|
|
|
|
|
|
|
|
if ( |
|
|
name.startswith("t5.decoder.") or |
|
|
"decoder.block" in name or |
|
|
name.startswith("t5.model.decoder.") or |
|
|
name.startswith("t5.lm_head.") or |
|
|
name.startswith("t5.shared.") |
|
|
): |
|
|
dec_total += n |
|
|
if p.requires_grad: |
|
|
dec_train += n |
|
|
continue |
|
|
|
|
|
if "lora_" in name and "decoder" in name: |
|
|
dec_total += n |
|
|
if p.requires_grad: |
|
|
dec_train += n |
|
|
continue |
|
|
|
|
|
|
|
|
if name.startswith("t5.encoder."): |
|
|
other_total += n |
|
|
if p.requires_grad: |
|
|
other_train += n |
|
|
continue |
|
|
|
|
|
|
|
|
other_total += n |
|
|
if p.requires_grad: |
|
|
other_train += n |
|
|
|
|
|
total_params = enc_total + proj_total + dec_total + other_total |
|
|
trainable_params = enc_train + proj_train + dec_train + other_train |
|
|
|
|
|
return { |
|
|
"encoder_total_params": enc_total, |
|
|
"encoder_trainable_params": enc_train, |
|
|
"encoder_trainable_fraction": |
|
|
enc_train / enc_total if enc_total else None, |
|
|
|
|
|
"projector_total_params": proj_total, |
|
|
"projector_trainable_params": proj_train, |
|
|
"projector_trainable_fraction": |
|
|
proj_train / proj_total if proj_total else None, |
|
|
|
|
|
"decoder_total_params": dec_total, |
|
|
"decoder_trainable_params": dec_train, |
|
|
"decoder_trainable_fraction": |
|
|
dec_train / dec_total if dec_total else None, |
|
|
|
|
|
"other_total_params": other_total, |
|
|
"other_trainable_params": other_train, |
|
|
|
|
|
"total_params": total_params, |
|
|
"trainable_params": trainable_params, |
|
|
"trainable_params_fraction": |
|
|
trainable_params / total_params if total_params else None, |
|
|
} |
|
|
|
|
|
|
|
|
def classify_param(name): |
|
|
|
|
|
if name.startswith("vision_encoder."): |
|
|
return "encoder" |
|
|
|
|
|
if name.startswith("projector."): |
|
|
return "projector" |
|
|
|
|
|
if ( |
|
|
name.startswith("t5.decoder.") or |
|
|
name.startswith("t5.model.decoder.") or |
|
|
"decoder.block" in name or |
|
|
name.startswith("t5.lm_head.") or |
|
|
name.startswith("t5.shared.") or |
|
|
("lora_" in name and "decoder" in name) |
|
|
): |
|
|
return "decoder" |
|
|
|
|
|
if name.startswith("t5.encoder."): |
|
|
return "t5_encoder_frozen" |
|
|
|
|
|
return "other" |
|
|
|
|
|
|