import os, math, json, time, random from dataclasses import dataclass import unicodedata import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader from PIL import Image, ImageDraw, ImageFont import textwrap import numpy as np from .tokenizer import Tokenizer def _load_img(path, imgsz): if not isinstance(path, str): # PIL img = path.convert("RGB").resize((imgsz, imgsz)) data = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float() / 255.0 return data assert os.path.exists(path) img = np.array(Image.open(path).convert("RGB").resize((imgsz, imgsz))) data = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() / 255.0 return data # [B, C, H, W] def load_images(paths, imgsz, device=torch.device('cpu')): if isinstance(paths, dict): # For huggingface datasets my example "ituperceptron/image-captioning-turkish" try: return torch.cat([_load_img(p, imgsz) for p in paths["image"]], dim=0).to(device=device) except Exception as e: print(e) print("If you are using Huggingface dataset or something else, please implement your own Trainer.load_images and Trainer._load_img function.") exit(1) elif isinstance(paths[0], dict): return torch.cat([_load_img(p["image"], imgsz) for p in paths], dim=0).to(device=device) return torch.cat([_load_img(p, imgsz) for p in paths], dim=0).to(device=device) def load_captions(captions:list[list[int]], max_len, device=torch.device('cpu')): """ Used to apply teacher forcing for training """ tin = [] tout = [] for cap in captions: cap = cap[:max_len] tin.append(cap[:-1]) tout.append(cap[1:]) tin = torch.tensor(tin).long().to(device=device) tout = torch.tensor(tout).long().to(device=device) return tin, tout def remove_accents(text: str) -> str: normalized = unicodedata.normalize("NFKD", text) return "".join(ch for ch in normalized if not unicodedata.combining(ch)) def turkish_ascii(text: str) -> str: """ Used to prediction captions visualization in Turkish sentences. """ text = remove_accents(text) return text.replace("ı","i").replace("İ","I") def save_pred(img: torch.Tensor, caption: str, save_path: str = "pred.png"): """ This function is used to save image and caption for validation. """ if img.ndim != 3 or img.shape[0] not in (1,3): raise ValueError(f"img shape must be [C,H,W] with C=1 or 3, got {tuple(img.shape)}") C, H, W = img.shape imgsz = W caption = turkish_ascii(caption) if C == 1: img = img.repeat(3, 1, 1) img_np = (img.clamp(0,1).permute(1,2,0) * 255.0).cpu().numpy().astype(np.uint8) panel_pil = Image.new("RGB", (imgsz, imgsz), (0,0,0)) draw = ImageDraw.Draw(panel_pil) font = ImageFont.load_default(size=26) wrapped = textwrap.fill(caption, width=38) x0, y0 = 24, 24 bbox = draw.multiline_textbbox((x0, y0), wrapped, font=font, spacing=6) pad = 12 draw.rectangle([bbox[0]-pad, bbox[1]-pad, bbox[2]+pad, bbox[3]+pad], fill=(0,0,0)) draw.multiline_text((x0, y0), wrapped, font=font, fill=(255,255,255), spacing=6, stroke_width=2, stroke_fill=(0,0,0)) panel_np = np.array(panel_pil) frame = np.concatenate([img_np, panel_np], axis=1) Image.fromarray(frame).save(save_path) @dataclass class TrainConfig: epoch:int = 20 batch_size:int = 64 lr:float = 2e-4 weight_decay:float = 1e-2 grad_clip:float = 1.0 save_dir:str = "checkpoints" sentence_lim:int = 128 max_seq_len:int = 128 min_seq_len:int = 1 avg_seq_len:int = 0 std_seq_len:int = 0 recommended_seq_len:int = 0 def __str__(self): return f""" epoch: {self.epoch} batch_size: {self.batch_size} lr: {self.lr} weight_decay: {self.weight_decay} grad_clip: {self.grad_clip} save_dir: {self.save_dir} min sentece length: {self.min_seq_len} max sentece length: {self.max_seq_len} avg sentece length: {self.avg_seq_len:.2f} std sentece length: {self.std_seq_len:.2f} recommended sentece length: {self.recommended_seq_len} Training sentece limit: {self.sentence_lim} """ class Trainer: def __init__(self, model:nn.Module, device:torch.device=torch.device('cpu')): self.model = model self.device = device @torch.no_grad() def validate(self, imagepaths:list[str], tokenizer:Tokenizer ) -> str: images = load_images(imagepaths, self.model.imgsz, device=self.device) # [B, C, H, W] preds = [] for img in images: p = self.model.forward(img.unsqueeze(0)) # [B, S] preds.append(p.squeeze(0)) for i,c in enumerate(preds): cap = tokenizer.decode(c.tolist()) os.makedirs(f"{self.cfg.save_dir}/preds", exist_ok=True) save_pred(images[i], cap, os.path.join(f"{self.cfg.save_dir}/preds", f"{i}.png")) def fit(self, tokenizer:Tokenizer, imagepaths:list[str], epoch:int=20, batch_size:int=64, lr:float=2e-4, weight_decay:float=1e-2, grad_clip:float=1.0, save_dir:str="checkpoints", max_len = None ): len_limit = max_len if max_len is not None else tokenizer.recommended_seq_len self.cfg = TrainConfig(epoch=epoch, batch_size=batch_size, lr=lr, weight_decay=weight_decay, grad_clip=grad_clip, save_dir=save_dir, sentence_lim=len_limit, max_seq_len=tokenizer.max_seq_len, min_seq_len=tokenizer.min_seq_len, avg_seq_len=tokenizer.avg_len, std_seq_len=tokenizer.leng_std, recommended_seq_len=tokenizer.recommended_seq_len) cfg = self.cfg print(cfg) os.makedirs(cfg.save_dir, exist_ok=True) self.captions = tokenizer.tokenized_captions global_step = 0 total_step = cfg.epoch*math.ceil(len(self.captions)/cfg.batch_size) opt = torch.optim.AdamW([ {'params': self.model.backbone.parameters(), 'lr': 1e-5}, {'params': self.model.encoder.parameters()}, {'params': self.model.decoder.parameters()}, ], lr=cfg.lr, weight_decay=cfg.weight_decay) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_step) criterion = nn.CrossEntropyLoss(ignore_index=int(self.model.pad_id)) best_loss = float("inf") self.model.to(self.device) self.model.train(True) scaler = GradScaler(device=self.device.type, enabled=True) for ep in range(1, cfg.epoch+1): global_loss = 0.0 seen = 0 t0 = time.time() for batch_start_idx in range(0, len(self.captions), cfg.batch_size): batch_size = min(cfg.batch_size, len(self.captions)-batch_start_idx) paths = imagepaths[batch_start_idx:batch_start_idx+batch_size] captions = self.captions[batch_start_idx:batch_start_idx+batch_size] imgs = load_images(paths, imgsz=self.model.imgsz, device=self.device) tins, touts = load_captions(captions, len_limit, device=self.device) # 2x [B, T, V] with autocast(device_type="cuda", dtype=torch.float16): logits = self.model.forward(imgs, tins) loss = criterion.forward(logits.reshape(-1, logits.size(-1)), touts.reshape(-1)) opt.zero_grad() scaler.scale(loss).backward() nn.utils.clip_grad_norm_(self.model.parameters(), cfg.grad_clip) old_scale = scaler.get_scale() scaler.step(opt) scaler.update() stepped = scaler.get_scale() >= old_scale if stepped: sched.step() bsz = imgs.size(0) global_loss += loss.item() * bsz seen += bsz global_step += 1 print(f"[%{100*((batch_start_idx+batch_size)/len(self.captions)):>6.2f}]Epoch {ep:03d} loss={global_loss/seen:<6.3f} lr={sched.get_last_lr()[-1]:<.6f}", end="\r") random_idx:list = np.random.randint(0, len(self.captions), cfg.batch_size).tolist() paths = [] for idx in random_idx: paths.append(imagepaths[idx]) self.model.train(False) self.validate(paths, tokenizer) self.model.train(True) ep_loss = global_loss/seen dt = time.time()-t0 print(f"[EPOCH {ep:03d}] loss={ep_loss:.4f} time={dt:.1f}s") ckpt = { "model_name": self.model.model_name if self.model is not None else "ViT Transformer", "model_state_dict": self.model.state_dict(), "vocab": tokenizer.vocap, "imgsz": self.model.imgsz, "dim": self.model.dim, "encoder_depth": self.model.encoder_depth, "decoder_depth": self.model.decoder_depth, "encoder_num_heads": self.model.encoder_num_heads, "decoder_num_heads": self.model.decoder_num_heads, "last_avg_loss": ep_loss } last_path = os.path.join(cfg.save_dir, "last.pt") torch.save(ckpt, last_path) if ep_loss < best_loss: best_loss = ep_loss torch.save(ckpt, os.path.join(cfg.save_dir, "best.pt")) print(f"[TRAIN DONE] best_loss={best_loss:.4f} ckpts => {cfg.save_dir}") return {"best_loss": best_loss, "save_dir": cfg.save_dir}