safe-image-challenge / app /services /inference_dino.py
Shahidmuneer's picture
Upload folder using huggingface_hub
8bd3ef8 verified
"""ResNet inference service implementation."""
import base64
import os
from io import BytesIO
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
from app.services.models.alignment_pretrained.unet import UNetImageDecoder
from app.services.models.DinoLORA import DINOEncoderLoRA
from app.services.models.alignment_pretrained.model_with_bce_images_dino import MMModerator
class DINODINOProcessor:
"""Image processor for DINO and DINO models with 224x224 resizing and normalization."""
def __init__(self, image_size: int = 512):
self.image_size = image_size
# Standard ImageNet normalization used by both DINO and DINO
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet mean
std=[0.229, 0.224, 0.225] # ImageNet std
)
])
def __call__(self, image: Image.Image, return_tensors: str = "pt"):
"""
Process an image for DINO/DINO input.
Args:
image: PIL Image object
return_tensors: Format of returned tensors (default: "pt" for PyTorch)
Returns:
Dictionary with 'pixel_values' key containing the processed tensor
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
# Apply transforms
pixel_values = self.transform(image)
# Add batch dimension if needed
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
return {"pixel_values": pixel_values}
def create_vision_encoder():
REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3"
encoder = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights="https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoib3VpbXR2cHlhZXE5c2JwajNucnN3aWF2IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NTk5MDI0NDF9fX1dfQ__&Signature=GIYJR4%7ESJVx0gkcm7lgAvDljIfpR30MXgWb2VpCqbDeVpnwjn97k%7EOcPeGF-lkR0q1Sn3Iw5Y3iYWqspcpPoDJ4FXUmMKhWtd-m00HO73Aknq2kyrKVMBpzwQB-k-2zZe7okJfXTj46EWbzu9mNcxt%7ErDPe7phQpRJi8Dleida1BJ823oXFx8d7oRSa4NDSzT2TNXqNNZ8ux7N0aDfdT9dupEeEr4AP06LhYB2I7kF%7Ef4bvKQsKnlPMVDAADyYG9nQ7HqAW41LaWZtR-BrDGm%7ESNu-6L44cUVnk3qEPVRMQB4GW7ixRGGhtr37F6HVz%7EKilrCpivFD6ej4reNUWaGQ__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=829796286371500")
model = DINOEncoderLoRA(encoder, r=16, emb_dim=1024, use_lora=True)
return model
class DINOInferenceService(InferenceService[ImageRequest, PredictionResponse]):
"""ResNet-18 inference service for image classification."""
def __init__(self, model_name: str = "microsoft/resnet-18"):
self.model_name = model_name
# self.model = None
self.processor = DINODINOProcessor(image_size=512)
self._is_loaded = False
self.model_path = os.path.join("models", model_name)
pretraining = False
num_classes = 4
self.vision_encoder = create_vision_encoder()
# self.vision_encoder.to(device=device, dtype=torch.float32)
self.unet_decoder = UNetImageDecoder(
num_patches=1024, # MUST match N=1024
token_dim=768, # because tokens are [B,1024,768]
out_channels=1, # RGB reconstructed output
base_channels=128, # recommended for 512px
img_size=512, # image resolution
grid_hw=(32, 32) # MUST match N=1024
)
self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes)
self.model.eval()
logger.info(f"Initializing DINO service: {self.model_path}")
def load_model(self) -> None:
if self._is_loaded:
return
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found: {self.model_path}")
config_path = os.path.join(self.model_path, "config.json")
# if not os.path.exists(config_path):
# raise FileNotFoundError(f"Config not found: {config_path}")
logger.info(f"Loading model from {self.model_path}")
checkpoint_path = os.path.join(self.model_path, "model_state.pt")
checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt")
checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt")
raw = torch.load(checkpoint_path)
raw_encoder = torch.load(checkpoint_path_encoder)
raw_decoder = torch.load(checkpoint_path_decoder)
sd = raw.get("model_state_dict", raw)
new_sd = {}
for k, v in sd.items():
new_key = k.replace("module.", "")
new_sd[new_key] = v
self.model.load_state_dict(new_sd)
sd_encoder = raw_encoder.get("model_state_dict", raw_encoder)
new_sd_encoder = {}
for k, v in sd_encoder.items():
new_key = k.replace("module.", "")
new_sd_encoder[new_key] = v
self.vision_encoder.load_state_dict(new_sd_encoder)
sd_decoder = raw_decoder.get("model_state_dict", raw_decoder)
new_sd_decoder = {}
for k, v in sd_decoder.items():
new_key = k.replace("module.", "")
new_sd_decoder[new_key] = v
self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default
self._is_loaded = True
# logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
def predict(self, request: ImageRequest) -> PredictionResponse:
if not self.is_loaded:
raise RuntimeError("model is not loaded")
assert self.processor is not None
assert self.model is not None
image_data = base64.b64decode(request.image.data)
image = Image.open(BytesIO(image_data))
if image.mode != 'RGB':
image = image.convert('RGB')
processed = self.processor(image, return_tensors="pt")
pixel_values = processed["pixel_values"]
self.model.eval()
with torch.no_grad():
# logits, losses, label, image_recon, overlay =
logits,logits_multi_cls, losses, label,data_label, image_recon, overlay, shuffled_images, gt_masks = self.model(images=pixel_values)
logprobs = torch.nn.functional.log_softmax(logits_multi_cls[:len(Labels)]).tolist()[0]
# x = image.width // 3
# y = image.height // 3
# # Row-major order
# mask = np.zeros((image.height, image.width), dtype=np.uint8)
# mask[y:(2*y), x:(2*x)] = 1
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
print(logprobs)
return PredictionResponse(
logprobs=logprobs,
localizationMask=None,
)
@property
def is_loaded(self) -> bool:
return self._is_loaded