MIML / inference.py
P-rateek's picture
Update inference.py
f57e2b8 verified
import base64
import io
import os
import sys
from typing import Dict, List, Any
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmcv.utils import Config
from PIL import Image
# Add current directory to path to import local modules
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
# Now we can import from the local mmseg and mmcv_custom
from modelsforIML.mmseg.datasets.pipelines import Compose
from modelsforIML.mmseg.models import build_segmentor
class Pipeline:
def __init__(self, model_path: str):
"""
Initializes the pipeline by loading the model and preprocessing steps.
Args:
model_path (str): The path to the model checkpoint file. It's automatically
passed by the Hugging Face infrastructure.
"""
# --- Device ---
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Model Configuration ---
# The config file path is relative to the repository root
config_path = 'models for IML/apscnet.py'
# The checkpoint path is also relative to the repository root
checkpoint_path = 'models for IML/APSC-Net.pth'
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint file not found at {checkpoint_path}. "
"Please download it and place it in the 'models for IML' directory."
)
cfg = Config.fromfile(config_path)
# --- Build Model ---
self.model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
load_checkpoint(self.model, checkpoint_path, map_location='cpu')
self.model.to(self.device)
self.model.eval()
# --- Build Preprocessing Pipeline ---
# We extract the transforms from the test_pipeline in the config
test_pipeline_cfg = cfg.data.test.pipeline[1]['transforms']
self.pipeline = Compose(test_pipeline_cfg)
def __call__(self, inputs: Image.Image) -> Dict[str, Any]:
"""
Performs inference on a single image.
Args:
inputs (Image.Image): A PIL Image to be processed.
Returns:
Dict[str, Any]: A dictionary containing the resulting mask as a base64 encoded string.
"""
# Convert PIL image to numpy array (RGB)
img = np.array(inputs.convert('RGB'))
# Prepare data for the pipeline
data = {'img': img, 'img_shape': img.shape, 'ori_shape': img.shape}
data = self.pipeline(data)
# Move data to the device
img_tensor = data['img'][0].unsqueeze(0).to(self.device)
# --- Inference ---
with torch.no_grad():
result = self.model(return_loss=False, img=[img_tensor])
# --- Post-process ---
# The model output is logits of shape (1, 2, H, W)
# We take argmax to get the class (0=authentic, 1=tampered)
mask_pred = result[0].argmax(0).astype(np.uint8)
# Convert mask to a visual format (0 -> 0, 1 -> 255)
mask_pred *= 255
# Create a PIL image from the numpy mask
mask_image = Image.fromarray(mask_pred, mode='L')
# --- Encode to Base64 ---
buffered = io.BytesIO()
mask_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return {"image": img_str}