"""ResNet inference service implementation.""" import base64 import os from io import BytesIO import numpy as np import torch import torchvision.transforms as transforms from PIL import Image from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped] from app.core.logging import logger from app.services.base import InferenceService from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse from app.services.models.alignment_pretrained.unet import UNetImageDecoder from app.services.models.CLIPSvD import CLIPSvD from app.services.models.alignment_pretrained.model_with_bce_images_blip import MMModerator class CLIPDINOProcessor: """Image processor for CLIP and DINO models with 224x224 resizing and normalization.""" def __init__(self, image_size: int = 224): self.image_size = image_size # Standard ImageNet normalization used by both CLIP and DINO self.transform = transforms.Compose([ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], # ImageNet mean std=[0.229, 0.224, 0.225] # ImageNet std ) ]) def __call__(self, image: Image.Image, return_tensors: str = "pt"): """ Process an image for CLIP/DINO input. Args: image: PIL Image object return_tensors: Format of returned tensors (default: "pt" for PyTorch) Returns: Dictionary with 'pixel_values' key containing the processed tensor """ if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") # Apply transforms pixel_values = self.transform(image) # Add batch dimension if needed if pixel_values.dim() == 3: pixel_values = pixel_values.unsqueeze(0) return {"pixel_values": pixel_values} def create_vision_encoder(): REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3" model = CLIPSvD() return model class CLIPInferenceService(InferenceService[ImageRequest, PredictionResponse]): """ResNet-18 inference service for image classification.""" def __init__(self, model_name: str = "microsoft/resnet-18"): self.model_name = model_name # self.model = None self.processor = CLIPDINOProcessor(image_size=224) self._is_loaded = False self.model_path = os.path.join("models", model_name) pretraining = False num_classes = 4 self.vision_encoder = create_vision_encoder() # self.vision_encoder.to(device=device, dtype=torch.float32) self.unet_decoder = UNetImageDecoder( num_patches=256, # 7 × 7 grid (ViT-B/32) token_dim=1024, # ViT-B/32 embedding dim out_channels=3, # mask or 3 for RGB base_channels=256, img_size=256, grid_hw=(16, 16) # explicitly set to match patch grid ) self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes) logger.info(f"Initializing CLIP service: {self.model_path}") def load_model(self) -> None: if self._is_loaded: return if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model not found: {self.model_path}") config_path = os.path.join(self.model_path, "config.json") # if not os.path.exists(config_path): # raise FileNotFoundError(f"Config not found: {config_path}") logger.info(f"Loading model from {self.model_path}") checkpoint_path = os.path.join(self.model_path, "model_state.pt") checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt") checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt") raw = torch.load(checkpoint_path) raw_encoder = torch.load(checkpoint_path_encoder) raw_decoder = torch.load(checkpoint_path_decoder) sd = raw.get("model_state_dict", raw) new_sd = {} for k, v in sd.items(): new_key = k.replace("module.", "") new_sd[new_key] = v self.model.load_state_dict(new_sd) sd_encoder = raw_encoder.get("model_state_dict", raw_encoder) new_sd_encoder = {} for k, v in sd_encoder.items(): new_key = k.replace("module.", "") new_sd_encoder[new_key] = v self.vision_encoder.load_state_dict(new_sd_encoder) sd_decoder = raw_decoder.get("model_state_dict", raw_decoder) new_sd_decoder = {} for k, v in sd_decoder.items(): new_key = k.replace("module.", "") new_sd_decoder[new_key] = v self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default self._is_loaded = True # logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore def predict(self, request: ImageRequest) -> PredictionResponse: if not self.is_loaded: raise RuntimeError("model is not loaded") assert self.processor is not None assert self.model is not None image_data = base64.b64decode(request.image.data) image = Image.open(BytesIO(image_data)) if image.mode != 'RGB': image = image.convert('RGB') processed = self.processor(image, return_tensors="pt") pixel_values = processed["pixel_values"] self.model.eval() with torch.no_grad(): # logits, losses, label, image_recon, overlay = logits_cls,logits, losses, labels_expanded,data_labels_expanded, image_recon, overlay = self.model(images=pixel_values) logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()[0] # x = image.width // 3 # y = image.height // 3 # # Row-major order # mask = np.zeros((image.height, image.width), dtype=np.uint8) # mask[y:(2*y), x:(2*x)] = 1 # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy()) # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy()) print(logprobs) return PredictionResponse( logprobs=logprobs, localizationMask=None, ) @property def is_loaded(self) -> bool: return self._is_loaded