JKrishnanandhaa's picture
Update app.py
3870592 verified
raw
history blame
6.91 kB
"""
Document Forgery Detection – Professional Gradio Dashboard
Hugging Face Spaces Deployment
"""
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import plotly.graph_objects as go
from pathlib import Path
import sys
import json
# -------------------------------------------------
# PATH SETUP
# -------------------------------------------------
sys.path.insert(0, str(Path(__file__).parent))
from src.models import get_model
from src.config import get_config
from src.data.preprocessing import DocumentPreprocessor
from src.data.augmentation import DatasetAwareAugmentation
from src.features.region_extraction import get_mask_refiner, get_region_extractor
from src.features.feature_extraction import get_feature_extractor
from src.training.classifier import ForgeryClassifier
# -------------------------------------------------
# CONSTANTS
# -------------------------------------------------
CLASS_NAMES = {0: "Copy-Move", 1: "Splicing", 2: "Generation"}
CLASS_COLORS = {
0: (255, 0, 0),
1: (0, 255, 0),
2: (0, 0, 255),
}
# -------------------------------------------------
# FORGERY DETECTOR (UNCHANGED CORE LOGIC)
# -------------------------------------------------
class ForgeryDetector:
def __init__(self):
print("Loading models...")
self.config = get_config("config.yaml")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = get_model(self.config).to(self.device)
checkpoint = torch.load("models/best_doctamper.pth", map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
self.classifier = ForgeryClassifier(self.config)
self.classifier.load("models/classifier")
self.preprocessor = DocumentPreprocessor(self.config, "doctamper")
self.augmentation = DatasetAwareAugmentation(self.config, "doctamper", is_training=False)
self.mask_refiner = get_mask_refiner(self.config)
self.region_extractor = get_region_extractor(self.config)
self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
print("βœ“ Models loaded")
def detect(self, image):
if isinstance(image, Image.Image):
image = np.array(image)
if image.ndim == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
original = image.copy()
preprocessed, _ = self.preprocessor(image, None)
augmented = self.augmentation(preprocessed, None)
image_tensor = augmented["image"].unsqueeze(0).to(self.device)
with torch.no_grad():
logits, decoder_features = self.model(image_tensor)
prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
binary = (prob_map > 0.5).astype(np.uint8)
refined = self.mask_refiner.refine(binary, original_size=original.shape[:2])
regions = self.region_extractor.extract(refined, prob_map, original)
results = []
for r in regions:
features = self.feature_extractor.extract(
preprocessed, r["region_mask"], [f.cpu() for f in decoder_features]
)
if features.ndim == 1:
features = features.reshape(1, -1)
if features.shape[1] != 526:
pad = max(0, 526 - features.shape[1])
features = np.pad(features, ((0, 0), (0, pad)))[:, :526]
pred, conf = self.classifier.predict(features)
if conf[0] > 0.6:
results.append({
"bounding_box": r["bounding_box"],
"forgery_type": CLASS_NAMES[int(pred[0])],
"confidence": float(conf[0]),
})
overlay = self._draw_overlay(original, results)
return overlay, {
"num_detections": len(results),
"detections": results,
}
def _draw_overlay(self, image, results):
out = image.copy()
for r in results:
x, y, w, h = r["bounding_box"]
fid = [k for k, v in CLASS_NAMES.items() if v == r["forgery_type"]][0]
color = CLASS_COLORS[fid]
cv2.rectangle(out, (x, y), (x + w, y + h), color, 2)
label = f"{r['forgery_type']} ({r['confidence']*100:.1f}%)"
cv2.putText(out, label, (x, y - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return out
detector = ForgeryDetector()
# -------------------------------------------------
# METRIC VISUALS
# -------------------------------------------------
def gauge(value, title):
fig = go.Figure(go.Indicator(
mode="gauge+number",
value=value,
title={"text": title},
gauge={"axis": {"range": [0, 100]}, "bar": {"color": "#2563eb"}}
))
fig.update_layout(height=240, margin=dict(t=40, b=20))
return fig
# -------------------------------------------------
# GRADIO CALLBACK
# -------------------------------------------------
def run_detection(file):
image = Image.open(file.name)
overlay, result = detector.detect(image)
avg_conf = (
sum(d["confidence"] for d in result["detections"]) / max(1, result["num_detections"])
) * 100
return (
overlay,
result,
gauge(75, "Localization Dice (%)"),
gauge(92, "Classifier Accuracy (%)"),
gauge(avg_conf, "Avg Detection Confidence (%)"),
)
# -------------------------------------------------
# UI
# -------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="Document Forgery Detection") as demo:
gr.Markdown("# πŸ“„ Document Forgery Detection System")
with gr.Row():
file_input = gr.File(label="Upload Document (Image/PDF)")
detect_btn = gr.Button("Run Detection", variant="primary")
output_img = gr.Image(label="Forgery Localization Result", type="numpy")
with gr.Tabs():
with gr.Tab("πŸ“Š Metrics"):
with gr.Row():
dice_plot = gr.Plot()
acc_plot = gr.Plot()
conf_plot = gr.Plot()
with gr.Tab("🧾 Details"):
json_out = gr.JSON()
with gr.Tab("πŸ‘₯ Team"):
gr.Markdown("""
**Document Forgery Detection Project**
- Krishnanandhaa β€” Model & Training
- Teammate 1 β€” Feature Engineering
- Teammate 2 β€” Evaluation
- Teammate 3 β€” Deployment
*Collaborators are added via Hugging Face Space settings.*
""")
detect_btn.click(
run_detection,
inputs=file_input,
outputs=[output_img, json_out, dice_plot, acc_plot, conf_plot]
)
if __name__ == "__main__":
demo.launch()