|
|
"""ResNet inference service implementation.""" |
|
|
|
|
|
import base64 |
|
|
import os |
|
|
from io import BytesIO |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
from transformers import AutoImageProcessor, ResNetForImageClassification |
|
|
|
|
|
from app.core.logging import logger |
|
|
from app.services.base import InferenceService |
|
|
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse |
|
|
|
|
|
from app.services.models.alignment_pretrained.unet import UNetImageDecoder |
|
|
|
|
|
from app.services.models.DinoLORA import DINOEncoderLoRA |
|
|
from app.services.models.alignment_pretrained.model_with_bce_images_dino import MMModerator |
|
|
|
|
|
|
|
|
class DINODINOProcessor: |
|
|
"""Image processor for DINO and DINO models with 224x224 resizing and normalization.""" |
|
|
|
|
|
def __init__(self, image_size: int = 512): |
|
|
self.image_size = image_size |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
def __call__(self, image: Image.Image, return_tensors: str = "pt"): |
|
|
""" |
|
|
Process an image for DINO/DINO input. |
|
|
|
|
|
Args: |
|
|
image: PIL Image object |
|
|
return_tensors: Format of returned tensors (default: "pt" for PyTorch) |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'pixel_values' key containing the processed tensor |
|
|
""" |
|
|
if not isinstance(image, Image.Image): |
|
|
raise ValueError("Input must be a PIL Image") |
|
|
|
|
|
|
|
|
pixel_values = self.transform(image) |
|
|
|
|
|
|
|
|
if pixel_values.dim() == 3: |
|
|
pixel_values = pixel_values.unsqueeze(0) |
|
|
|
|
|
return {"pixel_values": pixel_values} |
|
|
|
|
|
|
|
|
def create_vision_encoder(): |
|
|
REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3" |
|
|
|
|
|
|
|
|
encoder = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights="https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoib3VpbXR2cHlhZXE5c2JwajNucnN3aWF2IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NTk5MDI0NDF9fX1dfQ__&Signature=GIYJR4%7ESJVx0gkcm7lgAvDljIfpR30MXgWb2VpCqbDeVpnwjn97k%7EOcPeGF-lkR0q1Sn3Iw5Y3iYWqspcpPoDJ4FXUmMKhWtd-m00HO73Aknq2kyrKVMBpzwQB-k-2zZe7okJfXTj46EWbzu9mNcxt%7ErDPe7phQpRJi8Dleida1BJ823oXFx8d7oRSa4NDSzT2TNXqNNZ8ux7N0aDfdT9dupEeEr4AP06LhYB2I7kF%7Ef4bvKQsKnlPMVDAADyYG9nQ7HqAW41LaWZtR-BrDGm%7ESNu-6L44cUVnk3qEPVRMQB4GW7ixRGGhtr37F6HVz%7EKilrCpivFD6ej4reNUWaGQ__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=829796286371500") |
|
|
model = DINOEncoderLoRA(encoder, r=16, emb_dim=1024, use_lora=True) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
class DINOInferenceService(InferenceService[ImageRequest, PredictionResponse]): |
|
|
"""ResNet-18 inference service for image classification.""" |
|
|
|
|
|
def __init__(self, model_name: str = "microsoft/resnet-18"): |
|
|
self.model_name = model_name |
|
|
|
|
|
self.processor = DINODINOProcessor(image_size=512) |
|
|
self._is_loaded = False |
|
|
self.model_path = os.path.join("models", model_name) |
|
|
|
|
|
pretraining = False |
|
|
num_classes = 4 |
|
|
self.vision_encoder = create_vision_encoder() |
|
|
|
|
|
|
|
|
self.unet_decoder = UNetImageDecoder( |
|
|
num_patches=1024, |
|
|
token_dim=768, |
|
|
out_channels=1, |
|
|
base_channels=128, |
|
|
img_size=512, |
|
|
grid_hw=(32, 32) |
|
|
) |
|
|
self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes) |
|
|
self.model.eval() |
|
|
logger.info(f"Initializing DINO service: {self.model_path}") |
|
|
|
|
|
def load_model(self) -> None: |
|
|
if self._is_loaded: |
|
|
return |
|
|
|
|
|
if not os.path.exists(self.model_path): |
|
|
raise FileNotFoundError(f"Model not found: {self.model_path}") |
|
|
|
|
|
config_path = os.path.join(self.model_path, "config.json") |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loading model from {self.model_path}") |
|
|
checkpoint_path = os.path.join(self.model_path, "model_state.pt") |
|
|
checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt") |
|
|
checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt") |
|
|
raw = torch.load(checkpoint_path) |
|
|
raw_encoder = torch.load(checkpoint_path_encoder) |
|
|
raw_decoder = torch.load(checkpoint_path_decoder) |
|
|
sd = raw.get("model_state_dict", raw) |
|
|
new_sd = {} |
|
|
for k, v in sd.items(): |
|
|
new_key = k.replace("module.", "") |
|
|
new_sd[new_key] = v |
|
|
|
|
|
self.model.load_state_dict(new_sd) |
|
|
sd_encoder = raw_encoder.get("model_state_dict", raw_encoder) |
|
|
new_sd_encoder = {} |
|
|
for k, v in sd_encoder.items(): |
|
|
new_key = k.replace("module.", "") |
|
|
new_sd_encoder[new_key] = v |
|
|
self.vision_encoder.load_state_dict(new_sd_encoder) |
|
|
sd_decoder = raw_decoder.get("model_state_dict", raw_decoder) |
|
|
new_sd_decoder = {} |
|
|
for k, v in sd_decoder.items(): |
|
|
new_key = k.replace("module.", "") |
|
|
new_sd_decoder[new_key] = v |
|
|
self.unet_decoder.load_state_dict(new_sd_decoder) |
|
|
|
|
|
self._is_loaded = True |
|
|
|
|
|
|
|
|
def predict(self, request: ImageRequest) -> PredictionResponse: |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("model is not loaded") |
|
|
|
|
|
assert self.processor is not None |
|
|
|
|
|
assert self.model is not None |
|
|
image_data = base64.b64decode(request.image.data) |
|
|
image = Image.open(BytesIO(image_data)) |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
processed = self.processor(image, return_tensors="pt") |
|
|
pixel_values = processed["pixel_values"] |
|
|
|
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
|
|
|
logits,logits_multi_cls, losses, label,data_label, image_recon, overlay, shuffled_images, gt_masks = self.model(images=pixel_values) |
|
|
|
|
|
logprobs = torch.nn.functional.log_softmax(logits_multi_cls[:len(Labels)]).tolist()[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(logprobs) |
|
|
return PredictionResponse( |
|
|
logprobs=logprobs, |
|
|
localizationMask=None, |
|
|
) |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
return self._is_loaded |
|
|
|