Spaces:
Sleeping
Sleeping
File size: 6,353 Bytes
9ec3d0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import torch
import numpy as np
from PIL import Image
from . import UltralyticsModel
from .backbone import Backbone, PatchEmbedder
from .encoder import ViTEncoder
from .decoder import ViTDecoder
from .tokenizer import Tokenizer
from .trainer import Trainer, save_pred
class Model(torch.nn.Module):
def __init__(self,
tokenizer,
model=None,
imgsz=640,
dim=512,
encoder_depth=3,
decoder_depth=3,
encoder_num_heads=8,
decoder_num_heads=8,
dropout=0.1,
freeze_backbone=True,
device=torch.device('cpu'),
use_raw_patches: bool = False, # if True, use raw patches not included backbone.
patch_size: int = 16 # if use_raw_patches is True
):
super().__init__()
self.model_name = model.model_name if model is not None else None
self.tokenizer = tokenizer
self.vocap_size = tokenizer.vocap_size
self.imgsz = max(64, 32 * (imgsz // 32))
self.dim = dim
self.pad_id = tokenizer.char2idx[tokenizer.PAD]
self.bos_id = tokenizer.char2idx[tokenizer.BOS]
self.eos_id = tokenizer.char2idx[tokenizer.EOS]
self.encoder_depth = encoder_depth
self.decoder_depth = decoder_depth
self.encoder_num_heads = encoder_num_heads
self.decoder_num_heads = decoder_num_heads
self.device = device
if use_raw_patches:
self.backbone = PatchEmbedder(imgsz=self.imgsz,
out_dim=self.dim,
patch_size=patch_size,
device=self.device)
else:
self.backbone = Backbone(model=model,
imgsz=self.imgsz,
device=self.device)
self.encoder = ViTEncoder(dim=self.dim,
in_shape=self.backbone.bb_out_shape,
num_heads=encoder_num_heads,
num_blocks=encoder_depth,
dropout=dropout,
device=self.device)
self.decoder = ViTDecoder(vocab_size=self.vocap_size,
pad_id=self.pad_id,
bos_id=self.bos_id,
eos_id=self.eos_id,
dim=self.dim,
depth=decoder_depth,
heads=decoder_num_heads,
dropout=dropout,
device=self.device)
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
def _preprocess(self, image_path:os.PathLike):
im = Image.open(image_path).convert("RGB").resize((self.imgsz, self.imgsz))
im = np.array(im)
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).float() / 255.0
return im
@torch.no_grad()
def predict(self, image:torch.Tensor|os.PathLike, maxlen:int=None):
if isinstance(image,str):
image = self._preprocess(image).to(self.device)
self.eval()
assert image.ndim == 4, "x must be [B,C,H,W]"
pred = self.forward(image, max_len=maxlen if maxlen is not None else self.tokenizer.max_seq_len)
B = pred.size(0)
for i in range(B):
cap = self.tokenizer.decode(pred[i].tolist())
os.makedirs(f"Inference_outs/preds", exist_ok=True)
save_pred(image[i], cap, os.path.join(f"Inference_outs/preds", f"{i}.png"))
def forward(self, x, tokens_in=None, max_len=None):
x = self.backbone(x)
x = self.encoder(x)
x = self.decoder.forward(x, tokens_in) if self.training else self.decoder.generate(x, max_len=max_len if max_len is not None else self.tokenizer.max_seq_len)
return x
@classmethod
def load_from_checkpoint(cls,
path:str,
tokenizer:Tokenizer,
freeze_backbone:bool=False,
dropout:float=0.1,
device=torch.device('cpu')):
from ultralytics import YOLO
ckpt = torch.load(path, map_location=device)
sd = ckpt.get("model_state_dict", None)
imgsz = ckpt.get("imgsz",640)
dim = ckpt.get("dim", 512)
encoder_depth = ckpt.get("encoder_depth", 3)
decoder_depth = ckpt.get("decoder_depth", 3)
encoder_num_heads = ckpt.get("encoder_num_heads", 8)
decoder_num_heads = ckpt.get("decoder_num_heads", 8)
vocab = ckpt.get("vocab", None)
model_name = ckpt.get("model_name", "yolo11n.pt")
yolo_model = YOLO(model_name)
model = cls(
tokenizer=tokenizer.set_vocab(vocab), # load training vocabulary
model=yolo_model,
imgsz=imgsz,
dim=dim,
encoder_depth=encoder_depth,
decoder_depth=decoder_depth,
encoder_num_heads=encoder_num_heads,
decoder_num_heads=decoder_num_heads,
dropout=dropout,
freeze_backbone=freeze_backbone,
device=device
)
res = model.load_state_dict(sd, strict=False)
print(res.missing_keys)
print(res.unexpected_keys)
return model
def train(self, mode:bool=True, **kwargs):
has_fit_args = ("imagepaths" in kwargs)
if not has_fit_args:
return super().train(mode)
super().train(True)
trainer = Trainer(model=self, device=self.device)
return trainer.fit(**kwargs, tokenizer=self.tokenizer)
def export(self, path:str="model.onnx"):
import onnx, onnxsim
dummy = torch.rand(1, 3, self.imgsz, self.imgsz, device=self.device)
self.train(False)
torch.onnx.export(self, (dummy), path, dynamo=True)
model = onnx.load(path)
model, check = onnxsim.simplify(model)
model = onnx.shape_inference.infer_shapes(model)
onnx.save(model, path)
|