|
|
"""ResNet inference service implementation.""" |
|
|
|
|
|
import base64 |
|
|
import os |
|
|
from io import BytesIO |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
from transformers import AutoImageProcessor, ResNetForImageClassification |
|
|
|
|
|
from app.core.logging import logger |
|
|
from app.services.base import InferenceService |
|
|
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse |
|
|
|
|
|
from app.services.models.alignment_pretrained.unet import UNetImageDecoder |
|
|
from app.services.models.CLIPSvD import CLIPSvD |
|
|
from app.services.models.alignment_pretrained.model_with_bce_images_blip import MMModerator |
|
|
|
|
|
|
|
|
class CLIPDINOProcessor: |
|
|
"""Image processor for CLIP and DINO models with 224x224 resizing and normalization.""" |
|
|
|
|
|
def __init__(self, image_size: int = 224): |
|
|
self.image_size = image_size |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
def __call__(self, image: Image.Image, return_tensors: str = "pt"): |
|
|
""" |
|
|
Process an image for CLIP/DINO input. |
|
|
|
|
|
Args: |
|
|
image: PIL Image object |
|
|
return_tensors: Format of returned tensors (default: "pt" for PyTorch) |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'pixel_values' key containing the processed tensor |
|
|
""" |
|
|
if not isinstance(image, Image.Image): |
|
|
raise ValueError("Input must be a PIL Image") |
|
|
|
|
|
|
|
|
pixel_values = self.transform(image) |
|
|
|
|
|
|
|
|
if pixel_values.dim() == 3: |
|
|
pixel_values = pixel_values.unsqueeze(0) |
|
|
|
|
|
return {"pixel_values": pixel_values} |
|
|
|
|
|
|
|
|
def create_vision_encoder(): |
|
|
REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3" |
|
|
model = CLIPSvD() |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
class CLIPInferenceService(InferenceService[ImageRequest, PredictionResponse]): |
|
|
"""ResNet-18 inference service for image classification.""" |
|
|
|
|
|
def __init__(self, model_name: str = "microsoft/resnet-18"): |
|
|
self.model_name = model_name |
|
|
|
|
|
self.processor = CLIPDINOProcessor(image_size=224) |
|
|
self._is_loaded = False |
|
|
self.model_path = os.path.join("models", model_name) |
|
|
|
|
|
pretraining = False |
|
|
num_classes = 4 |
|
|
self.vision_encoder = create_vision_encoder() |
|
|
|
|
|
self.unet_decoder = UNetImageDecoder( |
|
|
num_patches=256, |
|
|
token_dim=1024, |
|
|
out_channels=3, |
|
|
base_channels=256, |
|
|
img_size=256, |
|
|
grid_hw=(16, 16) |
|
|
) |
|
|
self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes) |
|
|
logger.info(f"Initializing CLIP service: {self.model_path}") |
|
|
|
|
|
def load_model(self) -> None: |
|
|
if self._is_loaded: |
|
|
return |
|
|
|
|
|
if not os.path.exists(self.model_path): |
|
|
raise FileNotFoundError(f"Model not found: {self.model_path}") |
|
|
|
|
|
config_path = os.path.join(self.model_path, "config.json") |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loading model from {self.model_path}") |
|
|
checkpoint_path = os.path.join(self.model_path, "model_state.pt") |
|
|
checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt") |
|
|
checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt") |
|
|
raw = torch.load(checkpoint_path) |
|
|
raw_encoder = torch.load(checkpoint_path_encoder) |
|
|
raw_decoder = torch.load(checkpoint_path_decoder) |
|
|
sd = raw.get("model_state_dict", raw) |
|
|
new_sd = {} |
|
|
for k, v in sd.items(): |
|
|
new_key = k.replace("module.", "") |
|
|
new_sd[new_key] = v |
|
|
|
|
|
self.model.load_state_dict(new_sd) |
|
|
sd_encoder = raw_encoder.get("model_state_dict", raw_encoder) |
|
|
new_sd_encoder = {} |
|
|
for k, v in sd_encoder.items(): |
|
|
new_key = k.replace("module.", "") |
|
|
new_sd_encoder[new_key] = v |
|
|
self.vision_encoder.load_state_dict(new_sd_encoder) |
|
|
sd_decoder = raw_decoder.get("model_state_dict", raw_decoder) |
|
|
new_sd_decoder = {} |
|
|
for k, v in sd_decoder.items(): |
|
|
new_key = k.replace("module.", "") |
|
|
new_sd_decoder[new_key] = v |
|
|
self.unet_decoder.load_state_dict(new_sd_decoder) |
|
|
|
|
|
self._is_loaded = True |
|
|
|
|
|
|
|
|
def predict(self, request: ImageRequest) -> PredictionResponse: |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("model is not loaded") |
|
|
|
|
|
assert self.processor is not None |
|
|
|
|
|
assert self.model is not None |
|
|
image_data = base64.b64decode(request.image.data) |
|
|
image = Image.open(BytesIO(image_data)) |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
processed = self.processor(image, return_tensors="pt") |
|
|
pixel_values = processed["pixel_values"] |
|
|
|
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
|
|
|
logits_cls,logits, losses, labels_expanded,data_labels_expanded, image_recon, overlay = self.model(images=pixel_values) |
|
|
|
|
|
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(logprobs) |
|
|
return PredictionResponse( |
|
|
logprobs=logprobs, |
|
|
localizationMask=None, |
|
|
) |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
return self._is_loaded |
|
|
|