safe-image-challenge / app /services /inference_clip.py
Shahidmuneer's picture
Upload folder using huggingface_hub
8bd3ef8 verified
"""ResNet inference service implementation."""
import base64
import os
from io import BytesIO
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
from app.services.models.alignment_pretrained.unet import UNetImageDecoder
from app.services.models.CLIPSvD import CLIPSvD
from app.services.models.alignment_pretrained.model_with_bce_images_blip import MMModerator
class CLIPDINOProcessor:
"""Image processor for CLIP and DINO models with 224x224 resizing and normalization."""
def __init__(self, image_size: int = 224):
self.image_size = image_size
# Standard ImageNet normalization used by both CLIP and DINO
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet mean
std=[0.229, 0.224, 0.225] # ImageNet std
)
])
def __call__(self, image: Image.Image, return_tensors: str = "pt"):
"""
Process an image for CLIP/DINO input.
Args:
image: PIL Image object
return_tensors: Format of returned tensors (default: "pt" for PyTorch)
Returns:
Dictionary with 'pixel_values' key containing the processed tensor
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
# Apply transforms
pixel_values = self.transform(image)
# Add batch dimension if needed
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
return {"pixel_values": pixel_values}
def create_vision_encoder():
REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3"
model = CLIPSvD()
return model
class CLIPInferenceService(InferenceService[ImageRequest, PredictionResponse]):
"""ResNet-18 inference service for image classification."""
def __init__(self, model_name: str = "microsoft/resnet-18"):
self.model_name = model_name
# self.model = None
self.processor = CLIPDINOProcessor(image_size=224)
self._is_loaded = False
self.model_path = os.path.join("models", model_name)
pretraining = False
num_classes = 4
self.vision_encoder = create_vision_encoder()
# self.vision_encoder.to(device=device, dtype=torch.float32)
self.unet_decoder = UNetImageDecoder(
num_patches=256, # 7 × 7 grid (ViT-B/32)
token_dim=1024, # ViT-B/32 embedding dim
out_channels=3, # mask or 3 for RGB
base_channels=256,
img_size=256,
grid_hw=(16, 16) # explicitly set to match patch grid
)
self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes)
logger.info(f"Initializing CLIP service: {self.model_path}")
def load_model(self) -> None:
if self._is_loaded:
return
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found: {self.model_path}")
config_path = os.path.join(self.model_path, "config.json")
# if not os.path.exists(config_path):
# raise FileNotFoundError(f"Config not found: {config_path}")
logger.info(f"Loading model from {self.model_path}")
checkpoint_path = os.path.join(self.model_path, "model_state.pt")
checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt")
checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt")
raw = torch.load(checkpoint_path)
raw_encoder = torch.load(checkpoint_path_encoder)
raw_decoder = torch.load(checkpoint_path_decoder)
sd = raw.get("model_state_dict", raw)
new_sd = {}
for k, v in sd.items():
new_key = k.replace("module.", "")
new_sd[new_key] = v
self.model.load_state_dict(new_sd)
sd_encoder = raw_encoder.get("model_state_dict", raw_encoder)
new_sd_encoder = {}
for k, v in sd_encoder.items():
new_key = k.replace("module.", "")
new_sd_encoder[new_key] = v
self.vision_encoder.load_state_dict(new_sd_encoder)
sd_decoder = raw_decoder.get("model_state_dict", raw_decoder)
new_sd_decoder = {}
for k, v in sd_decoder.items():
new_key = k.replace("module.", "")
new_sd_decoder[new_key] = v
self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default
self._is_loaded = True
# logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
def predict(self, request: ImageRequest) -> PredictionResponse:
if not self.is_loaded:
raise RuntimeError("model is not loaded")
assert self.processor is not None
assert self.model is not None
image_data = base64.b64decode(request.image.data)
image = Image.open(BytesIO(image_data))
if image.mode != 'RGB':
image = image.convert('RGB')
processed = self.processor(image, return_tensors="pt")
pixel_values = processed["pixel_values"]
self.model.eval()
with torch.no_grad():
# logits, losses, label, image_recon, overlay =
logits_cls,logits, losses, labels_expanded,data_labels_expanded, image_recon, overlay = self.model(images=pixel_values)
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()[0]
# x = image.width // 3
# y = image.height // 3
# # Row-major order
# mask = np.zeros((image.height, image.width), dtype=np.uint8)
# mask[y:(2*y), x:(2*x)] = 1
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
print(logprobs)
return PredictionResponse(
logprobs=logprobs,
localizationMask=None,
)
@property
def is_loaded(self) -> bool:
return self._is_loaded