File size: 8,094 Bytes
16f3614
 
c4ae52a
 
16f3614
 
c4ae52a
2dff18f
 
c4ae52a
 
 
 
 
 
 
 
 
 
 
2dff18f
c4ae52a
 
 
 
 
 
 
 
 
 
2dff18f
c4ae52a
 
 
16f3614
c4ae52a
16f3614
2dff18f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4ae52a
 
 
 
 
 
 
 
16f3614
c4ae52a
16f3614
c4ae52a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dff18f
c4ae52a
 
 
 
 
 
 
 
2dff18f
c4ae52a
 
 
 
 
 
 
 
 
16f3614
c4ae52a
 
 
 
 
 
 
 
 
 
2dff18f
c4ae52a
16f3614
 
c4ae52a
 
 
 
 
 
 
 
 
16f3614
 
 
 
 
 
 
c4ae52a
16f3614
 
c4ae52a
 
16f3614
 
c4ae52a
2dff18f
 
 
16f3614
c4ae52a
 
16f3614
c4ae52a
16f3614
 
 
 
 
 
2dff18f
16f3614
 
 
 
c4ae52a
16f3614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4ae52a
16f3614
 
 
 
 
c4ae52a
16f3614
 
 
 
 
c4ae52a
16f3614
 
 
 
c4ae52a
 
16f3614
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import gradio as gr
from PIL import Image
import onnxruntime as ort
import torchvision.transforms as transforms
import json
import os
import numpy as np
import pandas as pd
import random
from huggingface_hub import snapshot_download, HfApi
from transformers import CLIPTokenizer

# --- Config ---
HUB_REPO_ID = "aarodi/OpenArenaLeaderboard"
HF_TOKEN = os.environ.get("HF_TOKEN")
LOCAL_JSON = "leaderboard.json"
HUB_JSON = "leaderboard.json"
MODEL_PATH = "mobilenet_v2_fake_detector.onnx"
CLIP_IMAGE_ENCODER_PATH = "clip_image_encoder.onnx"
CLIP_TEXT_ENCODER_PATH = "clip_text_encoder.onnx"
PROMPT_CSV_PATH = "generate2_1.csv"
PROMPT_MATCH_THRESHOLD = 10  # percent

# --- Download leaderboard + model checkpoint from HF Hub ---
def load_assets():
    try:
        snapshot_download(
            repo_id=HUB_REPO_ID,
            local_dir=".",
            repo_type="dataset",
            token=HF_TOKEN,
            allow_patterns=[HUB_JSON, MODEL_PATH, CLIP_IMAGE_ENCODER_PATH, CLIP_TEXT_ENCODER_PATH, PROMPT_CSV_PATH]
        )
    except Exception as e:
        print(f"Failed to load assets from HF Hub: {e}")

load_assets()

# --- Load prompts from CSV ---
def load_prompts():
    try:
        df = pd.read_csv(PROMPT_CSV_PATH)
        if "prompt" in df.columns:
            return df["prompt"].dropna().tolist()
        else:
            print("CSV missing 'prompt' column.")
            return []
    except Exception as e:
        print(f"Failed to load prompts: {e}")
        return []

PROMPT_LIST = load_prompts()

# --- Load leaderboard ---
def load_leaderboard():
    try:
        with open(HUB_JSON, "r") as f:
            return json.load(f)
    except Exception as e:
        print(f"Failed to read leaderboard: {e}")
        return {}

leaderboard_scores = load_leaderboard()

# --- Save and push to HF Hub ---
def save_leaderboard():
    try:
        with open(HUB_JSON, "w") as f:
            json.dump(leaderboard_scores, f)

        if HF_TOKEN is None:
            print("HF_TOKEN not set. Skipping push to hub.")
            return

        api = HfApi()
        api.upload_file(
            path_or_fileobj=HUB_JSON,
            path_in_repo=HUB_JSON,
            repo_id=HUB_REPO_ID,
            repo_type="dataset",
            token=HF_TOKEN,
            commit_message="Update leaderboard"
        )
    except Exception as e:
        print(f"Failed to save leaderboard to HF Hub: {e}")

# --- Load ONNX models ---
session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name

clip_image_sess = ort.InferenceSession(CLIP_IMAGE_ENCODER_PATH, providers=["CPUExecutionProvider"])
clip_text_sess = ort.InferenceSession(CLIP_TEXT_ENCODER_PATH, providers=["CPUExecutionProvider"])
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])

def compute_prompt_match(image: Image.Image, prompt: str) -> float:
    try:
        img_tensor = transform(image).unsqueeze(0).numpy().astype(np.float32)
        image_features = clip_image_sess.run(None, {clip_image_sess.get_inputs()[0].name: img_tensor})[0][0]
        image_features /= np.linalg.norm(image_features)

        inputs = clip_tokenizer(prompt, return_tensors="np", padding="max_length", truncation=True, max_length=77)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        text_features = clip_text_sess.run(None, {
            clip_text_sess.get_inputs()[0].name: input_ids,
            clip_text_sess.get_inputs()[1].name: attention_mask
        })[0][0]
        text_features /= np.linalg.norm(text_features)

        sim = np.dot(image_features, text_features)
        return round(sim * 100, 2)
    except Exception as e:
        print(f"CLIP ONNX match failed: {e}")
        return 0.0

# --- Main prediction logic ---
def detect_with_model(image: Image.Image, prompt: str, username: str):
    if not username.strip():
        return "Please enter your name.", None, [], gr.update(visible=True), gr.update(visible=False)

    prompt_score = compute_prompt_match(image, prompt)
    if prompt_score < PROMPT_MATCH_THRESHOLD:
        message = f"⚠️ Prompt match too low ({round(prompt_score, 2)}%). Please generate an image that better matches the prompt."
        return message, None, [], gr.update(visible=True), gr.update(visible=False)

    image_tensor = transforms.Resize((224, 224))(image)
    image_tensor = transforms.ToTensor()(image_tensor).unsqueeze(0).numpy().astype(np.float32)
    outputs = session.run(None, {input_name: image_tensor})
    prob = round(1 / (1 + np.exp(-outputs[0][0][0])), 2)
    prediction = "Real" if prob > 0.5 else "Fake"

    score = 1 if prediction == "Real" else 0
    confidence = round(prob * 100, 2) if prediction == "Real" else round((1 - prob) * 100, 2)

    message = f"Prediction: {prediction} ({confidence}% confidence)\n🧐 Prompt match: {prompt_score}%"

    if prediction == "Real":
        leaderboard_scores[username] = leaderboard_scores.get(username, 0) + score
        message += "\nπŸŽ‰ Nice! You fooled the AI. +1 point!"
    else:
        message += "\nπŸ˜… The AI caught you this time. Try again!"

    save_leaderboard()

    sorted_scores = sorted(leaderboard_scores.items(), key=lambda x: x[1], reverse=True)
    leaderboard_table = [[name, points] for name, points in sorted_scores]

    return (
        message,
        image,
        leaderboard_table,
        gr.update(visible=False),
        gr.update(visible=True)
    )

# --- UI Layout ---
def get_random_prompt():
    return random.choice(PROMPT_LIST) if PROMPT_LIST else "A synthetic scene with dramatic lighting"

with gr.Blocks(css=".gr-button {font-size: 16px !important}") as demo:
    gr.Markdown("## 🎝 OpenFake Arena")
    gr.Markdown("Welcome to the OpenFake Arena!\n\n**Your mission:** Generate a synthetic image for the prompt, upload it, and try to fool the AI detector into thinking it’s real.\n\n**Rules:**\n- Only synthetic images allowed!\n- No cheating with real photos.\n- Licensing is your responsibility.\n\nMake it wild. Make it weird. Most of all β€” make it fun.")

    with gr.Group(visible=True) as input_section:
        username_input = gr.Textbox(label="Your Name", placeholder="Enter your name")

        with gr.Row():
            prompt_input = gr.Textbox(
                label="Suggested Prompt",
                placeholder="e.g., A portrait photograph of a politician delivering a speech...",
                value=get_random_prompt(),
                lines=2
            )

        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload Synthetic Image")

        with gr.Row():
            submit_btn = gr.Button("Upload")

    try_again_btn = gr.Button("Try Again", visible=False)

    with gr.Group():
        gr.Markdown("### 🎯 Result")
        with gr.Row():
            prediction_output = gr.Textbox(label="Prediction", interactive=False)
            image_output = gr.Image(label="Submitted Image", show_label=False)

    with gr.Group():
        gr.Markdown("### πŸ† Leaderboard")
        leaderboard = gr.Dataframe(
            headers=["Username", "Score"],
            datatype=["str", "number"],
            interactive=False,
            row_count=5
        )

    submit_btn.click(
        fn=detect_with_model,
        inputs=[image_input, prompt_input, username_input],
        outputs=[
            prediction_output,
            image_output,
            leaderboard,
            input_section,
            try_again_btn
        ]
    )

    try_again_btn.click(
        fn=lambda: ("", None, [], gr.update(visible=True), gr.update(visible=False)),
        outputs=[
            prediction_output,
            image_output,
            leaderboard,
            input_section,
            try_again_btn
        ]
    )

if __name__ == "__main__":
    demo.launch()