image-moderator-v0 / code /inference.py
TomLEE2026's picture
Upload 200 files
0bfcee2 verified
Raw
History Blame Contribute Delete
2.28 kB
import torch
import os
from engine.core import YAMLConfig
from io import BytesIO
from PIL import Image
import torchvision.transforms as T
import json
import base64
class Model(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.model = cfg.model.deploy()
self.postprocessor = cfg.postprocessor.deploy()
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
outputs = self.postprocessor(outputs, orig_target_sizes)
return outputs
def model_fn(model_dir, context):
model_ckpt_path = os.path.join(model_dir, 'deim_dfine_hgnetv2_l_coco_50e.pth')
config_path = os.path.join(model_dir, 'code', 'configs', 'deim_dfine', 'deim_hgnetv2_l_coco.yml')
cfg = YAMLConfig(config_path, resume=model_ckpt_path)
if 'HGNetv2' in cfg.yaml_cfg:
cfg.yaml_cfg['HGNetv2']['pretrained'] = False
checkpoint = torch.load(model_ckpt_path, map_location='cpu')
state = checkpoint['model']
# Load train mode state and convert to deploy mode
cfg.model.load_state_dict(state)
dfine_model = Model(cfg)
return dfine_model
def input_fn(request_body, request_content_type, context):
if request_content_type != "application/json":
raise Exception(f"Invalid content type: {request_content_type}")
body = json.loads(request_body)
# Decode base64 string
img_bytes = base64.b64decode(body["image"])
img = BytesIO(img_bytes)
# Read the image
pil_image = Image.open(img).convert('RGB')
w, h = pil_image.size
orig_size = torch.tensor([[w, h]])
transforms = T.Compose([T.Resize((640, 640)), T.ToTensor()])
im_data = transforms(pil_image).unsqueeze(0)
return (im_data, orig_size)
def predict_fn(input_data, model):
im_data, orig_size = input_data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
with torch.no_grad():
outputs = model(im_data.to(device), orig_size.to(device))
result = {
"labels": outputs[0].cpu().tolist(),
"boxes": outputs[1].cpu().tolist(),
"scores": outputs[2].cpu().tolist()
}
return result
def output_fn(prediction, response_content_type):
return json.dumps(prediction)