ash12321's picture
Upload app.py with huggingface_hub
fba5b26 verified
"""
V15 - Self-Learning Deepfake Detector with Web Search
"""
import os
import json
import time
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm
from torchvision import transforms
from PIL import Image
from safetensors.torch import load_file, save_file
from huggingface_hub import hf_hub_download
import numpy as np
import hashlib
from datetime import datetime
import requests
from io import BytesIO
SERPAPI_KEY = os.environ.get("SERPAPI_KEY", "")
SERPER_KEY = os.environ.get("SERPER_KEY", "")
CONFIG = {
'model_repo': 'ash12321/deepfake-detector-v15',
'data_dir': './data',
'feedback_file': './data/feedback.json',
'images_dir': './data/images',
'checkpoint': './data/v15_model.safetensors',
'retrain_threshold': 50,
'learning_rate': 5e-6,
'batch_size': 8,
'epochs': 3,
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs(CONFIG['data_dir'], exist_ok=True)
os.makedirs(CONFIG['images_dir'], exist_ok=True)
def upload_image_temp(img):
try:
buf = BytesIO()
img.save(buf, format='JPEG', quality=85)
buf.seek(0)
r = requests.post('https://litterbox.catbox.moe/resources/internals/api.php',
files={'reqtype': (None, 'fileupload'), 'time': (None, '1h'),
'fileToUpload': ('img.jpg', buf, 'image/jpeg')}, timeout=30)
if r.status_code == 200 and r.text.startswith('http'):
return r.text.strip()
except:
pass
return None
def serpapi_search(img):
if not SERPAPI_KEY:
return {'indicators': 0}
url = upload_image_temp(img)
if not url:
return {'indicators': 0}
try:
r = requests.get("https://serpapi.com/search.json",
params={"engine": "google_reverse_image", "image_url": url, "api_key": SERPAPI_KEY}, timeout=20)
if r.status_code == 200:
text = json.dumps(r.json()).lower()
count = sum(text.count(k) for k in ['deepfake', 'fake', 'ai generated', 'synthetic'])
return {'indicators': count, 'source': 'serpapi'}
except:
pass
return {'indicators': 0}
def serper_search():
if not SERPER_KEY:
return {'indicators': 0}
try:
r = requests.post("https://google.serper.dev/search",
headers={'X-API-KEY': SERPER_KEY, 'Content-Type': 'application/json'},
data=json.dumps({"q": "deepfake AI generated face", "num": 10}), timeout=15)
if r.status_code == 200:
count = sum(1 for x in r.json().get('organic', [])
if any(k in (x.get('title', '') + x.get('snippet', '')).lower()
for k in ['deepfake', 'fake', 'ai generated']))
return {'indicators': count, 'source': 'serper'}
except:
pass
return {'indicators': 0}
def web_search(img):
total = 0
serp = serpapi_search(img)
total += serp.get('indicators', 0)
serp2 = serper_search()
total += serp2.get('indicators', 0)
return {'total': total}
class DeepfakeDetectorV15(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=0)
d = 1536
self.adapter = nn.Sequential(nn.Linear(d, 512), nn.LayerNorm(512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, d))
self.classifier = nn.Sequential(
nn.Linear(d, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(0.3),
nn.Linear(512, 128), nn.BatchNorm1d(128), nn.GELU(), nn.Dropout(0.15), nn.Linear(128, 1))
def forward(self, x):
f = self.backbone(x)
return self.classifier(f + 0.1 * self.adapter(f)).squeeze(-1)
print("Loading model...")
model = DeepfakeDetectorV15()
try:
if os.path.exists(CONFIG['checkpoint']):
model.load_state_dict(load_file(CONFIG['checkpoint']))
else:
path = hf_hub_download(repo_id=CONFIG['model_repo'], filename="model.safetensors")
model.load_state_dict(load_file(path), strict=False)
except:
path = hf_hub_download(repo_id="ash12321/deepfake-detector-v14", filename="model_3.safetensors")
model.load_state_dict(load_file(path), strict=False)
model = model.to(device).eval()
print(f"Model ready on {device}")
def load_feedback():
if os.path.exists(CONFIG['feedback_file']):
with open(CONFIG['feedback_file']) as f:
return json.load(f)
return []
def save_feedback(data):
with open(CONFIG['feedback_file'], 'w') as f:
json.dump(data, f)
feedback_data = load_feedback()
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
last = {'img': None, 'prob': 0.5, 'web': {}}
class FBDataset(Dataset):
def __init__(self, data):
self.data = [d for d in data if os.path.exists(d.get('path', ''))]
def __len__(self):
return len(self.data)
def __getitem__(self, i):
d = self.data[i]
img = transform(Image.open(d['path']).convert('RGB'))
return img, torch.tensor(d['label'], dtype=torch.float32)
def train_model():
global model, feedback_data
samples = [d for d in feedback_data if not d.get('trained')]
if len(samples) < 5:
return f"Need at least 5 samples (have {len(samples)})"
loader = DataLoader(FBDataset(samples), batch_size=CONFIG['batch_size'], shuffle=True)
for n, p in model.named_parameters():
p.requires_grad = 'backbone' not in n
opt = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=CONFIG['learning_rate'])
model.train()
for ep in range(CONFIG['epochs']):
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
opt.zero_grad()
nn.BCEWithLogitsLoss()(model(imgs), labels).backward()
opt.step()
model.eval()
save_file(model.state_dict(), CONFIG['checkpoint'])
for d in samples:
d['trained'] = True
save_feedback(feedback_data)
return f"Trained on {len(samples)} samples!"
def analyze(image, use_web):
global last
if image is None:
return "Upload an image!", "", ""
img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
img = img.convert('RGB')
last['img'] = img
inp = transform(img).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
prob = torch.sigmoid(model(inp)).item()
web_text = "Web search disabled"
if use_web:
web = web_search(img)
last['web'] = web
if web['total'] > 0:
prob = min(prob + web['total'] * 0.03, 0.99)
web_text = f"Found {web['total']} deepfake indicators online!"
else:
web_text = "No deepfake indicators found online"
last['prob'] = prob
result = f"## {'🚨 FAKE' if prob > 0.5 else '✅ REAL'}\n**Confidence**: {(prob if prob > 0.5 else 1-prob):.1%}"
pending = sum(1 for d in feedback_data if not d.get('trained'))
stats = f"Feedback: {len(feedback_data)} total | {pending} pending"
return result, web_text, stats
def submit(label):
global feedback_data, last
if last['img'] is None:
return "Analyze an image first!"
h = hashlib.md5(str(time.time()).encode()).hexdigest()[:12]
path = os.path.join(CONFIG['images_dir'], f"{h}.jpg")
last['img'].save(path)
feedback_data.append({'path': path, 'label': 1 if label == "Fake" else 0, 'prob': last['prob'], 'trained': False})
save_feedback(feedback_data)
pending = sum(1 for d in feedback_data if not d.get('trained'))
return f"Saved! ({pending}/{CONFIG['retrain_threshold']})"
with gr.Blocks(title="V15 Deepfake Detector") as app:
gr.Markdown("# 🧠 V15 Self-Learning Deepfake Detector")
with gr.Row():
with gr.Column():
img = gr.Image(type="pil", label="Upload Image")
web_cb = gr.Checkbox(label="Enable Web Search", value=True)
btn1 = gr.Button("Analyze", variant="primary")
gr.Markdown("---")
radio = gr.Radio(["Real", "Fake"], label="Correct label:")
btn2 = gr.Button("Submit Feedback")
btn3 = gr.Button("Train Model")
with gr.Column():
out1 = gr.Markdown()
out2 = gr.Markdown()
out3 = gr.Markdown()
out4 = gr.Markdown()
btn1.click(analyze, [img, web_cb], [out1, out2, out3])
btn2.click(submit, radio, out4)
btn3.click(train_model, outputs=out4)
app.queue().launch()