ForensicFusion / predict.py
bmrayan's picture
Upload predict.py with huggingface_hub
9ea88f7 verified
"""
Prediction script combining DINOv2 classifier and Qwen2-VL reasoner
Outputs predictions.json in required format
"""
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import (
AutoImageProcessor,
Dinov2Model,
Qwen3VLForConditionalGeneration,
AutoProcessor
)
from peft import PeftModel
from PIL import Image
import json
import os
from pathlib import Path
from tqdm import tqdm
from qwen_vl_utils import process_vision_info
class DINOv2Classifier(nn.Module):
def __init__(self, num_classes=3):
super().__init__()
self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")
# Classification head
self.classifier = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, pixel_values):
outputs = self.dinov2(pixel_values)
cls_token = outputs.last_hidden_state[:, 0]
logits = self.classifier(cls_token)
return logits
class GenAIDetector:
def __init__(self, classifier_path):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Load DINOv2 classifier
print("Loading classifier...")
self.classifier = DINOv2Classifier(num_classes=3).to(self.device)
checkpoint = torch.load(classifier_path, map_location=self.device)
self.classifier.load_state_dict(checkpoint['model_state_dict'])
self.classifier.eval()
self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
# Load VLM
print("Loading VLM reasoner...")
base_model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-8B-Instruct",
torch_dtype="auto",
device_map="auto"
)
self.vlm = base_model
self.vlm_processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
self.vlm.eval()
self.class_names = ['real', 'manipulated', 'fake']
self.manipulation_types = {
'real': 'none',
'manipulated': 'inpainting',
'fake': 'full_synthesis'
}
def classify_image(self, image_path):
"""Classify image and get confidence scores"""
image = Image.open(image_path).convert('RGB')
inputs = self.image_processor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(self.device)
with torch.no_grad():
logits = self.classifier(pixel_values)
probs = torch.softmax(logits, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
confidence = probs[0].cpu().numpy()
return pred_class, confidence
def generate_reasoning(self, image_path, predicted_class):
"""Generate reasoning using VLM"""
class_name = self.class_names[predicted_class]
# Prepare prompt
prompt = f"The given image has been flagged as {class_name}. Explain in 2-3 sentences why that might be. Focus on specific features which indicated this."
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": prompt}
]
}
]
# Apply chat template
text = self.vlm_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process inputs
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.vlm_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
inputs = inputs.to(self.device)
# Generate
with torch.no_grad():
output_ids = self.vlm.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True
)
# Decode
generated_text = self.vlm_processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Extract assistant response
if "assistant" in generated_text.lower():
reasoning = generated_text.split("assistant")[-1].strip()
else:
reasoning = generated_text.strip()
return reasoning
def predict(self, image_path):
"""Full prediction pipeline"""
# Classify
pred_class, confidence = self.classify_image(image_path)
# Get authenticity score (confidence that it's real, i.e., confidence[0])
authenticity_score = float(1.0 - confidence[0]) # Higher score = more manipulated
# Get manipulation type
class_name = self.class_names[pred_class]
manipulation_type = self.manipulation_types[class_name]
# Generate reasoning
reasoning = self.generate_reasoning(image_path, pred_class)
return {
'authenticity_score': round(authenticity_score, 2),
'manipulation_type': manipulation_type,
'vlm_reasoning': reasoning
}
def main(image_dir, classifier_path, output_file):
"""Main prediction function"""
# Initialize detector
detector = GenAIDetector(classifier_path)
# Get all images
image_extensions = ['.jpg', '.jpeg', '.png']
image_files = []
for ext in image_extensions:
image_files.extend(Path(image_dir).glob(f'*{ext}'))
image_files.extend(Path(image_dir).glob(f'*{ext.upper()}'))
print(f"Found {len(image_files)} images")
# Process images
predictions = []
for image_path in tqdm(image_files, desc="Processing images"):
try:
result = detector.predict(str(image_path))
result['image_name'] = image_path.name
predictions.append(result)
except Exception as e:
print(f"Error processing {image_path.name}: {str(e)}")
continue
# Save predictions
with open(output_file, 'w') as f:
json.dump(predictions, f, indent=2)
print(f"\n✓ Processed {len(predictions)} images")
print(f"✓ Saved predictions to {output_file}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./test_images',
help='Directory containing images to predict')
parser.add_argument('--classifier_path', type=str, default='best_model.pth',
help='Path to trained DINOv2 checkpoint (.pth file)')
parser.add_argument('--output_file', type=str, default='predictions.json',
help='Output JSON file')
args = parser.parse_args()
main(args.image_dir, args.classifier_path, args.output_file)