mevlt01001's picture
Upload 7 files
9ec3d0b verified
import os
import torch
import numpy as np
from PIL import Image
from . import UltralyticsModel
from .backbone import Backbone, PatchEmbedder
from .encoder import ViTEncoder
from .decoder import ViTDecoder
from .tokenizer import Tokenizer
from .trainer import Trainer, save_pred
class Model(torch.nn.Module):
def __init__(self,
tokenizer,
model=None,
imgsz=640,
dim=512,
encoder_depth=3,
decoder_depth=3,
encoder_num_heads=8,
decoder_num_heads=8,
dropout=0.1,
freeze_backbone=True,
device=torch.device('cpu'),
use_raw_patches: bool = False, # if True, use raw patches not included backbone.
patch_size: int = 16 # if use_raw_patches is True
):
super().__init__()
self.model_name = model.model_name if model is not None else None
self.tokenizer = tokenizer
self.vocap_size = tokenizer.vocap_size
self.imgsz = max(64, 32 * (imgsz // 32))
self.dim = dim
self.pad_id = tokenizer.char2idx[tokenizer.PAD]
self.bos_id = tokenizer.char2idx[tokenizer.BOS]
self.eos_id = tokenizer.char2idx[tokenizer.EOS]
self.encoder_depth = encoder_depth
self.decoder_depth = decoder_depth
self.encoder_num_heads = encoder_num_heads
self.decoder_num_heads = decoder_num_heads
self.device = device
if use_raw_patches:
self.backbone = PatchEmbedder(imgsz=self.imgsz,
out_dim=self.dim,
patch_size=patch_size,
device=self.device)
else:
self.backbone = Backbone(model=model,
imgsz=self.imgsz,
device=self.device)
self.encoder = ViTEncoder(dim=self.dim,
in_shape=self.backbone.bb_out_shape,
num_heads=encoder_num_heads,
num_blocks=encoder_depth,
dropout=dropout,
device=self.device)
self.decoder = ViTDecoder(vocab_size=self.vocap_size,
pad_id=self.pad_id,
bos_id=self.bos_id,
eos_id=self.eos_id,
dim=self.dim,
depth=decoder_depth,
heads=decoder_num_heads,
dropout=dropout,
device=self.device)
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
def _preprocess(self, image_path:os.PathLike):
im = Image.open(image_path).convert("RGB").resize((self.imgsz, self.imgsz))
im = np.array(im)
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).float() / 255.0
return im
@torch.no_grad()
def predict(self, image:torch.Tensor|os.PathLike, maxlen:int=None):
if isinstance(image,str):
image = self._preprocess(image).to(self.device)
self.eval()
assert image.ndim == 4, "x must be [B,C,H,W]"
pred = self.forward(image, max_len=maxlen if maxlen is not None else self.tokenizer.max_seq_len)
B = pred.size(0)
for i in range(B):
cap = self.tokenizer.decode(pred[i].tolist())
os.makedirs(f"Inference_outs/preds", exist_ok=True)
save_pred(image[i], cap, os.path.join(f"Inference_outs/preds", f"{i}.png"))
def forward(self, x, tokens_in=None, max_len=None):
x = self.backbone(x)
x = self.encoder(x)
x = self.decoder.forward(x, tokens_in) if self.training else self.decoder.generate(x, max_len=max_len if max_len is not None else self.tokenizer.max_seq_len)
return x
@classmethod
def load_from_checkpoint(cls,
path:str,
tokenizer:Tokenizer,
freeze_backbone:bool=False,
dropout:float=0.1,
device=torch.device('cpu')):
from ultralytics import YOLO
ckpt = torch.load(path, map_location=device)
sd = ckpt.get("model_state_dict", None)
imgsz = ckpt.get("imgsz",640)
dim = ckpt.get("dim", 512)
encoder_depth = ckpt.get("encoder_depth", 3)
decoder_depth = ckpt.get("decoder_depth", 3)
encoder_num_heads = ckpt.get("encoder_num_heads", 8)
decoder_num_heads = ckpt.get("decoder_num_heads", 8)
vocab = ckpt.get("vocab", None)
model_name = ckpt.get("model_name", "yolo11n.pt")
yolo_model = YOLO(model_name)
model = cls(
tokenizer=tokenizer.set_vocab(vocab), # load training vocabulary
model=yolo_model,
imgsz=imgsz,
dim=dim,
encoder_depth=encoder_depth,
decoder_depth=decoder_depth,
encoder_num_heads=encoder_num_heads,
decoder_num_heads=decoder_num_heads,
dropout=dropout,
freeze_backbone=freeze_backbone,
device=device
)
res = model.load_state_dict(sd, strict=False)
print(res.missing_keys)
print(res.unexpected_keys)
return model
def train(self, mode:bool=True, **kwargs):
has_fit_args = ("imagepaths" in kwargs)
if not has_fit_args:
return super().train(mode)
super().train(True)
trainer = Trainer(model=self, device=self.device)
return trainer.fit(**kwargs, tokenizer=self.tokenizer)
def export(self, path:str="model.onnx"):
import onnx, onnxsim
dummy = torch.rand(1, 3, self.imgsz, self.imgsz, device=self.device)
self.train(False)
torch.onnx.export(self, (dummy), path, dynamo=True)
model = onnx.load(path)
model, check = onnxsim.simplify(model)
model = onnx.shape_inference.infer_shapes(model)
onnx.save(model, path)