macroadster commited on
Commit
140e2df
·
verified ·
1 Parent(s): 2efbe7f

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +308 -0
inference.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ import sys
6
+ import json
7
+ from typing import Dict, Any
8
+
9
+ # Optional ONNX imports
10
+ try:
11
+ import onnx
12
+ import onnxruntime as ort
13
+ ONNX_AVAILABLE = True
14
+ except ImportError:
15
+ ONNX_AVAILABLE = False
16
+ print("Warning: ONNX not available. Neural network features disabled.")
17
+
18
+ # Optional Hugging Face imports
19
+ try:
20
+ from transformers import Pipeline
21
+ HF_AVAILABLE = True
22
+ except ImportError:
23
+ HF_AVAILABLE = False
24
+ print("Warning: Hugging Face transformers not available. Pipeline features disabled.")
25
+
26
+ # Add scripts directory to import utilities
27
+ scripts_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
28
+ if scripts_dir not in sys.path:
29
+ sys.path.append(scripts_dir)
30
+
31
+ # Import unified input loader
32
+ try:
33
+ from starlight_utils import load_unified_input
34
+ except ImportError as e:
35
+ print(f"Warning: Could not import starlight_utils: {e}")
36
+ load_unified_input = None
37
+
38
+
39
+ class StarlightModel:
40
+ def __init__(
41
+ self,
42
+ detector_path: str = "model/detector.onnx",
43
+ task: str = "detect"
44
+ ):
45
+ self.detector_path = detector_path
46
+ self.task = task
47
+
48
+ # Load ONNX model
49
+ if ONNX_AVAILABLE:
50
+ providers = []
51
+ available_providers = ort.get_available_providers()
52
+ if 'CUDAExecutionProvider' in available_providers:
53
+ providers.append('CUDAExecutionProvider')
54
+ if 'CoreMLExecutionProvider' in available_providers:
55
+ providers.append('CoreMLExecutionProvider')
56
+ providers.append('CPUExecutionProvider')
57
+
58
+ session_options = ort.SessionOptions()
59
+ if 'CUDAExecutionProvider' in providers:
60
+ session_options.enable_mem_pattern = False
61
+ elif 'CoreMLExecutionProvider' in providers:
62
+ session_options.enable_mem_pattern = False
63
+
64
+ if os.path.exists(detector_path):
65
+ try:
66
+ self.detector = ort.InferenceSession(detector_path, sess_options=session_options, providers=providers)
67
+ except Exception as e:
68
+ print(f"Warning: Could not load detector: {e}")
69
+ self.detector = None
70
+ else:
71
+ print(f"Warning: Detector model not found at {detector_path}")
72
+ self.detector = None
73
+ else:
74
+ self.detector = None
75
+
76
+ def _detect_method_from_filename(self, img_path: str) -> str:
77
+ basename = os.path.basename(img_path)
78
+ parts = basename.split("_")
79
+ if len(parts) >= 3:
80
+ method = parts[-2] # e.g., alpha, eoi, dct
81
+ return method
82
+ return "lsb" # Default fallback
83
+
84
+ def predict(self, img_path: str, method: str = None) -> Dict[str, Any]:
85
+ if not load_unified_input:
86
+ return {"error": "starlight_utils not available"}
87
+
88
+ # Use unified input loader (aligned with scanner.py design)
89
+ pixel_tensor, meta, alpha, lsb, palette, format_features, content_features = load_unified_input(img_path, fast_mode=True)
90
+
91
+ # Convert to numpy for ONNX and add batch dimension
92
+ # Note: lsb and alpha need to be in CHW format for ONNX
93
+ lsb_chw = lsb.permute(2, 0, 1) if lsb.dim() == 3 else lsb # (3, 256, 256)
94
+ alpha_chw = alpha.unsqueeze(0) if alpha.dim() == 2 else alpha # (1, 256, 256)
95
+
96
+ inputs = {
97
+ 'meta': np.expand_dims(meta.numpy(), 0),
98
+ 'alpha': np.expand_dims(alpha_chw.numpy(), 0),
99
+ 'lsb': np.expand_dims(lsb_chw.numpy(), 0),
100
+ 'palette': np.expand_dims(palette.numpy(), 0),
101
+ 'format_features': np.expand_dims(format_features.numpy(), 0),
102
+ 'content_features': np.expand_dims(content_features.numpy(), 0),
103
+ 'bit_order': np.array([[0.0, 1.0, 0.0]], dtype=np.float32) # Default msb-first
104
+ }
105
+
106
+ method = method or self._detect_method_from_filename(img_path)
107
+
108
+ if self.task == "detect":
109
+ if self.detector:
110
+ try:
111
+ outputs = self.detector.run(None, inputs)
112
+ stego_logits = outputs[0]
113
+ method_logits = outputs[1]
114
+ method_id = outputs[2]
115
+ method_probs = outputs[3]
116
+
117
+ prob = float(1 / (1 + np.exp(-stego_logits[0][0]))) # Sigmoid
118
+ predicted_method = int(np.argmax(method_logits[0]))
119
+
120
+ return {
121
+ "image_path": img_path,
122
+ "stego_probability": prob,
123
+ "task": self.task,
124
+ "method": method,
125
+ "predicted_method_id": predicted_method,
126
+ "predicted": prob > 0.5
127
+ }
128
+ except Exception as e:
129
+ return {"error": f"ONNX inference failed: {e}"}
130
+ else:
131
+ return {"error": "Detector model not loaded"}
132
+ else:
133
+ return {"error": f"Task '{self.task}' not supported in unified design"}
134
+
135
+
136
+ if ONNX_AVAILABLE and load_unified_input:
137
+ class StarlightSteganographyDetectionPipeline:
138
+ def __init__(self, model_path=None, config_path="config.json", **kwargs):
139
+ # Load config
140
+ if not os.path.exists(config_path):
141
+ raise FileNotFoundError(f"Config file not found at {config_path}")
142
+ with open(config_path, 'r') as f:
143
+ self.config = json.load(f)
144
+
145
+ if model_path is None:
146
+ model_path = self.config.get("model_path", "models/detector_balanced.onnx")
147
+
148
+ # Load ONNX model
149
+ providers = []
150
+ available_providers = ort.get_available_providers()
151
+ if 'CUDAExecutionProvider' in available_providers:
152
+ providers.append('CUDAExecutionProvider')
153
+ if 'CoreMLExecutionProvider' in available_providers:
154
+ providers.append('CoreMLExecutionProvider')
155
+ providers.append('CPUExecutionProvider')
156
+
157
+ session_options = ort.SessionOptions()
158
+ if 'CUDAExecutionProvider' in providers:
159
+ session_options.enable_mem_pattern = False
160
+ elif 'CoreMLExecutionProvider' in providers:
161
+ session_options.enable_mem_pattern = False
162
+
163
+ if not os.path.exists(model_path):
164
+ raise FileNotFoundError(f"Model not found at {model_path}")
165
+
166
+ self.model = ort.InferenceSession(model_path, sess_options=session_options, providers=providers)
167
+
168
+ def __call__(self, image_path, **kwargs):
169
+ sanitized_kwargs, _, _ = self._sanitize_parameters(**kwargs)
170
+ model_inputs = self.preprocess(image_path)
171
+ model_outputs = self._forward(model_inputs)
172
+ return self.postprocess(model_outputs)
173
+
174
+ def _sanitize_parameters(self, **kwargs):
175
+ # No specific parameters to sanitize for now
176
+ return {}, {}, {}
177
+
178
+ def preprocess(self, image_path):
179
+ if not isinstance(image_path, str) or not os.path.exists(image_path):
180
+ raise ValueError(f"Invalid image_path: {image_path}")
181
+
182
+ # Use unified input loader
183
+ try:
184
+ pixel_tensor, meta, alpha, lsb, palette, format_features, content_features = load_unified_input(image_path, fast_mode=True)
185
+ except Exception as e:
186
+ raise ValueError(f"Failed to preprocess image {image_path}: {e}")
187
+
188
+ # Convert to numpy for ONNX and add batch dimension
189
+ # Note: lsb and alpha need to be in CHW format for ONNX
190
+ lsb_chw = lsb.permute(2, 0, 1) if lsb.dim() == 3 else lsb # (3, 256, 256)
191
+ alpha_chw = alpha.unsqueeze(0) if alpha.dim() == 2 else alpha # (1, 256, 256)
192
+
193
+ model_inputs = {
194
+ 'meta': np.expand_dims(meta.numpy(), 0),
195
+ 'alpha': np.expand_dims(alpha_chw.numpy(), 0),
196
+ 'lsb': np.expand_dims(lsb_chw.numpy(), 0),
197
+ 'palette': np.expand_dims(palette.numpy(), 0),
198
+ 'format_features': np.expand_dims(format_features.numpy(), 0),
199
+ 'content_features': np.expand_dims(content_features.numpy(), 0),
200
+ 'bit_order': np.array([[0.0, 1.0, 0.0]], dtype=np.float32) # Default msb-first
201
+ }
202
+
203
+ return model_inputs
204
+
205
+ def _forward(self, model_inputs):
206
+ try:
207
+ outputs = self.model.run(None, model_inputs)
208
+ return {
209
+ 'stego_logits': outputs[0],
210
+ 'method_logits': outputs[1],
211
+ }
212
+ except Exception as e:
213
+ raise RuntimeError(f"ONNX inference failed: {e}")
214
+
215
+ def postprocess(self, model_outputs):
216
+ stego_logits = model_outputs['stego_logits']
217
+ method_logits = model_outputs['method_logits']
218
+
219
+ prob = float(1 / (1 + np.exp(-stego_logits[0][0]))) # Sigmoid
220
+
221
+ method_probs = np.exp(method_logits[0]) / np.sum(np.exp(method_logits[0]))
222
+ predicted_method_id = int(np.argmax(method_logits[0]))
223
+ predicted_method_name = self.config["id2label"].get(str(predicted_method_id), "unknown")
224
+
225
+ return {
226
+ "stego_probability": prob,
227
+ "predicted_method": predicted_method_name,
228
+ "predicted_method_id": predicted_method_id,
229
+ "predicted_method_prob": float(method_probs[predicted_method_id]),
230
+ "is_steganography": prob > 0.5
231
+ }
232
+
233
+ def _sanitize_parameters(self, **kwargs):
234
+ # No specific parameters to sanitize for now
235
+ return {}, {}, {}
236
+
237
+ def preprocess(self, image_path):
238
+ if not isinstance(image_path, str) or not os.path.exists(image_path):
239
+ raise ValueError(f"Invalid image_path: {image_path}")
240
+
241
+ # Use unified input loader
242
+ try:
243
+ pixel_tensor, meta, alpha, lsb, palette, format_features, content_features = load_unified_input(image_path, fast_mode=True)
244
+ except Exception as e:
245
+ raise ValueError(f"Failed to preprocess image {image_path}: {e}")
246
+
247
+ # Convert to numpy for ONNX and add batch dimension
248
+ # Note: lsb and alpha need to be in CHW format for ONNX
249
+ lsb_chw = lsb.permute(2, 0, 1) if lsb.dim() == 3 else lsb # (3, 256, 256)
250
+ alpha_chw = alpha.unsqueeze(0) if alpha.dim() == 2 else alpha # (1, 256, 256)
251
+
252
+ model_inputs = {
253
+ 'meta': np.expand_dims(meta.numpy(), 0),
254
+ 'alpha': np.expand_dims(alpha_chw.numpy(), 0),
255
+ 'lsb': np.expand_dims(lsb_chw.numpy(), 0),
256
+ 'palette': np.expand_dims(palette.numpy(), 0),
257
+ 'format_features': np.expand_dims(format_features.numpy(), 0),
258
+ 'content_features': np.expand_dims(content_features.numpy(), 0),
259
+ 'bit_order': np.array([[0.0, 1.0, 0.0]], dtype=np.float32) # Default msb-first
260
+ }
261
+
262
+ return model_inputs
263
+
264
+ def _forward(self, model_inputs):
265
+ try:
266
+ outputs = self.model.run(None, model_inputs)
267
+ return {
268
+ 'stego_logits': outputs[0],
269
+ 'method_logits': outputs[1],
270
+ }
271
+ except Exception as e:
272
+ raise RuntimeError(f"ONNX inference failed: {e}")
273
+
274
+ def postprocess(self, model_outputs):
275
+ stego_logits = model_outputs['stego_logits']
276
+ method_logits = model_outputs['method_logits']
277
+
278
+ prob = float(1 / (1 + np.exp(-stego_logits[0][0]))) # Sigmoid
279
+
280
+ method_probs = np.exp(method_logits[0]) / np.sum(np.exp(method_logits[0]))
281
+ predicted_method_id = int(np.argmax(method_logits[0]))
282
+ predicted_method_name = self.config["id2label"].get(str(predicted_method_id), "unknown")
283
+
284
+ return {
285
+ "stego_probability": prob,
286
+ "predicted_method": predicted_method_name,
287
+ "predicted_method_id": predicted_method_id,
288
+ "predicted_method_prob": float(method_probs[predicted_method_id]),
289
+ "is_steganography": prob > 0.5
290
+ }
291
+
292
+ # Convenience functions for specific tasks
293
+ def detect_steganography(img_path):
294
+ """Detect steganography using the unified model."""
295
+ model = StarlightModel(task="detect")
296
+ return model.predict(img_path)
297
+
298
+ def get_starlight_pipeline():
299
+ """
300
+ Initializes and returns the StarlightSteganographyDetectionPipeline.
301
+ Raises ImportError if dependencies are not met.
302
+ """
303
+ if not ONNX_AVAILABLE:
304
+ raise ImportError("ONNX runtime library not found. Please install it with 'pip install onnxruntime'.")
305
+ if not load_unified_input:
306
+ raise ImportError("starlight_utils could not be imported. Please ensure the 'scripts' directory is in your Python path.")
307
+
308
+ return StarlightSteganographyDetectionPipeline()