Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from . import UltralyticsModel | |
| from .backbone import Backbone, PatchEmbedder | |
| from .encoder import ViTEncoder | |
| from .decoder import ViTDecoder | |
| from .tokenizer import Tokenizer | |
| from .trainer import Trainer, save_pred | |
| class Model(torch.nn.Module): | |
| def __init__(self, | |
| tokenizer, | |
| model=None, | |
| imgsz=640, | |
| dim=512, | |
| encoder_depth=3, | |
| decoder_depth=3, | |
| encoder_num_heads=8, | |
| decoder_num_heads=8, | |
| dropout=0.1, | |
| freeze_backbone=True, | |
| device=torch.device('cpu'), | |
| use_raw_patches: bool = False, # if True, use raw patches not included backbone. | |
| patch_size: int = 16 # if use_raw_patches is True | |
| ): | |
| super().__init__() | |
| self.model_name = model.model_name if model is not None else None | |
| self.tokenizer = tokenizer | |
| self.vocap_size = tokenizer.vocap_size | |
| self.imgsz = max(64, 32 * (imgsz // 32)) | |
| self.dim = dim | |
| self.pad_id = tokenizer.char2idx[tokenizer.PAD] | |
| self.bos_id = tokenizer.char2idx[tokenizer.BOS] | |
| self.eos_id = tokenizer.char2idx[tokenizer.EOS] | |
| self.encoder_depth = encoder_depth | |
| self.decoder_depth = decoder_depth | |
| self.encoder_num_heads = encoder_num_heads | |
| self.decoder_num_heads = decoder_num_heads | |
| self.device = device | |
| if use_raw_patches: | |
| self.backbone = PatchEmbedder(imgsz=self.imgsz, | |
| out_dim=self.dim, | |
| patch_size=patch_size, | |
| device=self.device) | |
| else: | |
| self.backbone = Backbone(model=model, | |
| imgsz=self.imgsz, | |
| device=self.device) | |
| self.encoder = ViTEncoder(dim=self.dim, | |
| in_shape=self.backbone.bb_out_shape, | |
| num_heads=encoder_num_heads, | |
| num_blocks=encoder_depth, | |
| dropout=dropout, | |
| device=self.device) | |
| self.decoder = ViTDecoder(vocab_size=self.vocap_size, | |
| pad_id=self.pad_id, | |
| bos_id=self.bos_id, | |
| eos_id=self.eos_id, | |
| dim=self.dim, | |
| depth=decoder_depth, | |
| heads=decoder_num_heads, | |
| dropout=dropout, | |
| device=self.device) | |
| if freeze_backbone: | |
| for p in self.backbone.parameters(): | |
| p.requires_grad = False | |
| def _preprocess(self, image_path:os.PathLike): | |
| im = Image.open(image_path).convert("RGB").resize((self.imgsz, self.imgsz)) | |
| im = np.array(im) | |
| im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).float() / 255.0 | |
| return im | |
| def predict(self, image:torch.Tensor|os.PathLike, maxlen:int=None): | |
| if isinstance(image,str): | |
| image = self._preprocess(image).to(self.device) | |
| self.eval() | |
| assert image.ndim == 4, "x must be [B,C,H,W]" | |
| pred = self.forward(image, max_len=maxlen if maxlen is not None else self.tokenizer.max_seq_len) | |
| B = pred.size(0) | |
| for i in range(B): | |
| cap = self.tokenizer.decode(pred[i].tolist()) | |
| os.makedirs(f"Inference_outs/preds", exist_ok=True) | |
| save_pred(image[i], cap, os.path.join(f"Inference_outs/preds", f"{i}.png")) | |
| def forward(self, x, tokens_in=None, max_len=None): | |
| x = self.backbone(x) | |
| x = self.encoder(x) | |
| x = self.decoder.forward(x, tokens_in) if self.training else self.decoder.generate(x, max_len=max_len if max_len is not None else self.tokenizer.max_seq_len) | |
| return x | |
| def load_from_checkpoint(cls, | |
| path:str, | |
| tokenizer:Tokenizer, | |
| freeze_backbone:bool=False, | |
| dropout:float=0.1, | |
| device=torch.device('cpu')): | |
| from ultralytics import YOLO | |
| ckpt = torch.load(path, map_location=device) | |
| sd = ckpt.get("model_state_dict", None) | |
| imgsz = ckpt.get("imgsz",640) | |
| dim = ckpt.get("dim", 512) | |
| encoder_depth = ckpt.get("encoder_depth", 3) | |
| decoder_depth = ckpt.get("decoder_depth", 3) | |
| encoder_num_heads = ckpt.get("encoder_num_heads", 8) | |
| decoder_num_heads = ckpt.get("decoder_num_heads", 8) | |
| vocab = ckpt.get("vocab", None) | |
| model_name = ckpt.get("model_name", "yolo11n.pt") | |
| yolo_model = YOLO(model_name) | |
| model = cls( | |
| tokenizer=tokenizer.set_vocab(vocab), # load training vocabulary | |
| model=yolo_model, | |
| imgsz=imgsz, | |
| dim=dim, | |
| encoder_depth=encoder_depth, | |
| decoder_depth=decoder_depth, | |
| encoder_num_heads=encoder_num_heads, | |
| decoder_num_heads=decoder_num_heads, | |
| dropout=dropout, | |
| freeze_backbone=freeze_backbone, | |
| device=device | |
| ) | |
| res = model.load_state_dict(sd, strict=False) | |
| print(res.missing_keys) | |
| print(res.unexpected_keys) | |
| return model | |
| def train(self, mode:bool=True, **kwargs): | |
| has_fit_args = ("imagepaths" in kwargs) | |
| if not has_fit_args: | |
| return super().train(mode) | |
| super().train(True) | |
| trainer = Trainer(model=self, device=self.device) | |
| return trainer.fit(**kwargs, tokenizer=self.tokenizer) | |
| def export(self, path:str="model.onnx"): | |
| import onnx, onnxsim | |
| dummy = torch.rand(1, 3, self.imgsz, self.imgsz, device=self.device) | |
| self.train(False) | |
| torch.onnx.export(self, (dummy), path, dynamo=True) | |
| model = onnx.load(path) | |
| model, check = onnxsim.simplify(model) | |
| model = onnx.shape_inference.infer_shapes(model) | |
| onnx.save(model, path) | |