boluobobo's picture
Upload app.py with huggingface_hub
f9aed6b verified
"""
ItsNotAI v2 - AI Image Detector
Gradio app for Hugging Face Spaces
Dual-head model: Multi-class + Binary classification
+ FLUX specialist detector (ensemble)
"""
import gradio as gr
import torch
import torch.nn as nn
import json
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor, ViTForImageClassification, ViTImageProcessor
from huggingface_hub import hf_hub_download
# ==============================================
# Main Model (v2 multi-class + binary head)
# ==============================================
MODEL_ID = "boluobobo/ItsNotAI-ai-detector-v2"
print("Loading main model...")
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model.eval()
# Load source metadata
try:
meta_path = hf_hub_download(repo_id=MODEL_ID, filename="source_meta.json")
with open(meta_path) as f:
meta = json.load(f)
source_names = meta["source_names"]
source_is_real = meta["source_is_real"]
dual_head = meta.get("dual_head", False)
hidden_size = meta.get("hidden_size", model.config.hidden_size)
except Exception:
# Fallback
source_names = list(model.config.id2label.values())
source_is_real = {}
dual_head = False
hidden_size = model.config.hidden_size
# Load binary head if dual-head mode
binary_head = None
if dual_head:
try:
binary_head_path = hf_hub_download(repo_id=MODEL_ID, filename="binary_head.pt")
binary_head = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_size, 2),
)
binary_head.load_state_dict(torch.load(binary_head_path, map_location="cpu", weights_only=True))
binary_head.eval()
print("Binary head loaded - using dual-head mode")
except Exception as e:
print(f"Binary head not found, falling back to single-head mode: {e}")
dual_head = False
print(f"Main model: {len(source_names)} classes, dual_head={dual_head}")
# ==============================================
# FLUX Specialist Detector (ensemble)
# ==============================================
FLUX_MODEL_ID = "ash12321/flux-detector-vit"
FLUX_THRESHOLD = 0.85 # High confidence threshold
flux_model = None
flux_processor = None
try:
print("Loading FLUX specialist detector...")
flux_model = ViTForImageClassification.from_pretrained(FLUX_MODEL_ID)
flux_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
flux_model.eval()
print(f"FLUX detector loaded (threshold={FLUX_THRESHOLD})")
except Exception as e:
print(f"FLUX detector not available: {e}")
def get_backbone_features(pixel_values):
"""Extract backbone features (CLS token)"""
if hasattr(model, 'beit'):
outputs = model.beit(pixel_values)
return outputs.last_hidden_state[:, 0]
elif hasattr(model, 'vit'):
outputs = model.vit(pixel_values)
return outputs.last_hidden_state[:, 0]
else:
for name, module in model.named_children():
if name not in ["classifier", "head", "fc"]:
outputs = module(pixel_values)
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
return outputs.pooler_output
else:
return outputs.last_hidden_state[:, 0]
raise ValueError("Cannot extract backbone features from this model")
def detect_flux(image: Image.Image):
"""
FLUX specialist detection.
Returns (is_flux, confidence) or (None, None) if detector not available.
"""
if flux_model is None or flux_processor is None:
return None, None
try:
inputs = flux_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = flux_model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
# index 0 = Real, index 1 = FLUX
flux_prob = probs[1].item()
return flux_prob >= FLUX_THRESHOLD, flux_prob
except Exception:
return None, None
def predict(image: Image.Image):
"""Predict if image is real or AI-generated"""
if image is None:
return None, None, "Please upload an image", None
# Preprocess
image = image.convert("RGB")
# ==============================================
# Step 1: FLUX specialist detection (priority)
# ==============================================
flux_detected, flux_confidence = detect_flux(image)
used_flux_detector = False
if flux_detected:
# High confidence FLUX detection - use specialist result
ai_prob = flux_confidence
human_prob = 1.0 - ai_prob
is_real = False
predicted_source = "FLUX"
used_flux_detector = True
# ==============================================
# Step 2: Main model inference
# ==============================================
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0]
if not used_flux_detector:
# Dual-head mode: use binary head
if dual_head and binary_head is not None:
features = get_backbone_features(inputs["pixel_values"])
binary_logits = binary_head(features)
binary_probs = torch.softmax(binary_logits, dim=-1)[0]
# Binary head: index 0 = Real, index 1 = AI
human_prob = binary_probs[0].item()
ai_prob = binary_probs[1].item()
is_real = human_prob > ai_prob
else:
# Fallback: Top-1 decision
pred_idx = probs.argmax().item()
predicted_source_top1 = source_names[pred_idx]
confidence = probs[pred_idx].item()
is_real = source_is_real.get(predicted_source_top1, False)
if is_real:
human_prob = confidence
ai_prob = 1.0 - human_prob
else:
ai_prob = confidence
human_prob = 1.0 - ai_prob
# Top-1 multi-class prediction (for source display)
pred_idx = probs.argmax().item()
if not used_flux_detector:
predicted_source = source_names[pred_idx]
confidence = probs[pred_idx].item()
# Get top 3 AI sources only (exclude real sources)
ai_sources = []
for i, (name, prob) in enumerate(zip(source_names, probs.tolist())):
if not source_is_real.get(name, False):
ai_sources.append({"label": name, "score": round(prob, 3)})
ai_sources.sort(key=lambda x: x["score"], reverse=True)
top3_sources = ai_sources[:3]
# If FLUX detected, add it to top sources
if used_flux_detector:
top3_sources.insert(0, {"label": "FLUX (specialist)", "score": round(flux_confidence, 3)})
top3_sources = top3_sources[:3]
# API-style JSON output
api_output = {
"ai_probability": round(ai_prob, 3),
"human_probability": round(human_prob, 3),
"predicted_source": predicted_source,
"top3_sources": top3_sources,
"dual_head": dual_head,
"flux_detected": used_flux_detector,
"flux_confidence": round(flux_confidence, 3) if flux_confidence else None,
}
# Top predictions for bar chart
top_preds = {}
for i, (name, prob) in enumerate(zip(source_names, probs.tolist())):
if prob > 0.01:
marker = "[Real]" if source_is_real.get(name, False) else "[AI]"
top_preds[f"{marker} {name}"] = prob
# Add FLUX specialist result if detected
if used_flux_detector:
top_preds["[AI] FLUX (specialist)"] = flux_confidence
top_preds = dict(sorted(top_preds.items(), key=lambda x: x[1], reverse=True)[:10])
# Summary
if used_flux_detector:
mode_note = " *(FLUX Specialist)*"
elif dual_head:
mode_note = " *(Binary Head)*"
else:
mode_note = ""
verdict_emoji = "Human" if is_real else "AI Generated"
summary = f"""
### Verdict: **{verdict_emoji}**{mode_note}
| Category | Probability |
|----------|-------------|
| Human | {human_prob:.1%} |
| AI Generated | {ai_prob:.1%} |
---
**Predicted Source**: {predicted_source} ({confidence:.1%})
"""
if used_flux_detector:
summary += f"\n**FLUX Detector**: {flux_confidence:.1%} confidence"
summary += "\n\n**Top AI Sources**:"
for src in top3_sources[:3]:
summary += f"\n- {src['label']}: {src['score']:.1%}"
return (
{"Human": human_prob, "AI Generated": ai_prob},
top_preds,
summary,
api_output
)
# Custom CSS
css = """
.main-title {
text-align: center;
margin-bottom: 1rem;
}
.gradio-container {
max-width: 1200px !important;
}
"""
# Gradio interface
with gr.Blocks(css=css, title="ItsNotAI v2 - AI Image Detector") as demo:
gr.Markdown(
"""
# ItsNotAI v2 - AI Image Detector
Upload an image to detect if it's **Human-made** or **AI-generated**, and identify the potential source.
**New in v2**: Dual-head architecture with 95%+ binary accuracy. Supports FLUX, Midjourney, Stable Diffusion, DALL-E, and more.
""",
elem_classes="main-title"
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Upload Image",
height=400
)
submit_btn = gr.Button("Analyze", variant="primary", size="lg")
with gr.Column(scale=1):
# Main result
result_label = gr.Label(
label="Human vs AI",
num_top_classes=2
)
# Top predictions
top_preds_label = gr.Label(
label="Top Predictions by Source",
num_top_classes=10
)
# Detailed summary
summary_md = gr.Markdown(label="Details")
# API-style JSON output
with gr.Accordion("API Output (JSON)", open=False):
json_output = gr.JSON(label="API Output")
# Event handlers
submit_btn.click(
fn=predict,
inputs=[input_image],
outputs=[result_label, top_preds_label, summary_md, json_output]
)
input_image.change(
fn=predict,
inputs=[input_image],
outputs=[result_label, top_preds_label, summary_md, json_output]
)
gr.Markdown(
"""
---
### About ItsNotAI
Most AI detectors focus on catching AI. **We help artists prove their work is human-made.**
- **Model**: [boluobobo/ItsNotAI-ai-detector-v2](https://huggingface.co/boluobobo/ItsNotAI-ai-detector-v2)
- **Binary Accuracy**: 95.07%
- **Multi-class Accuracy**: 93.47%
- **Website**: [itsnotai.org](https://itsnotai.org)
### Disclaimer
This tool is for educational and research purposes. Results should not be used as definitive proof of image authenticity.
"""
)
if __name__ == "__main__":
demo.launch()