import os import torch import numpy as np from PIL import Image from . import UltralyticsModel from .backbone import Backbone, PatchEmbedder from .encoder import ViTEncoder from .decoder import ViTDecoder from .tokenizer import Tokenizer from .trainer import Trainer, save_pred class Model(torch.nn.Module): def __init__(self, tokenizer, model=None, imgsz=640, dim=512, encoder_depth=3, decoder_depth=3, encoder_num_heads=8, decoder_num_heads=8, dropout=0.1, freeze_backbone=True, device=torch.device('cpu'), use_raw_patches: bool = False, # if True, use raw patches not included backbone. patch_size: int = 16 # if use_raw_patches is True ): super().__init__() self.model_name = model.model_name if model is not None else None self.tokenizer = tokenizer self.vocap_size = tokenizer.vocap_size self.imgsz = max(64, 32 * (imgsz // 32)) self.dim = dim self.pad_id = tokenizer.char2idx[tokenizer.PAD] self.bos_id = tokenizer.char2idx[tokenizer.BOS] self.eos_id = tokenizer.char2idx[tokenizer.EOS] self.encoder_depth = encoder_depth self.decoder_depth = decoder_depth self.encoder_num_heads = encoder_num_heads self.decoder_num_heads = decoder_num_heads self.device = device if use_raw_patches: self.backbone = PatchEmbedder(imgsz=self.imgsz, out_dim=self.dim, patch_size=patch_size, device=self.device) else: self.backbone = Backbone(model=model, imgsz=self.imgsz, device=self.device) self.encoder = ViTEncoder(dim=self.dim, in_shape=self.backbone.bb_out_shape, num_heads=encoder_num_heads, num_blocks=encoder_depth, dropout=dropout, device=self.device) self.decoder = ViTDecoder(vocab_size=self.vocap_size, pad_id=self.pad_id, bos_id=self.bos_id, eos_id=self.eos_id, dim=self.dim, depth=decoder_depth, heads=decoder_num_heads, dropout=dropout, device=self.device) if freeze_backbone: for p in self.backbone.parameters(): p.requires_grad = False def _preprocess(self, image_path:os.PathLike): im = Image.open(image_path).convert("RGB").resize((self.imgsz, self.imgsz)) im = np.array(im) im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).float() / 255.0 return im @torch.no_grad() def predict(self, image:torch.Tensor|os.PathLike, maxlen:int=None): if isinstance(image,str): image = self._preprocess(image).to(self.device) self.eval() assert image.ndim == 4, "x must be [B,C,H,W]" pred = self.forward(image, max_len=maxlen if maxlen is not None else self.tokenizer.max_seq_len) B = pred.size(0) for i in range(B): cap = self.tokenizer.decode(pred[i].tolist()) os.makedirs(f"Inference_outs/preds", exist_ok=True) save_pred(image[i], cap, os.path.join(f"Inference_outs/preds", f"{i}.png")) def forward(self, x, tokens_in=None, max_len=None): x = self.backbone(x) x = self.encoder(x) x = self.decoder.forward(x, tokens_in) if self.training else self.decoder.generate(x, max_len=max_len if max_len is not None else self.tokenizer.max_seq_len) return x @classmethod def load_from_checkpoint(cls, path:str, tokenizer:Tokenizer, freeze_backbone:bool=False, dropout:float=0.1, device=torch.device('cpu')): from ultralytics import YOLO ckpt = torch.load(path, map_location=device) sd = ckpt.get("model_state_dict", None) imgsz = ckpt.get("imgsz",640) dim = ckpt.get("dim", 512) encoder_depth = ckpt.get("encoder_depth", 3) decoder_depth = ckpt.get("decoder_depth", 3) encoder_num_heads = ckpt.get("encoder_num_heads", 8) decoder_num_heads = ckpt.get("decoder_num_heads", 8) vocab = ckpt.get("vocab", None) model_name = ckpt.get("model_name", "yolo11n.pt") yolo_model = YOLO(model_name) model = cls( tokenizer=tokenizer.set_vocab(vocab), # load training vocabulary model=yolo_model, imgsz=imgsz, dim=dim, encoder_depth=encoder_depth, decoder_depth=decoder_depth, encoder_num_heads=encoder_num_heads, decoder_num_heads=decoder_num_heads, dropout=dropout, freeze_backbone=freeze_backbone, device=device ) res = model.load_state_dict(sd, strict=False) print(res.missing_keys) print(res.unexpected_keys) return model def train(self, mode:bool=True, **kwargs): has_fit_args = ("imagepaths" in kwargs) if not has_fit_args: return super().train(mode) super().train(True) trainer = Trainer(model=self, device=self.device) return trainer.fit(**kwargs, tokenizer=self.tokenizer) def export(self, path:str="model.onnx"): import onnx, onnxsim dummy = torch.rand(1, 3, self.imgsz, self.imgsz, device=self.device) self.train(False) torch.onnx.export(self, (dummy), path, dynamo=True) model = onnx.load(path) model, check = onnxsim.simplify(model) model = onnx.shape_inference.infer_shapes(model) onnx.save(model, path)