File size: 3,452 Bytes
a2a1eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f57e2b8
 
a2a1eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767dcae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import base64
import io
import os
import sys
from typing import Dict, List, Any

import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmcv.utils import Config
from PIL import Image

# Add current directory to path to import local modules
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

# Now we can import from the local mmseg and mmcv_custom
from modelsforIML.mmseg.datasets.pipelines import Compose
from modelsforIML.mmseg.models import build_segmentor


class Pipeline:
    def __init__(self, model_path: str):
        """
        Initializes the pipeline by loading the model and preprocessing steps.
        Args:
            model_path (str): The path to the model checkpoint file. It's automatically
                              passed by the Hugging Face infrastructure.
        """
        # --- Device ---
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # --- Model Configuration ---
        # The config file path is relative to the repository root
        config_path = 'models for IML/apscnet.py'
        # The checkpoint path is also relative to the repository root
        checkpoint_path = 'models for IML/APSC-Net.pth'
        
        if not os.path.exists(checkpoint_path):
             raise FileNotFoundError(
                f"Checkpoint file not found at {checkpoint_path}. "
                "Please download it and place it in the 'models for IML' directory."
            )

        cfg = Config.fromfile(config_path)

        # --- Build Model ---
        self.model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
        load_checkpoint(self.model, checkpoint_path, map_location='cpu')
        self.model.to(self.device)
        self.model.eval()

        # --- Build Preprocessing Pipeline ---
        # We extract the transforms from the test_pipeline in the config
        test_pipeline_cfg = cfg.data.test.pipeline[1]['transforms']
        self.pipeline = Compose(test_pipeline_cfg)

    def __call__(self, inputs: Image.Image) -> Dict[str, Any]:
        """
        Performs inference on a single image.
        Args:
            inputs (Image.Image): A PIL Image to be processed.
        Returns:
            Dict[str, Any]: A dictionary containing the resulting mask as a base64 encoded string.
        """
        # Convert PIL image to numpy array (RGB)
        img = np.array(inputs.convert('RGB'))

        # Prepare data for the pipeline
        data = {'img': img, 'img_shape': img.shape, 'ori_shape': img.shape}
        data = self.pipeline(data)

        # Move data to the device
        img_tensor = data['img'][0].unsqueeze(0).to(self.device)

        # --- Inference ---
        with torch.no_grad():
            result = self.model(return_loss=False, img=[img_tensor])

        # --- Post-process ---
        # The model output is logits of shape (1, 2, H, W)
        # We take argmax to get the class (0=authentic, 1=tampered)
        mask_pred = result[0].argmax(0).astype(np.uint8)
        
        # Convert mask to a visual format (0 -> 0, 1 -> 255)
        mask_pred *= 255
        
        # Create a PIL image from the numpy mask
        mask_image = Image.fromarray(mask_pred, mode='L')

        # --- Encode to Base64 ---
        buffered = io.BytesIO()
        mask_image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()

        return {"image": img_str}