Spaces:
Sleeping
Sleeping
| """ | |
| V15 - Self-Learning Deepfake Detector with Web Search | |
| """ | |
| import os | |
| import json | |
| import time | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import timm | |
| from torchvision import transforms | |
| from PIL import Image | |
| from safetensors.torch import load_file, save_file | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import hashlib | |
| from datetime import datetime | |
| import requests | |
| from io import BytesIO | |
| SERPAPI_KEY = os.environ.get("SERPAPI_KEY", "") | |
| SERPER_KEY = os.environ.get("SERPER_KEY", "") | |
| CONFIG = { | |
| 'model_repo': 'ash12321/deepfake-detector-v15', | |
| 'data_dir': './data', | |
| 'feedback_file': './data/feedback.json', | |
| 'images_dir': './data/images', | |
| 'checkpoint': './data/v15_model.safetensors', | |
| 'retrain_threshold': 50, | |
| 'learning_rate': 5e-6, | |
| 'batch_size': 8, | |
| 'epochs': 3, | |
| } | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| os.makedirs(CONFIG['data_dir'], exist_ok=True) | |
| os.makedirs(CONFIG['images_dir'], exist_ok=True) | |
| def upload_image_temp(img): | |
| try: | |
| buf = BytesIO() | |
| img.save(buf, format='JPEG', quality=85) | |
| buf.seek(0) | |
| r = requests.post('https://litterbox.catbox.moe/resources/internals/api.php', | |
| files={'reqtype': (None, 'fileupload'), 'time': (None, '1h'), | |
| 'fileToUpload': ('img.jpg', buf, 'image/jpeg')}, timeout=30) | |
| if r.status_code == 200 and r.text.startswith('http'): | |
| return r.text.strip() | |
| except: | |
| pass | |
| return None | |
| def serpapi_search(img): | |
| if not SERPAPI_KEY: | |
| return {'indicators': 0} | |
| url = upload_image_temp(img) | |
| if not url: | |
| return {'indicators': 0} | |
| try: | |
| r = requests.get("https://serpapi.com/search.json", | |
| params={"engine": "google_reverse_image", "image_url": url, "api_key": SERPAPI_KEY}, timeout=20) | |
| if r.status_code == 200: | |
| text = json.dumps(r.json()).lower() | |
| count = sum(text.count(k) for k in ['deepfake', 'fake', 'ai generated', 'synthetic']) | |
| return {'indicators': count, 'source': 'serpapi'} | |
| except: | |
| pass | |
| return {'indicators': 0} | |
| def serper_search(): | |
| if not SERPER_KEY: | |
| return {'indicators': 0} | |
| try: | |
| r = requests.post("https://google.serper.dev/search", | |
| headers={'X-API-KEY': SERPER_KEY, 'Content-Type': 'application/json'}, | |
| data=json.dumps({"q": "deepfake AI generated face", "num": 10}), timeout=15) | |
| if r.status_code == 200: | |
| count = sum(1 for x in r.json().get('organic', []) | |
| if any(k in (x.get('title', '') + x.get('snippet', '')).lower() | |
| for k in ['deepfake', 'fake', 'ai generated'])) | |
| return {'indicators': count, 'source': 'serper'} | |
| except: | |
| pass | |
| return {'indicators': 0} | |
| def web_search(img): | |
| total = 0 | |
| serp = serpapi_search(img) | |
| total += serp.get('indicators', 0) | |
| serp2 = serper_search() | |
| total += serp2.get('indicators', 0) | |
| return {'total': total} | |
| class DeepfakeDetectorV15(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=0) | |
| d = 1536 | |
| self.adapter = nn.Sequential(nn.Linear(d, 512), nn.LayerNorm(512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, d)) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(d, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(0.3), | |
| nn.Linear(512, 128), nn.BatchNorm1d(128), nn.GELU(), nn.Dropout(0.15), nn.Linear(128, 1)) | |
| def forward(self, x): | |
| f = self.backbone(x) | |
| return self.classifier(f + 0.1 * self.adapter(f)).squeeze(-1) | |
| print("Loading model...") | |
| model = DeepfakeDetectorV15() | |
| try: | |
| if os.path.exists(CONFIG['checkpoint']): | |
| model.load_state_dict(load_file(CONFIG['checkpoint'])) | |
| else: | |
| path = hf_hub_download(repo_id=CONFIG['model_repo'], filename="model.safetensors") | |
| model.load_state_dict(load_file(path), strict=False) | |
| except: | |
| path = hf_hub_download(repo_id="ash12321/deepfake-detector-v14", filename="model_3.safetensors") | |
| model.load_state_dict(load_file(path), strict=False) | |
| model = model.to(device).eval() | |
| print(f"Model ready on {device}") | |
| def load_feedback(): | |
| if os.path.exists(CONFIG['feedback_file']): | |
| with open(CONFIG['feedback_file']) as f: | |
| return json.load(f) | |
| return [] | |
| def save_feedback(data): | |
| with open(CONFIG['feedback_file'], 'w') as f: | |
| json.dump(data, f) | |
| feedback_data = load_feedback() | |
| transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | |
| last = {'img': None, 'prob': 0.5, 'web': {}} | |
| class FBDataset(Dataset): | |
| def __init__(self, data): | |
| self.data = [d for d in data if os.path.exists(d.get('path', ''))] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, i): | |
| d = self.data[i] | |
| img = transform(Image.open(d['path']).convert('RGB')) | |
| return img, torch.tensor(d['label'], dtype=torch.float32) | |
| def train_model(): | |
| global model, feedback_data | |
| samples = [d for d in feedback_data if not d.get('trained')] | |
| if len(samples) < 5: | |
| return f"Need at least 5 samples (have {len(samples)})" | |
| loader = DataLoader(FBDataset(samples), batch_size=CONFIG['batch_size'], shuffle=True) | |
| for n, p in model.named_parameters(): | |
| p.requires_grad = 'backbone' not in n | |
| opt = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=CONFIG['learning_rate']) | |
| model.train() | |
| for ep in range(CONFIG['epochs']): | |
| for imgs, labels in loader: | |
| imgs, labels = imgs.to(device), labels.to(device) | |
| opt.zero_grad() | |
| nn.BCEWithLogitsLoss()(model(imgs), labels).backward() | |
| opt.step() | |
| model.eval() | |
| save_file(model.state_dict(), CONFIG['checkpoint']) | |
| for d in samples: | |
| d['trained'] = True | |
| save_feedback(feedback_data) | |
| return f"Trained on {len(samples)} samples!" | |
| def analyze(image, use_web): | |
| global last | |
| if image is None: | |
| return "Upload an image!", "", "" | |
| img = Image.fromarray(image) if isinstance(image, np.ndarray) else image | |
| img = img.convert('RGB') | |
| last['img'] = img | |
| inp = transform(img).unsqueeze(0).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| prob = torch.sigmoid(model(inp)).item() | |
| web_text = "Web search disabled" | |
| if use_web: | |
| web = web_search(img) | |
| last['web'] = web | |
| if web['total'] > 0: | |
| prob = min(prob + web['total'] * 0.03, 0.99) | |
| web_text = f"Found {web['total']} deepfake indicators online!" | |
| else: | |
| web_text = "No deepfake indicators found online" | |
| last['prob'] = prob | |
| result = f"## {'🚨 FAKE' if prob > 0.5 else '✅ REAL'}\n**Confidence**: {(prob if prob > 0.5 else 1-prob):.1%}" | |
| pending = sum(1 for d in feedback_data if not d.get('trained')) | |
| stats = f"Feedback: {len(feedback_data)} total | {pending} pending" | |
| return result, web_text, stats | |
| def submit(label): | |
| global feedback_data, last | |
| if last['img'] is None: | |
| return "Analyze an image first!" | |
| h = hashlib.md5(str(time.time()).encode()).hexdigest()[:12] | |
| path = os.path.join(CONFIG['images_dir'], f"{h}.jpg") | |
| last['img'].save(path) | |
| feedback_data.append({'path': path, 'label': 1 if label == "Fake" else 0, 'prob': last['prob'], 'trained': False}) | |
| save_feedback(feedback_data) | |
| pending = sum(1 for d in feedback_data if not d.get('trained')) | |
| return f"Saved! ({pending}/{CONFIG['retrain_threshold']})" | |
| with gr.Blocks(title="V15 Deepfake Detector") as app: | |
| gr.Markdown("# 🧠 V15 Self-Learning Deepfake Detector") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img = gr.Image(type="pil", label="Upload Image") | |
| web_cb = gr.Checkbox(label="Enable Web Search", value=True) | |
| btn1 = gr.Button("Analyze", variant="primary") | |
| gr.Markdown("---") | |
| radio = gr.Radio(["Real", "Fake"], label="Correct label:") | |
| btn2 = gr.Button("Submit Feedback") | |
| btn3 = gr.Button("Train Model") | |
| with gr.Column(): | |
| out1 = gr.Markdown() | |
| out2 = gr.Markdown() | |
| out3 = gr.Markdown() | |
| out4 = gr.Markdown() | |
| btn1.click(analyze, [img, web_cb], [out1, out2, out3]) | |
| btn2.click(submit, radio, out4) | |
| btn3.click(train_model, outputs=out4) | |
| app.queue().launch() | |