mevlt01001's picture
Upload 7 files
9ec3d0b verified
import os, math, json, time, random
from dataclasses import dataclass
import unicodedata
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw, ImageFont
import textwrap
import numpy as np
from .tokenizer import Tokenizer
def _load_img(path, imgsz):
if not isinstance(path, str): # PIL
img = path.convert("RGB").resize((imgsz, imgsz))
data = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float() / 255.0
return data
assert os.path.exists(path)
img = np.array(Image.open(path).convert("RGB").resize((imgsz, imgsz)))
data = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() / 255.0
return data # [B, C, H, W]
def load_images(paths, imgsz, device=torch.device('cpu')):
if isinstance(paths, dict): # For huggingface datasets my example "ituperceptron/image-captioning-turkish"
try:
return torch.cat([_load_img(p, imgsz) for p in paths["image"]], dim=0).to(device=device)
except Exception as e:
print(e)
print("If you are using Huggingface dataset or something else, please implement your own Trainer.load_images and Trainer._load_img function.")
exit(1)
elif isinstance(paths[0], dict):
return torch.cat([_load_img(p["image"], imgsz) for p in paths], dim=0).to(device=device)
return torch.cat([_load_img(p, imgsz) for p in paths], dim=0).to(device=device)
def load_captions(captions:list[list[int]], max_len, device=torch.device('cpu')):
"""
Used to apply teacher forcing for training
"""
tin = []
tout = []
for cap in captions:
cap = cap[:max_len]
tin.append(cap[:-1])
tout.append(cap[1:])
tin = torch.tensor(tin).long().to(device=device)
tout = torch.tensor(tout).long().to(device=device)
return tin, tout
def remove_accents(text: str) -> str:
normalized = unicodedata.normalize("NFKD", text)
return "".join(ch for ch in normalized if not unicodedata.combining(ch))
def turkish_ascii(text: str) -> str:
"""
Used to prediction captions visualization in Turkish sentences.
"""
text = remove_accents(text)
return text.replace("ı","i").replace("İ","I")
def save_pred(img: torch.Tensor, caption: str, save_path: str = "pred.png"):
"""
This function is used to save image and caption for validation.
"""
if img.ndim != 3 or img.shape[0] not in (1,3):
raise ValueError(f"img shape must be [C,H,W] with C=1 or 3, got {tuple(img.shape)}")
C, H, W = img.shape
imgsz = W
caption = turkish_ascii(caption)
if C == 1:
img = img.repeat(3, 1, 1)
img_np = (img.clamp(0,1).permute(1,2,0) * 255.0).cpu().numpy().astype(np.uint8)
panel_pil = Image.new("RGB", (imgsz, imgsz), (0,0,0))
draw = ImageDraw.Draw(panel_pil)
font = ImageFont.load_default(size=26)
wrapped = textwrap.fill(caption, width=38)
x0, y0 = 24, 24
bbox = draw.multiline_textbbox((x0, y0), wrapped, font=font, spacing=6)
pad = 12
draw.rectangle([bbox[0]-pad, bbox[1]-pad, bbox[2]+pad, bbox[3]+pad], fill=(0,0,0))
draw.multiline_text((x0, y0), wrapped, font=font, fill=(255,255,255),
spacing=6, stroke_width=2, stroke_fill=(0,0,0))
panel_np = np.array(panel_pil)
frame = np.concatenate([img_np, panel_np], axis=1)
Image.fromarray(frame).save(save_path)
@dataclass
class TrainConfig:
epoch:int = 20
batch_size:int = 64
lr:float = 2e-4
weight_decay:float = 1e-2
grad_clip:float = 1.0
save_dir:str = "checkpoints"
sentence_lim:int = 128
max_seq_len:int = 128
min_seq_len:int = 1
avg_seq_len:int = 0
std_seq_len:int = 0
recommended_seq_len:int = 0
def __str__(self):
return f"""
epoch: {self.epoch}
batch_size: {self.batch_size}
lr: {self.lr}
weight_decay: {self.weight_decay}
grad_clip: {self.grad_clip}
save_dir: {self.save_dir}
min sentece length: {self.min_seq_len}
max sentece length: {self.max_seq_len}
avg sentece length: {self.avg_seq_len:.2f}
std sentece length: {self.std_seq_len:.2f}
recommended sentece length: {self.recommended_seq_len}
Training sentece limit: {self.sentence_lim}
"""
class Trainer:
def __init__(self, model:nn.Module, device:torch.device=torch.device('cpu')):
self.model = model
self.device = device
@torch.no_grad()
def validate(self,
imagepaths:list[str],
tokenizer:Tokenizer
) -> str:
images = load_images(imagepaths, self.model.imgsz, device=self.device) # [B, C, H, W]
preds = []
for img in images:
p = self.model.forward(img.unsqueeze(0)) # [B, S]
preds.append(p.squeeze(0))
for i,c in enumerate(preds):
cap = tokenizer.decode(c.tolist())
os.makedirs(f"{self.cfg.save_dir}/preds", exist_ok=True)
save_pred(images[i], cap, os.path.join(f"{self.cfg.save_dir}/preds", f"{i}.png"))
def fit(self,
tokenizer:Tokenizer,
imagepaths:list[str],
epoch:int=20,
batch_size:int=64,
lr:float=2e-4,
weight_decay:float=1e-2,
grad_clip:float=1.0,
save_dir:str="checkpoints",
max_len = None
):
len_limit = max_len if max_len is not None else tokenizer.recommended_seq_len
self.cfg = TrainConfig(epoch=epoch,
batch_size=batch_size,
lr=lr,
weight_decay=weight_decay,
grad_clip=grad_clip,
save_dir=save_dir,
sentence_lim=len_limit,
max_seq_len=tokenizer.max_seq_len,
min_seq_len=tokenizer.min_seq_len,
avg_seq_len=tokenizer.avg_len,
std_seq_len=tokenizer.leng_std,
recommended_seq_len=tokenizer.recommended_seq_len)
cfg = self.cfg
print(cfg)
os.makedirs(cfg.save_dir, exist_ok=True)
self.captions = tokenizer.tokenized_captions
global_step = 0
total_step = cfg.epoch*math.ceil(len(self.captions)/cfg.batch_size)
opt = torch.optim.AdamW([
{'params': self.model.backbone.parameters(), 'lr': 1e-5},
{'params': self.model.encoder.parameters()},
{'params': self.model.decoder.parameters()},
], lr=cfg.lr, weight_decay=cfg.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_step)
criterion = nn.CrossEntropyLoss(ignore_index=int(self.model.pad_id))
best_loss = float("inf")
self.model.to(self.device)
self.model.train(True)
scaler = GradScaler(device=self.device.type, enabled=True)
for ep in range(1, cfg.epoch+1):
global_loss = 0.0
seen = 0
t0 = time.time()
for batch_start_idx in range(0, len(self.captions), cfg.batch_size):
batch_size = min(cfg.batch_size, len(self.captions)-batch_start_idx)
paths = imagepaths[batch_start_idx:batch_start_idx+batch_size]
captions = self.captions[batch_start_idx:batch_start_idx+batch_size]
imgs = load_images(paths, imgsz=self.model.imgsz, device=self.device)
tins, touts = load_captions(captions, len_limit, device=self.device) # 2x [B, T, V]
with autocast(device_type="cuda", dtype=torch.float16):
logits = self.model.forward(imgs, tins)
loss = criterion.forward(logits.reshape(-1, logits.size(-1)), touts.reshape(-1))
opt.zero_grad()
scaler.scale(loss).backward()
nn.utils.clip_grad_norm_(self.model.parameters(), cfg.grad_clip)
old_scale = scaler.get_scale()
scaler.step(opt)
scaler.update()
stepped = scaler.get_scale() >= old_scale
if stepped:
sched.step()
bsz = imgs.size(0)
global_loss += loss.item() * bsz
seen += bsz
global_step += 1
print(f"[%{100*((batch_start_idx+batch_size)/len(self.captions)):>6.2f}]Epoch {ep:03d} loss={global_loss/seen:<6.3f} lr={sched.get_last_lr()[-1]:<.6f}", end="\r")
random_idx:list = np.random.randint(0, len(self.captions), cfg.batch_size).tolist()
paths = []
for idx in random_idx:
paths.append(imagepaths[idx])
self.model.train(False)
self.validate(paths, tokenizer)
self.model.train(True)
ep_loss = global_loss/seen
dt = time.time()-t0
print(f"[EPOCH {ep:03d}] loss={ep_loss:.4f} time={dt:.1f}s")
ckpt = {
"model_name": self.model.model_name if self.model is not None else "ViT Transformer",
"model_state_dict": self.model.state_dict(),
"vocab": tokenizer.vocap,
"imgsz": self.model.imgsz,
"dim": self.model.dim,
"encoder_depth": self.model.encoder_depth,
"decoder_depth": self.model.decoder_depth,
"encoder_num_heads": self.model.encoder_num_heads,
"decoder_num_heads": self.model.decoder_num_heads,
"last_avg_loss": ep_loss
}
last_path = os.path.join(cfg.save_dir, "last.pt")
torch.save(ckpt, last_path)
if ep_loss < best_loss:
best_loss = ep_loss
torch.save(ckpt, os.path.join(cfg.save_dir, "best.pt"))
print(f"[TRAIN DONE] best_loss={best_loss:.4f} ckpts => {cfg.save_dir}")
return {"best_loss": best_loss, "save_dir": cfg.save_dir}