coco-demo / src /utils.py
evanec's picture
Upload 12 files
1809762 verified
# utils.py
import os
import yaml
import torch
from datetime import datetime
from transformers import T5TokenizerFast
from models.vision_t5 import VisionT5
import models.encoders as encoders
from models.encoder_projection_t5 import ImageProjection
import inspect
def timestamp():
return datetime.now().strftime("%Y%m%d_%H%M%S")
def save_experiment(model, tokenizer, config, save_dir, notes="", run_name=None, add_timestamp=True):
if add_timestamp:
tag = timestamp()
if run_name:
save_dir = os.path.join(save_dir, f"{run_name}_{tag}")
else:
save_dir = os.path.join(save_dir, tag)
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
tok_dir = os.path.join(save_dir, "tokenizer")
os.makedirs(tok_dir, exist_ok=True)
tokenizer.save_pretrained(tok_dir)
with open(os.path.join(save_dir, "config_trained.yaml"), "w") as f:
yaml.safe_dump(config, f)
metadata = {
"encoder": config["model"]["encoder"],
"encoder_params": config["model"].get("encoder_params", {}),
"decoder": config["model"]["t5_name"],
"decoder_params": config["model"].get("decoder_params", {}),
"train_epochs": config["training"]["epochs"],
"batch_size": config["training"]["batch_size"],
"lr": config["training"]["lr"],
"notes": notes,
"run_name": run_name,
"timestamp": timestamp(),
}
with open(os.path.join(save_dir, "metadata.yaml"), "w") as f:
yaml.safe_dump(metadata, f)
print(f"[OK] Experiment saved → {save_dir}")
return save_dir
def load_experiment(checkpoint_dir, device="cpu"):
import yaml, torch, os
metadata_path = os.path.join(checkpoint_dir, "metadata.yaml")
config_path = os.path.join(checkpoint_dir, "config_trained.yaml")
if not os.path.exists(metadata_path):
raise FileNotFoundError(f"No metadata.yaml found at {checkpoint_dir}")
if not os.path.exists(config_path):
raise FileNotFoundError(f"No config_trained.yaml found at {checkpoint_dir}")
with open(metadata_path, "r") as f:
metadata = yaml.safe_load(f)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
model, tokenizer = build_model(config)
tok_dir = os.path.join(checkpoint_dir, "tokenizer")
if os.path.isdir(tok_dir):
tokenizer = T5TokenizerFast.from_pretrained(tok_dir)
ckpt_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
weights = torch.load(ckpt_path, map_location=device)
model.load_state_dict(weights, strict=False)
model.to(device)
model.eval()
print(f"Loaded experiment from {checkpoint_dir}")
return model, tokenizer, metadata, config
def filter_kwargs(cls, kwargs):
sig = inspect.signature(cls.__init__).parameters
return {k: v for k, v in kwargs.items() if k in sig}
def build_model(config):
encoder_name = config["model"]["encoder"]
raw_encoder_params = config["model"].get("encoder_params", {})
t5_name = config["model"]["t5_name"]
decoder_params = config["model"].get("decoder_params", {})
tokenizer = T5TokenizerFast.from_pretrained(t5_name)
# dynamically load encoder class
if not hasattr(encoders, encoder_name):
raise ValueError(f"Encoder '{encoder_name}' not found in encoders.py")
EncoderClass = getattr(encoders, encoder_name)
encoder_params = filter_kwargs(EncoderClass, raw_encoder_params)
# Instantiate encoder
vision_encoder = EncoderClass(**encoder_params)
# Projection layer
t5_hidden = VisionT5.get_t5_hidden_size(t5_name)
projector = ImageProjection(
encoder_dim=vision_encoder.get_output_dim(),
t5_hidden_size=t5_hidden
)
# Construct model
model = VisionT5(
vision_encoder=vision_encoder,
projector=projector,
t5_name=t5_name,
decoder_params=decoder_params
)
return model, tokenizer
def load_yaml(path):
with open(path, "r") as f:
return yaml.safe_load(f)
def count_encoder_decoder_params(model):
enc_total = enc_train = 0
proj_total = proj_train = 0
dec_total = dec_train = 0
other_total = other_train = 0
for name, p in model.named_parameters():
n = p.numel()
# Vision Encoder
if name.startswith("vision_encoder."):
enc_total += n
if p.requires_grad:
enc_train += n
continue
# Projector
if name.startswith("projector."):
proj_total += n
if p.requires_grad:
proj_train += n
continue
# T5 Decoder (covers small, base, large, AND LoRA)
if (
name.startswith("t5.decoder.") or
"decoder.block" in name or
name.startswith("t5.model.decoder.") or
name.startswith("t5.lm_head.") or
name.startswith("t5.shared.")
):
dec_total += n
if p.requires_grad:
dec_train += n
continue
if "lora_" in name and "decoder" in name:
dec_total += n
if p.requires_grad:
dec_train += n
continue
# T5 Encoder (always frozen)
if name.startswith("t5.encoder."):
other_total += n
if p.requires_grad:
other_train += n
continue
# Other params
other_total += n
if p.requires_grad:
other_train += n
total_params = enc_total + proj_total + dec_total + other_total
trainable_params = enc_train + proj_train + dec_train + other_train
return {
"encoder_total_params": enc_total,
"encoder_trainable_params": enc_train,
"encoder_trainable_fraction":
enc_train / enc_total if enc_total else None,
"projector_total_params": proj_total,
"projector_trainable_params": proj_train,
"projector_trainable_fraction":
proj_train / proj_total if proj_total else None,
"decoder_total_params": dec_total,
"decoder_trainable_params": dec_train,
"decoder_trainable_fraction":
dec_train / dec_total if dec_total else None,
"other_total_params": other_total,
"other_trainable_params": other_train,
"total_params": total_params,
"trainable_params": trainable_params,
"trainable_params_fraction":
trainable_params / total_params if total_params else None,
}
def classify_param(name):
if name.startswith("vision_encoder."):
return "encoder"
if name.startswith("projector."):
return "projector"
if (
name.startswith("t5.decoder.") or
name.startswith("t5.model.decoder.") or
"decoder.block" in name or
name.startswith("t5.lm_head.") or
name.startswith("t5.shared.") or
("lora_" in name and "decoder" in name)
):
return "decoder"
if name.startswith("t5.encoder."):
return "t5_encoder_frozen"
return "other"