from __future__ import annotations import streamlit as st from PIL import Image from src.ai_image_detector.config import ( MODEL_PATH, ) from src.ai_image_detector.inference import ( CalibrationConfig, load_trained_model, predict_image_bytes, ) st.set_page_config( page_title="AI Image Detector", page_icon="📷", layout="wide", ) @st.cache_resource def get_model(): return load_trained_model() def inject_styles() -> None: st.markdown( """ """, unsafe_allow_html=True, ) def render_hero() -> None: st.markdown( """
Visual Forensics
AI Image Detector
Check one image or a batch in a cleaner tab-based workspace. Use the default scan for balanced decisions or switch to the sensitive tab when you want the detector to lean more aggressively toward AI signals.
""", unsafe_allow_html=True, ) def decision_class(label: str) -> str: if label == "AI-generated": return "decision-pill decision-ai" if label == "Real": return "decision-pill decision-real" return "decision-pill decision-uncertain" def render_empty_state(title: str, body: str) -> None: st.markdown( f"""
{title} {body}
""", unsafe_allow_html=True, ) def render_detection_tab( *, key: str, title: str, description: str, calibration: CalibrationConfig, orientation_conservative: bool, model, ) -> None: st.markdown(f"### {title}") st.markdown(f'
{description}
', unsafe_allow_html=True) uploaded_files = st.file_uploader( "Upload Image(s)", type=["jpg", "jpeg", "png", "webp", "bmp"], accept_multiple_files=True, help="Upload one image or a batch to compare results quickly.", key=key, ) if not uploaded_files: render_empty_state( "Drop files to start a scan", "Your results will appear here with a preview, label, AI probability, and confidence score.", ) return rows: list[dict] = [] previews: dict[str, Image.Image] = {} for file in uploaded_files: image = Image.open(file).convert("RGB") previews[file.name] = image result = predict_image_bytes( model, file.getvalue(), calibration=calibration, orientation_conservative=orientation_conservative, ) rows.append( { "File": file.name, "Label": result.label, "AI Probability": f"{result.ai_probability:.2%}", "Confidence": f"{result.confidence:.2%}", "ai_prob_raw": result.ai_probability, } ) if len(rows) == 1: item = rows[0] image = previews[item["File"]] st.image(image, caption=item["File"], use_container_width=True) st.markdown( f'{item["Label"]}', unsafe_allow_html=True, ) st.progress(min(max(item["ai_prob_raw"], 0.0), 1.0)) st.markdown( f"""
AI Probability
{item["AI Probability"]}
Confidence
{item["Confidence"]}
""", unsafe_allow_html=True, ) return st.dataframe( [{k: v for k, v in row.items() if k != "ai_prob_raw"} for row in rows], use_container_width=True, hide_index=True, ) selected = st.selectbox("Preview one result", [r["File"] for r in rows], key=f"{key}_preview") chosen = next(row for row in rows if row["File"] == selected) st.image(previews[selected], caption=selected, use_container_width=True) st.markdown( f'{chosen["Label"]}', unsafe_allow_html=True, ) st.progress(min(max(chosen["ai_prob_raw"], 0.0), 1.0)) st.caption(f"AI Probability: {chosen['AI Probability']} | Confidence: {chosen['Confidence']}") def main() -> None: inject_styles() if not MODEL_PATH.exists(): st.warning("No trained model found. Train first with `python train.py`, then reload.") st.stop() render_hero() model = get_model() default_tab, sensitive_tab = st.tabs(["Default Scan", "AI-Sensitive"]) with default_tab: st.markdown( '
Balanced mode for the cleanest everyday result view.
', unsafe_allow_html=True, ) render_detection_tab( key="default_scan", title="Default Scan", description="Use this when you want a smoother, more balanced prediction flow for normal checks.", calibration=CalibrationConfig( threshold=0.65, uncertain_low=0.45, uncertain_high=0.70, ), orientation_conservative=True, model=model, ) with sensitive_tab: st.markdown( '
More aggressive mode when you want stronger AI catching behavior.
', unsafe_allow_html=True, ) render_detection_tab( key="sensitive_scan", title="AI-Sensitive Scan", description="This profile reacts faster to possible AI traits and is useful when you want a stricter pass.", calibration=CalibrationConfig( threshold=0.40, uncertain_low=0.30, uncertain_high=0.50, ), orientation_conservative=False, model=model, ) if __name__ == "__main__": main()