trufor-splicing-detecto / trufor_api.py
123Sashank12's picture
Update trufor_api.py
7282699 verified
import os
import sys
import glob
import subprocess
import numpy as np
from PIL import Image
class SplicingDetector:
def __init__(self, trufor_dir, python_executable="python"):
"""
Initialize the detector by pointing it to your TruFor installation.
:param trufor_dir: The absolute or relative path to the 'TruFor_train_test' folder.
:param python_executable: The python executable to run test.py with.
"""
self.trufor_dir = os.path.abspath(trufor_dir)
self.python_exec = python_executable
self.test_script = os.path.join(self.trufor_dir, "test.py")
self.weights_path = "pretrained_models/trufor.pth.tar" # Relative to trufor_dir
def analyze_image(self, image_path, output_dir="./temp_results"):
"""
Runs the TruFor model on an image and returns the raw forensic data.
Returns:
dict with keys: anomaly_map, confidence_map, global_score
or None if inference fails.
"""
# Validate folder and test.py
if not os.path.isdir(self.trufor_dir):
raise NotADirectoryError(
f"TruFor directory not found or invalid: {self.trufor_dir}\n"
f"Tip: TRUFOR_FOLDER must point to the folder that contains test.py."
)
if not os.path.isfile(self.test_script):
raise FileNotFoundError(
f"Could not find test.py at: {self.test_script}\n"
f"Tip: TRUFOR_FOLDER must be the 'TruFor_train_test' folder (the one that contains test.py)."
)
abs_image_path = os.path.abspath(image_path)
abs_output_dir = os.path.abspath(output_dir)
os.makedirs(abs_output_dir, exist_ok=True)
print(f"πŸ” Analyzing {os.path.basename(abs_image_path)} for splicing...")
command = [
self.python_exec,
self.test_script,
"-g",
"-1", # CPU
"-in",
abs_image_path,
"-out",
abs_output_dir,
"-exp",
"trufor_ph3",
"TEST.MODEL_FILE",
self.weights_path,
]
# Ensure repo-local imports work (dataset/, lib/, etc.)
env = os.environ.copy()
env["PYTHONPATH"] = self.trufor_dir + os.pathsep + env.get("PYTHONPATH", "")
try:
subprocess.run(
command,
cwd=self.trufor_dir,
env=env,
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as e:
print("🚨 TruFor Execution Failed!")
# stderr is typically the most useful; stdout may contain progress logs.
print("STDERR:\n", e.stderr)
print("STDOUT:\n", e.stdout)
return None
# Robustly locate the output .npz (avoids stale file issues)
npz_candidates = glob.glob(os.path.join(abs_output_dir, "*.npz"))
if not npz_candidates:
print(f"Error: No .npz output found in {abs_output_dir}")
return None
# pick the newest .npz in the output directory
npz_path = max(npz_candidates, key=os.path.getmtime)
print(f"βœ… Analysis complete. Extracting forensic data from: {os.path.basename(npz_path)}")
data = np.load(npz_path)
results = {
"anomaly_map": data["map"],
"confidence_map": data["conf"],
"global_score": float(data["score"]) if "score" in data else None,
}
return results
def visualize_results(self, image_path, results):
"""
Local-only visualization (not needed for Hugging Face Spaces).
Kept for convenience when running on your machine.
"""
import matplotlib.pyplot as plt # lazy import so server does not require matplotlib
original_img = Image.open(image_path).convert("RGB")
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(original_img)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(results["anomaly_map"], cmap="jet", vmin=0, vmax=1)
score = results.get("global_score")
plt.title(f"Anomaly Map\nGlobal Score: {score:.4f}" if score is not None else "Anomaly Map")
plt.colorbar(fraction=0.046, pad=0.04)
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(results["confidence_map"], cmap="magma", vmin=0, vmax=1)
plt.title("Confidence Map")
plt.colorbar(fraction=0.046, pad=0.04)
plt.axis("off")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
# Local test runner (optional)
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
TRUFOR_FOLDER = os.path.join(SCRIPT_DIR, "TruFor_train_test")
detector = SplicingDetector(trufor_dir=TRUFOR_FOLDER, python_executable=sys.executable)
test_image = os.path.join(SCRIPT_DIR, "sample.jpg")
forensic_data = detector.analyze_image(test_image, output_dir=os.path.join(SCRIPT_DIR, "temp_results_local"))
if forensic_data:
print(f"βœ… Global AI Score: {forensic_data['global_score']:.4f}")
detector.visualize_results(test_image, forensic_data)