"""ResNet inference service implementation.""" import base64 import os from io import BytesIO import numpy as np import torch import torchvision.transforms as transforms from PIL import Image from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped] from app.core.logging import logger from app.services.base import InferenceService from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse from app.services.models.alignment_pretrained.unet import UNetImageDecoder from app.services.models.DinoLORA import DINOEncoderLoRA from app.services.models.alignment_pretrained.model_with_bce_images_dino import MMModerator class DINODINOProcessor: """Image processor for DINO and DINO models with 224x224 resizing and normalization.""" def __init__(self, image_size: int = 512): self.image_size = image_size # Standard ImageNet normalization used by both DINO and DINO self.transform = transforms.Compose([ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], # ImageNet mean std=[0.229, 0.224, 0.225] # ImageNet std ) ]) def __call__(self, image: Image.Image, return_tensors: str = "pt"): """ Process an image for DINO/DINO input. Args: image: PIL Image object return_tensors: Format of returned tensors (default: "pt" for PyTorch) Returns: Dictionary with 'pixel_values' key containing the processed tensor """ if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") # Apply transforms pixel_values = self.transform(image) # Add batch dimension if needed if pixel_values.dim() == 3: pixel_values = pixel_values.unsqueeze(0) return {"pixel_values": pixel_values} def create_vision_encoder(): REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3" encoder = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights="https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoib3VpbXR2cHlhZXE5c2JwajNucnN3aWF2IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NTk5MDI0NDF9fX1dfQ__&Signature=GIYJR4%7ESJVx0gkcm7lgAvDljIfpR30MXgWb2VpCqbDeVpnwjn97k%7EOcPeGF-lkR0q1Sn3Iw5Y3iYWqspcpPoDJ4FXUmMKhWtd-m00HO73Aknq2kyrKVMBpzwQB-k-2zZe7okJfXTj46EWbzu9mNcxt%7ErDPe7phQpRJi8Dleida1BJ823oXFx8d7oRSa4NDSzT2TNXqNNZ8ux7N0aDfdT9dupEeEr4AP06LhYB2I7kF%7Ef4bvKQsKnlPMVDAADyYG9nQ7HqAW41LaWZtR-BrDGm%7ESNu-6L44cUVnk3qEPVRMQB4GW7ixRGGhtr37F6HVz%7EKilrCpivFD6ej4reNUWaGQ__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=829796286371500") model = DINOEncoderLoRA(encoder, r=16, emb_dim=1024, use_lora=True) return model class DINOInferenceService(InferenceService[ImageRequest, PredictionResponse]): """ResNet-18 inference service for image classification.""" def __init__(self, model_name: str = "microsoft/resnet-18"): self.model_name = model_name # self.model = None self.processor = DINODINOProcessor(image_size=512) self._is_loaded = False self.model_path = os.path.join("models", model_name) pretraining = False num_classes = 4 self.vision_encoder = create_vision_encoder() # self.vision_encoder.to(device=device, dtype=torch.float32) self.unet_decoder = UNetImageDecoder( num_patches=1024, # MUST match N=1024 token_dim=768, # because tokens are [B,1024,768] out_channels=1, # RGB reconstructed output base_channels=128, # recommended for 512px img_size=512, # image resolution grid_hw=(32, 32) # MUST match N=1024 ) self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes) self.model.eval() logger.info(f"Initializing DINO service: {self.model_path}") def load_model(self) -> None: if self._is_loaded: return if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model not found: {self.model_path}") config_path = os.path.join(self.model_path, "config.json") # if not os.path.exists(config_path): # raise FileNotFoundError(f"Config not found: {config_path}") logger.info(f"Loading model from {self.model_path}") checkpoint_path = os.path.join(self.model_path, "model_state.pt") checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt") checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt") raw = torch.load(checkpoint_path) raw_encoder = torch.load(checkpoint_path_encoder) raw_decoder = torch.load(checkpoint_path_decoder) sd = raw.get("model_state_dict", raw) new_sd = {} for k, v in sd.items(): new_key = k.replace("module.", "") new_sd[new_key] = v self.model.load_state_dict(new_sd) sd_encoder = raw_encoder.get("model_state_dict", raw_encoder) new_sd_encoder = {} for k, v in sd_encoder.items(): new_key = k.replace("module.", "") new_sd_encoder[new_key] = v self.vision_encoder.load_state_dict(new_sd_encoder) sd_decoder = raw_decoder.get("model_state_dict", raw_decoder) new_sd_decoder = {} for k, v in sd_decoder.items(): new_key = k.replace("module.", "") new_sd_decoder[new_key] = v self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default self._is_loaded = True # logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore def predict(self, request: ImageRequest) -> PredictionResponse: if not self.is_loaded: raise RuntimeError("model is not loaded") assert self.processor is not None assert self.model is not None image_data = base64.b64decode(request.image.data) image = Image.open(BytesIO(image_data)) if image.mode != 'RGB': image = image.convert('RGB') processed = self.processor(image, return_tensors="pt") pixel_values = processed["pixel_values"] self.model.eval() with torch.no_grad(): # logits, losses, label, image_recon, overlay = logits,logits_multi_cls, losses, label,data_label, image_recon, overlay, shuffled_images, gt_masks = self.model(images=pixel_values) logprobs = torch.nn.functional.log_softmax(logits_multi_cls[:len(Labels)]).tolist()[0] # x = image.width // 3 # y = image.height // 3 # # Row-major order # mask = np.zeros((image.height, image.width), dtype=np.uint8) # mask[y:(2*y), x:(2*x)] = 1 # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy()) # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy()) print(logprobs) return PredictionResponse( logprobs=logprobs, localizationMask=None, ) @property def is_loaded(self) -> bool: return self._is_loaded