LoliRimuru commited on
Commit
8661c43
Β·
verified Β·
1 Parent(s): 8e96457

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import warnings
8
+ from huggingface_hub import hf_hub_download, RepositoryNotFoundError, HFValidationError
9
+ import os
10
+
11
+ warnings.filterwarnings("ignore")
12
+
13
+ # ============ MODEL DEFINITION ============
14
+ class BAILU(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.conv_blocks = nn.Sequential(
18
+ nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
19
+ nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
20
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
21
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
22
+ nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
23
+ nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
24
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
25
+ nn.AdaptiveAvgPool2d(1)
26
+ )
27
+ self.head = nn.Sequential(
28
+ nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4)
29
+ )
30
+
31
+ def forward(self, x):
32
+ features = self.conv_blocks(x)
33
+ features = features.view(features.size(0), -1)
34
+ return self.head(features)
35
+
36
+ # ============ GLOBALS ============
37
+ VAES = ['FLUX', 'FLUX2', 'SDXL', 'SD1.5']
38
+ THRESHOLD = 0.5
39
+
40
+ # ============ HUGGINGFACE REPO CONFIG ============
41
+ HF_REPO_ID = "LoliRimuru/BAILU"
42
+ HF_MODEL_FILENAME = "model.pt"
43
+
44
+ # ============ LOAD MODEL ============
45
+ def load_model():
46
+ """Load the pre-trained BAILU model from HuggingFace or local path."""
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+
49
+ # FIX: Instantiate the correct model class
50
+ model = BAILU().to(device)
51
+
52
+ # Load from HuggingFace Hub
53
+ try:
54
+ print(f"πŸ“₯ Downloading model from HuggingFace: {HF_REPO_ID}")
55
+ model_file = hf_hub_download(
56
+ repo_id=HF_REPO_ID,
57
+ filename=HF_MODEL_FILENAME,
58
+ repo_type="model",
59
+ local_dir="./checkpoints",
60
+ local_dir_use_symlinks=False
61
+ )
62
+
63
+ checkpoint = torch.load(model_file, map_location=device, weights_only=True)
64
+ model.load_state_dict(checkpoint["model_state_dict"])
65
+ model.eval()
66
+ print(f"βœ… Model loaded from HuggingFace: {HF_REPO_ID}")
67
+ return model, device
68
+
69
+ except RepositoryNotFoundError:
70
+ print(f"❌ Repository '{HF_REPO_ID}' not found on HuggingFace.")
71
+ print(" Please check the repository name and ensure it's public or you have access.")
72
+ return None, device
73
+ except HFValidationError:
74
+ print(f"❌ Invalid repository ID: '{HF_REPO_ID}'")
75
+ return None, device
76
+ except Exception as e:
77
+ print(f"❌ Failed to download/load model from HuggingFace: {e}")
78
+ print(" Check your internet connection and huggingface_hub installation.")
79
+ return None, device
80
+
81
+ # ============ INFERENCE ============
82
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
83
+ """Preprocess image for model input."""
84
+ if image.mode != "RGB":
85
+ image = image.convert("RGB")
86
+
87
+ transform = transforms.Compose([
88
+ transforms.CenterCrop(512),
89
+ transforms.ToTensor(),
90
+ ])
91
+
92
+ return transform(image).unsqueeze(0)
93
+
94
+ def predict_image(model, device, image: Image.Image):
95
+ """Run inference and return predictions."""
96
+ with torch.no_grad():
97
+ image_tensor = preprocess_image(image).to(device)
98
+ logits = model(image_tensor)
99
+ probabilities = torch.sigmoid(logits).cpu().numpy()[0]
100
+
101
+ is_ai = np.any(probabilities > THRESHOLD)
102
+ max_prob = np.max(probabilities)
103
+ min_prob = np.min(probabilities)
104
+ confidence = max_prob if is_ai else (1 - min_prob)
105
+
106
+ return probabilities, is_ai, confidence
107
+
108
+ # ============ GRADIO INTERFACE ============
109
+ def create_demo():
110
+ """Create Gradio interface."""
111
+ model, device = load_model()
112
+
113
+ if model is None:
114
+ def error_demo(image):
115
+ return "❌ MODEL NOT LOADED", 0.0, [["ERROR", "0%", "N/A", "0%"]]
116
+
117
+ interface = gr.Interface(
118
+ fn=error_demo,
119
+ inputs=gr.Image(type="pil", label="Upload Image"),
120
+ outputs=[
121
+ gr.Textbox(label="Overall Verdict", show_copy_button=False),
122
+ gr.Number(label="Confidence Score", precision=2),
123
+ gr.Dataframe(
124
+ headers=["VAE Detector", "AI Probability", "Prediction", "Confidence"],
125
+ label="Per-Model Analysis"
126
+ )
127
+ ],
128
+ title="BAILU AI Detection Demo",
129
+ description="Model failed to load. Please check console for details."
130
+ )
131
+ return interface
132
+
133
+ def inference(image):
134
+ if image is None:
135
+ return "πŸ€” NO IMAGE UPLOADED", 0.0, []
136
+
137
+ probs, is_ai, confidence = predict_image(model, device, image)
138
+
139
+ verdict_icon = "πŸ”΄ AI GENERATED" if is_ai else "🟒 HUMAN/REAL IMAGE"
140
+ verdict_text = f"{verdict_icon}\n(Confidence: {confidence:.1%})"
141
+
142
+ results = []
143
+ for vae_name, prob in zip(VAES, probs):
144
+ prediction = "AI" if prob > THRESHOLD else "Real"
145
+ conf = prob if prob > THRESHOLD else (1 - prob)
146
+ status = "🚨" if prob > 0.7 else "⚠️" if prob > 0.5 else "βœ…"
147
+ results.append([
148
+ f"{status} {vae_name}",
149
+ f"{prob:.2%}",
150
+ prediction,
151
+ f"{conf:.1%}"
152
+ ])
153
+
154
+ results.sort(key=lambda x: float(x[1].replace('%', '')), reverse=True)
155
+
156
+ return verdict_text, confidence, results
157
+
158
+ interface = gr.Interface(
159
+ fn=inference,
160
+ inputs=gr.Image(
161
+ type="pil",
162
+ label="Upload Image (PNG, JPG, WEBP)",
163
+ height=400
164
+ ),
165
+ outputs=[
166
+ gr.Textbox(
167
+ label="🎯 Overall Verdict",
168
+ show_copy_button=False,
169
+ lines=2,
170
+ elem_classes="verdict-box"
171
+ ),
172
+ gr.Number(
173
+ label="πŸ“Š Overall Confidence",
174
+ precision=2,
175
+ elem_classes="confidence-box"
176
+ ),
177
+ gr.Dataframe(
178
+ headers=["🧠 Detector", "AI Probability", "Prediction", "Confidence"],
179
+ label="πŸ” Per-Model Breakdown",
180
+ elem_classes="results-table",
181
+ wrap=True
182
+ )
183
+ ],
184
+ title="BAILU AI-Generated Image Detector",
185
+ description="""
186
+ ### Detect AI-generated images
187
+
188
+ BAILU analyzes artifacts to identify
189
+ images generated by popular diffusion models. The model checks for traces from:
190
+
191
+ **🎨 FLUX.1 | πŸš€ FLUX.2 | πŸ–ΌοΈ SDXL | 🎯 Stable Diffusion 1.5**
192
+
193
+ **⚠️ IMPORTANT**: This is a research tool. Results should be verified by human experts
194
+ for critical decisions. The model may produce false positives/negatives.
195
+ """,
196
+ theme=gr.themes.Soft(),
197
+ allow_flagging="never",
198
+ css="""
199
+ .verdict-box {
200
+ font-size: 24px !important;
201
+ font-weight: bold !important;
202
+ text-align: center !important;
203
+ }
204
+ .confidence-box {
205
+ font-size: 20px !important;
206
+ font-weight: bold !important;
207
+ }
208
+ .results-table {
209
+ font-size: 16px !important;
210
+ }
211
+ .gradio-container {
212
+ max-width: 1000px !important;
213
+ margin: auto !important;
214
+ }
215
+ """
216
+ )
217
+
218
+ return interface
219
+
220
+ # ============ MAIN ============
221
+ if __name__ == "__main__":
222
+ demo = create_demo()
223
+ demo.launch(
224
+ server_name="0.0.0.0",
225
+ server_port=7860,
226
+ share=False,
227
+ show_api=False
228
+ )