from __future__ import annotations
import streamlit as st
from PIL import Image
from src.ai_image_detector.config import (
MODEL_PATH,
)
from src.ai_image_detector.inference import (
CalibrationConfig,
load_trained_model,
predict_image_bytes,
)
st.set_page_config(
page_title="AI Image Detector",
page_icon="📷",
layout="wide",
)
@st.cache_resource
def get_model():
return load_trained_model()
def inject_styles() -> None:
st.markdown(
"""
""",
unsafe_allow_html=True,
)
def render_hero() -> None:
st.markdown(
"""
Visual Forensics
AI Image Detector
Check one image or a batch in a cleaner tab-based workspace.
Use the default scan for balanced decisions or switch to the sensitive tab
when you want the detector to lean more aggressively toward AI signals.
""",
unsafe_allow_html=True,
)
def decision_class(label: str) -> str:
if label == "AI-generated":
return "decision-pill decision-ai"
if label == "Real":
return "decision-pill decision-real"
return "decision-pill decision-uncertain"
def render_empty_state(title: str, body: str) -> None:
st.markdown(
f"""
{title}
{body}
""",
unsafe_allow_html=True,
)
def render_detection_tab(
*,
key: str,
title: str,
description: str,
calibration: CalibrationConfig,
orientation_conservative: bool,
model,
) -> None:
st.markdown(f"### {title}")
st.markdown(f'{description}
', unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Upload Image(s)",
type=["jpg", "jpeg", "png", "webp", "bmp"],
accept_multiple_files=True,
help="Upload one image or a batch to compare results quickly.",
key=key,
)
if not uploaded_files:
render_empty_state(
"Drop files to start a scan",
"Your results will appear here with a preview, label, AI probability, and confidence score.",
)
return
rows: list[dict] = []
previews: dict[str, Image.Image] = {}
for file in uploaded_files:
image = Image.open(file).convert("RGB")
previews[file.name] = image
result = predict_image_bytes(
model,
file.getvalue(),
calibration=calibration,
orientation_conservative=orientation_conservative,
)
rows.append(
{
"File": file.name,
"Label": result.label,
"AI Probability": f"{result.ai_probability:.2%}",
"Confidence": f"{result.confidence:.2%}",
"ai_prob_raw": result.ai_probability,
}
)
if len(rows) == 1:
item = rows[0]
image = previews[item["File"]]
st.image(image, caption=item["File"], use_container_width=True)
st.markdown(
f'{item["Label"]}',
unsafe_allow_html=True,
)
st.progress(min(max(item["ai_prob_raw"], 0.0), 1.0))
st.markdown(
f"""
AI Probability
{item["AI Probability"]}
Confidence
{item["Confidence"]}
""",
unsafe_allow_html=True,
)
return
st.dataframe(
[{k: v for k, v in row.items() if k != "ai_prob_raw"} for row in rows],
use_container_width=True,
hide_index=True,
)
selected = st.selectbox("Preview one result", [r["File"] for r in rows], key=f"{key}_preview")
chosen = next(row for row in rows if row["File"] == selected)
st.image(previews[selected], caption=selected, use_container_width=True)
st.markdown(
f'{chosen["Label"]}',
unsafe_allow_html=True,
)
st.progress(min(max(chosen["ai_prob_raw"], 0.0), 1.0))
st.caption(f"AI Probability: {chosen['AI Probability']} | Confidence: {chosen['Confidence']}")
def main() -> None:
inject_styles()
if not MODEL_PATH.exists():
st.warning("No trained model found. Train first with `python train.py`, then reload.")
st.stop()
render_hero()
model = get_model()
default_tab, sensitive_tab = st.tabs(["Default Scan", "AI-Sensitive"])
with default_tab:
st.markdown(
'Balanced mode for the cleanest everyday result view.
',
unsafe_allow_html=True,
)
render_detection_tab(
key="default_scan",
title="Default Scan",
description="Use this when you want a smoother, more balanced prediction flow for normal checks.",
calibration=CalibrationConfig(
threshold=0.65,
uncertain_low=0.45,
uncertain_high=0.70,
),
orientation_conservative=True,
model=model,
)
with sensitive_tab:
st.markdown(
'More aggressive mode when you want stronger AI catching behavior.
',
unsafe_allow_html=True,
)
render_detection_tab(
key="sensitive_scan",
title="AI-Sensitive Scan",
description="This profile reacts faster to possible AI traits and is useful when you want a stricter pass.",
calibration=CalibrationConfig(
threshold=0.40,
uncertain_low=0.30,
uncertain_high=0.50,
),
orientation_conservative=False,
model=model,
)
if __name__ == "__main__":
main()