| | """ |
| | Prediction script combining DINOv2 classifier and Qwen2-VL reasoner |
| | Outputs predictions.json in required format |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torchvision import transforms |
| | from transformers import ( |
| | AutoImageProcessor, |
| | Dinov2Model, |
| | Qwen3VLForConditionalGeneration, |
| | AutoProcessor |
| | ) |
| | from peft import PeftModel |
| | from PIL import Image |
| | import json |
| | import os |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | from qwen_vl_utils import process_vision_info |
| |
|
| | class DINOv2Classifier(nn.Module): |
| | def __init__(self, num_classes=3): |
| | super().__init__() |
| | self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base") |
| | |
| | |
| | self.classifier = nn.Sequential( |
| | nn.Linear(768, 512), |
| | nn.ReLU(), |
| | nn.Dropout(0.3), |
| | nn.Linear(512, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.3), |
| | nn.Linear(256, num_classes) |
| | ) |
| | |
| | def forward(self, pixel_values): |
| | outputs = self.dinov2(pixel_values) |
| | cls_token = outputs.last_hidden_state[:, 0] |
| | logits = self.classifier(cls_token) |
| | return logits |
| |
|
| | class GenAIDetector: |
| | def __init__(self, classifier_path): |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {self.device}") |
| | |
| | |
| | print("Loading classifier...") |
| | self.classifier = DINOv2Classifier(num_classes=3).to(self.device) |
| | checkpoint = torch.load(classifier_path, map_location=self.device) |
| | self.classifier.load_state_dict(checkpoint['model_state_dict']) |
| | self.classifier.eval() |
| | |
| | self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") |
| | |
| | |
| | print("Loading VLM reasoner...") |
| | base_model = Qwen3VLForConditionalGeneration.from_pretrained( |
| | "Qwen/Qwen3-VL-8B-Instruct", |
| | torch_dtype="auto", |
| | device_map="auto" |
| | ) |
| | self.vlm = base_model |
| | self.vlm_processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") |
| | self.vlm.eval() |
| | |
| | self.class_names = ['real', 'manipulated', 'fake'] |
| | self.manipulation_types = { |
| | 'real': 'none', |
| | 'manipulated': 'inpainting', |
| | 'fake': 'full_synthesis' |
| | } |
| | |
| | def classify_image(self, image_path): |
| | """Classify image and get confidence scores""" |
| | image = Image.open(image_path).convert('RGB') |
| | inputs = self.image_processor(images=image, return_tensors="pt") |
| | pixel_values = inputs['pixel_values'].to(self.device) |
| | |
| | with torch.no_grad(): |
| | logits = self.classifier(pixel_values) |
| | probs = torch.softmax(logits, dim=1) |
| | pred_class = torch.argmax(probs, dim=1).item() |
| | confidence = probs[0].cpu().numpy() |
| | |
| | return pred_class, confidence |
| | |
| | def generate_reasoning(self, image_path, predicted_class): |
| | """Generate reasoning using VLM""" |
| | class_name = self.class_names[predicted_class] |
| | |
| | |
| | prompt = f"The given image has been flagged as {class_name}. Explain in 2-3 sentences why that might be. Focus on specific features which indicated this." |
| | |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": image_path}, |
| | {"type": "text", "text": prompt} |
| | ] |
| | } |
| | ] |
| | |
| | |
| | text = self.vlm_processor.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | |
| | |
| | image_inputs, video_inputs = process_vision_info(messages) |
| | inputs = self.vlm_processor( |
| | text=[text], |
| | images=image_inputs, |
| | videos=video_inputs, |
| | padding=True, |
| | return_tensors="pt" |
| | ) |
| | inputs = inputs.to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | output_ids = self.vlm.generate( |
| | **inputs, |
| | max_new_tokens=150, |
| | temperature=0.7, |
| | do_sample=True |
| | ) |
| | |
| | |
| | generated_text = self.vlm_processor.batch_decode( |
| | output_ids, |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | )[0] |
| | |
| | |
| | if "assistant" in generated_text.lower(): |
| | reasoning = generated_text.split("assistant")[-1].strip() |
| | else: |
| | reasoning = generated_text.strip() |
| | |
| | return reasoning |
| | |
| | def predict(self, image_path): |
| | """Full prediction pipeline""" |
| | |
| | pred_class, confidence = self.classify_image(image_path) |
| | |
| | |
| | authenticity_score = float(1.0 - confidence[0]) |
| | |
| | |
| | class_name = self.class_names[pred_class] |
| | manipulation_type = self.manipulation_types[class_name] |
| | |
| | |
| | reasoning = self.generate_reasoning(image_path, pred_class) |
| | |
| | return { |
| | 'authenticity_score': round(authenticity_score, 2), |
| | 'manipulation_type': manipulation_type, |
| | 'vlm_reasoning': reasoning |
| | } |
| |
|
| | def main(image_dir, classifier_path, output_file): |
| | """Main prediction function""" |
| | |
| | |
| | detector = GenAIDetector(classifier_path) |
| | |
| | |
| | image_extensions = ['.jpg', '.jpeg', '.png'] |
| | image_files = [] |
| | for ext in image_extensions: |
| | image_files.extend(Path(image_dir).glob(f'*{ext}')) |
| | image_files.extend(Path(image_dir).glob(f'*{ext.upper()}')) |
| | |
| | print(f"Found {len(image_files)} images") |
| | |
| | |
| | predictions = [] |
| | for image_path in tqdm(image_files, desc="Processing images"): |
| | try: |
| | result = detector.predict(str(image_path)) |
| | result['image_name'] = image_path.name |
| | predictions.append(result) |
| | except Exception as e: |
| | print(f"Error processing {image_path.name}: {str(e)}") |
| | continue |
| | |
| | |
| | with open(output_file, 'w') as f: |
| | json.dump(predictions, f, indent=2) |
| | |
| | print(f"\n✓ Processed {len(predictions)} images") |
| | print(f"✓ Saved predictions to {output_file}") |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--image_dir', type=str, default='./test_images', |
| | help='Directory containing images to predict') |
| | parser.add_argument('--classifier_path', type=str, default='best_model.pth', |
| | help='Path to trained DINOv2 checkpoint (.pth file)') |
| | parser.add_argument('--output_file', type=str, default='predictions.json', |
| | help='Output JSON file') |
| | |
| | args = parser.parse_args() |
| | |
| | main(args.image_dir, args.classifier_path, args.output_file) |