HussainKAUST's picture
Upload 5 files
a1e6863 verified
#!/usr/bin/env python3
"""
GenAI Image Detection - Track A: Social Media & Influencer Authenticity
MenaML Winter School 2026 Hackathon
Usage: python predict.py --input_dir /test_images --output_file predictions.json
"""
import os
import sys
import json
import argparse
import warnings
import re
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
warnings.filterwarnings('ignore')
# Global variables
clip_model = None
clip_preprocess = None
classifier = None
scaler = None
vlm_model = None
vlm_processor = None
device = "cuda" if torch.cuda.is_available() else "cpu"
VLM_PROMPT = """You are an AI forensics expert. Analyze this image and determine if it is REAL or AI-GENERATED.
Check for: hands/fingers (count should be 5), facial features, text readability, backgrounds, physics violations.
Respond in JSON:
{"verdict": "REAL" or "FAKE", "confidence": 0.5-0.95, "artifacts_found": [], "reasoning": "explanation"}"""
def load_models():
"""Load CLIP, classifier, and VLM."""
global clip_model, clip_preprocess, classifier, scaler, vlm_model, vlm_processor
import clip
import pickle
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
print("Loading models...")
# MODULE 1: CLIP
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
clip_model.eval()
print(" ✓ CLIP ViT-L/14 loaded")
# Load classifier
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_forensic_detector.pkl')
if not os.path.exists(model_path):
model_path = 'clip_forensic_detector.pkl'
with open(model_path, 'rb') as f:
data = pickle.load(f)
classifier = data['classifier']
scaler = data['scaler']
print(" ✓ Classifier loaded")
# MODULE 2: VLM
print(" Loading VLM...")
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
quantization_config=quant_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
vlm_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
trust_remote_code=True
)
print(" ✓ VLM loaded")
print("✅ Both modules ready!")
def analyze_with_vlm(image_path):
"""Analyze with VLM."""
try:
from qwen_vl_utils import process_vision_info
messages = [{
"role": "user",
"content": [
{"type": "image", "image": f"file://{image_path}"},
{"type": "text", "text": VLM_PROMPT}
]
}]
text = vlm_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = vlm_processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(vlm_model.device)
with torch.no_grad():
output_ids = vlm_model.generate(**inputs, max_new_tokens=500, do_sample=False)
output_ids = output_ids[:, inputs.input_ids.shape[1]:]
response = vlm_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
# Parse JSON
json_match = re.search(r'\{[^{}]*\}', response, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
return {
'verdict': data.get('verdict', 'UNKNOWN').upper(),
'confidence': float(data.get('confidence', 0.7)),
'artifacts': data.get('artifacts_found', []),
'reasoning': data.get('reasoning', '')
}
except:
pass
# Fallback
resp_lower = response.lower()
if any(w in resp_lower for w in ['fake', 'ai-generated', 'synthetic']):
verdict = 'FAKE'
elif any(w in resp_lower for w in ['real', 'authentic']):
verdict = 'REAL'
else:
verdict = 'UNKNOWN'
return {'verdict': verdict, 'confidence': 0.65, 'artifacts': [], 'reasoning': response[:200]}
except Exception as e:
return {'verdict': 'UNKNOWN', 'confidence': 0.5, 'artifacts': [], 'reasoning': f'Error: {str(e)[:50]}'}
def predict_single(image_path):
"""Predict with BOTH modules + fusion."""
try:
# MODULE 1: Forensic
img = Image.open(image_path).convert('RGB')
img_tensor = clip_preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
features = clip_model.encode_image(img_tensor)
features = features / features.norm(dim=-1, keepdim=True)
features_np = features.cpu().numpy().flatten()
features_scaled = scaler.transform(features_np.reshape(1, -1))
forensic_score = float(classifier.predict_proba(features_scaled)[0][1])
# MODULE 2: VLM
vlm_result = analyze_with_vlm(image_path)
# FUSION
if forensic_score >= 0.70:
final_score = forensic_score
elif forensic_score <= 0.25:
if vlm_result['verdict'] == 'FAKE' and vlm_result['artifacts']:
final_score = 0.4 * forensic_score + 0.6 * vlm_result['confidence']
else:
final_score = forensic_score
else:
if vlm_result['verdict'] == 'FAKE':
final_score = 0.4 * forensic_score + 0.6 * vlm_result['confidence']
elif vlm_result['verdict'] == 'REAL' and vlm_result['confidence'] > 0.85:
final_score = 0.6 * forensic_score + 0.4 * (1 - vlm_result['confidence'])
else:
final_score = forensic_score
# Type
if final_score >= 0.85:
manipulation_type = "Full Synthesis"
elif final_score >= 0.70:
manipulation_type = "AI-generated"
elif final_score >= 0.50:
manipulation_type = "Possible manipulation"
elif final_score >= 0.30:
manipulation_type = "Light editing"
else:
manipulation_type = "Authentic"
# Reasoning
reasoning = f"Forensic: {forensic_score:.2f}. VLM: {vlm_result['verdict']} ({vlm_result['confidence']:.2f}). "
if vlm_result['artifacts']:
reasoning += f"Artifacts: {', '.join(vlm_result['artifacts'][:3])}. "
reasoning += vlm_result['reasoning'][:300]
return {
'authenticity_score': round(final_score, 4),
'manipulation_type': manipulation_type,
'vlm_reasoning': reasoning[:500]
}
except Exception as e:
return {
'authenticity_score': 0.5,
'manipulation_type': 'Error',
'vlm_reasoning': f'Error: {str(e)[:100]}'
}
def main():
parser = argparse.ArgumentParser(description='GenAI Image Detection - Track A')
parser.add_argument('--input_dir', required=True, help='Directory containing input images')
parser.add_argument('--output_file', required=True, help='Output JSON file path')
args = parser.parse_args()
if not os.path.isdir(args.input_dir):
print(f"Error: Input directory not found: {args.input_dir}")
sys.exit(1)
load_models()
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'}
images = [f for f in os.listdir(args.input_dir) if Path(f).suffix.lower() in image_extensions]
if not images:
print(f"No images found in {args.input_dir}")
sys.exit(1)
print(f"\nProcessing {len(images)} images with DUAL-MODULE system...\n")
predictions = []
for img_name in tqdm(images, desc="Analyzing"):
img_path = os.path.join(args.input_dir, img_name)
result = predict_single(img_path)
prediction = {
'image_name': img_name,
'authenticity_score': result['authenticity_score'],
'manipulation_type': result['manipulation_type'],
'vlm_reasoning': result['vlm_reasoning']
}
predictions.append(prediction)
with open(args.output_file, 'w') as f:
json.dump(predictions, f, indent=2)
print(f"\n✅ Results saved to: {args.output_file}")
print(f"✅ Processed {len(predictions)} images")
if __name__ == '__main__':
main()