P-rateek commited on
Commit
a2a1eb2
·
verified ·
1 Parent(s): bcc46f8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +94 -94
inference.py CHANGED
@@ -1,95 +1,95 @@
1
- import base64
2
- import io
3
- import os
4
- import sys
5
- from typing import Dict, List, Any
6
-
7
- import numpy as np
8
- import torch
9
- from mmcv.runner import load_checkpoint
10
- from mmcv.utils import Config
11
- from PIL import Image
12
-
13
- # Add current directory to path to import local modules
14
- sys.path.append(os.path.dirname(os.path.realpath(__file__)))
15
-
16
- # Now we can import from the local mmseg and mmcv_custom
17
- from models.for_IML.mmseg.datasets.pipelines import Compose
18
- from models.for_IML.mmseg.models import build_segmentor
19
-
20
-
21
- class Pipeline:
22
- def __init__(self, model_path: str):
23
- """
24
- Initializes the pipeline by loading the model and preprocessing steps.
25
- Args:
26
- model_path (str): The path to the model checkpoint file. It's automatically
27
- passed by the Hugging Face infrastructure.
28
- """
29
- # --- Device ---
30
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
-
32
- # --- Model Configuration ---
33
- # The config file path is relative to the repository root
34
- config_path = 'models for IML/apscnet.py'
35
- # The checkpoint path is also relative to the repository root
36
- checkpoint_path = 'models for IML/APSC-Net.pth'
37
-
38
- if not os.path.exists(checkpoint_path):
39
- raise FileNotFoundError(
40
- f"Checkpoint file not found at {checkpoint_path}. "
41
- "Please download it and place it in the 'models for IML' directory."
42
- )
43
-
44
- cfg = Config.fromfile(config_path)
45
-
46
- # --- Build Model ---
47
- self.model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
48
- load_checkpoint(self.model, checkpoint_path, map_location='cpu')
49
- self.model.to(self.device)
50
- self.model.eval()
51
-
52
- # --- Build Preprocessing Pipeline ---
53
- # We extract the transforms from the test_pipeline in the config
54
- test_pipeline_cfg = cfg.data.test.pipeline[1]['transforms']
55
- self.pipeline = Compose(test_pipeline_cfg)
56
-
57
- def __call__(self, inputs: Image.Image) -> Dict[str, Any]:
58
- """
59
- Performs inference on a single image.
60
- Args:
61
- inputs (Image.Image): A PIL Image to be processed.
62
- Returns:
63
- Dict[str, Any]: A dictionary containing the resulting mask as a base64 encoded string.
64
- """
65
- # Convert PIL image to numpy array (RGB)
66
- img = np.array(inputs.convert('RGB'))
67
-
68
- # Prepare data for the pipeline
69
- data = {'img': img, 'img_shape': img.shape, 'ori_shape': img.shape}
70
- data = self.pipeline(data)
71
-
72
- # Move data to the device
73
- img_tensor = data['img'][0].unsqueeze(0).to(self.device)
74
-
75
- # --- Inference ---
76
- with torch.no_grad():
77
- result = self.model(return_loss=False, img=[img_tensor])
78
-
79
- # --- Post-process ---
80
- # The model output is logits of shape (1, 2, H, W)
81
- # We take argmax to get the class (0=authentic, 1=tampered)
82
- mask_pred = result[0].argmax(0).astype(np.uint8)
83
-
84
- # Convert mask to a visual format (0 -> 0, 1 -> 255)
85
- mask_pred *= 255
86
-
87
- # Create a PIL image from the numpy mask
88
- mask_image = Image.fromarray(mask_pred, mode='L')
89
-
90
- # --- Encode to Base64 ---
91
- buffered = io.BytesIO()
92
- mask_image.save(buffered, format="PNG")
93
- img_str = base64.b64encode(buffered.getvalue()).decode()
94
-
95
  return {"image": img_str}
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import sys
5
+ from typing import Dict, List, Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ from mmcv.runner import load_checkpoint
10
+ from mmcv.utils import Config
11
+ from PIL import Image
12
+
13
+ # Add current directory to path to import local modules
14
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
15
+
16
+ # Now we can import from the local mmseg and mmcv_custom
17
+ from models_for_IML.mmseg.datasets.pipelines import Compose
18
+ from models_for_IML.mmseg.models import build_segmentor
19
+
20
+
21
+ class Pipeline:
22
+ def __init__(self, model_path: str):
23
+ """
24
+ Initializes the pipeline by loading the model and preprocessing steps.
25
+ Args:
26
+ model_path (str): The path to the model checkpoint file. It's automatically
27
+ passed by the Hugging Face infrastructure.
28
+ """
29
+ # --- Device ---
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ # --- Model Configuration ---
33
+ # The config file path is relative to the repository root
34
+ config_path = 'models for IML/apscnet.py'
35
+ # The checkpoint path is also relative to the repository root
36
+ checkpoint_path = 'models for IML/APSC-Net.pth'
37
+
38
+ if not os.path.exists(checkpoint_path):
39
+ raise FileNotFoundError(
40
+ f"Checkpoint file not found at {checkpoint_path}. "
41
+ "Please download it and place it in the 'models for IML' directory."
42
+ )
43
+
44
+ cfg = Config.fromfile(config_path)
45
+
46
+ # --- Build Model ---
47
+ self.model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
48
+ load_checkpoint(self.model, checkpoint_path, map_location='cpu')
49
+ self.model.to(self.device)
50
+ self.model.eval()
51
+
52
+ # --- Build Preprocessing Pipeline ---
53
+ # We extract the transforms from the test_pipeline in the config
54
+ test_pipeline_cfg = cfg.data.test.pipeline[1]['transforms']
55
+ self.pipeline = Compose(test_pipeline_cfg)
56
+
57
+ def __call__(self, inputs: Image.Image) -> Dict[str, Any]:
58
+ """
59
+ Performs inference on a single image.
60
+ Args:
61
+ inputs (Image.Image): A PIL Image to be processed.
62
+ Returns:
63
+ Dict[str, Any]: A dictionary containing the resulting mask as a base64 encoded string.
64
+ """
65
+ # Convert PIL image to numpy array (RGB)
66
+ img = np.array(inputs.convert('RGB'))
67
+
68
+ # Prepare data for the pipeline
69
+ data = {'img': img, 'img_shape': img.shape, 'ori_shape': img.shape}
70
+ data = self.pipeline(data)
71
+
72
+ # Move data to the device
73
+ img_tensor = data['img'][0].unsqueeze(0).to(self.device)
74
+
75
+ # --- Inference ---
76
+ with torch.no_grad():
77
+ result = self.model(return_loss=False, img=[img_tensor])
78
+
79
+ # --- Post-process ---
80
+ # The model output is logits of shape (1, 2, H, W)
81
+ # We take argmax to get the class (0=authentic, 1=tampered)
82
+ mask_pred = result[0].argmax(0).astype(np.uint8)
83
+
84
+ # Convert mask to a visual format (0 -> 0, 1 -> 255)
85
+ mask_pred *= 255
86
+
87
+ # Create a PIL image from the numpy mask
88
+ mask_image = Image.fromarray(mask_pred, mode='L')
89
+
90
+ # --- Encode to Base64 ---
91
+ buffered = io.BytesIO()
92
+ mask_image.save(buffered, format="PNG")
93
+ img_str = base64.b64encode(buffered.getvalue()).decode()
94
+
95
  return {"image": img_str}