""" Prediction script combining DINOv2 classifier and Qwen2-VL reasoner Outputs predictions.json in required format """ import torch import torch.nn as nn from torchvision import transforms from transformers import ( AutoImageProcessor, Dinov2Model, Qwen3VLForConditionalGeneration, AutoProcessor ) from peft import PeftModel from PIL import Image import json import os from pathlib import Path from tqdm import tqdm from qwen_vl_utils import process_vision_info class DINOv2Classifier(nn.Module): def __init__(self, num_classes=3): super().__init__() self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base") # Classification head self.classifier = nn.Sequential( nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) def forward(self, pixel_values): outputs = self.dinov2(pixel_values) cls_token = outputs.last_hidden_state[:, 0] logits = self.classifier(cls_token) return logits class GenAIDetector: def __init__(self, classifier_path): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {self.device}") # Load DINOv2 classifier print("Loading classifier...") self.classifier = DINOv2Classifier(num_classes=3).to(self.device) checkpoint = torch.load(classifier_path, map_location=self.device) self.classifier.load_state_dict(checkpoint['model_state_dict']) self.classifier.eval() self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") # Load VLM print("Loading VLM reasoner...") base_model = Qwen3VLForConditionalGeneration.from_pretrained( "Qwen/Qwen3-VL-8B-Instruct", torch_dtype="auto", device_map="auto" ) self.vlm = base_model self.vlm_processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") self.vlm.eval() self.class_names = ['real', 'manipulated', 'fake'] self.manipulation_types = { 'real': 'none', 'manipulated': 'inpainting', 'fake': 'full_synthesis' } def classify_image(self, image_path): """Classify image and get confidence scores""" image = Image.open(image_path).convert('RGB') inputs = self.image_processor(images=image, return_tensors="pt") pixel_values = inputs['pixel_values'].to(self.device) with torch.no_grad(): logits = self.classifier(pixel_values) probs = torch.softmax(logits, dim=1) pred_class = torch.argmax(probs, dim=1).item() confidence = probs[0].cpu().numpy() return pred_class, confidence def generate_reasoning(self, image_path, predicted_class): """Generate reasoning using VLM""" class_name = self.class_names[predicted_class] # Prepare prompt prompt = f"The given image has been flagged as {class_name}. Explain in 2-3 sentences why that might be. Focus on specific features which indicated this." messages = [ { "role": "user", "content": [ {"type": "image", "image": image_path}, {"type": "text", "text": prompt} ] } ] # Apply chat template text = self.vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process inputs image_inputs, video_inputs = process_vision_info(messages) inputs = self.vlm_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) inputs = inputs.to(self.device) # Generate with torch.no_grad(): output_ids = self.vlm.generate( **inputs, max_new_tokens=150, temperature=0.7, do_sample=True ) # Decode generated_text = self.vlm_processor.batch_decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Extract assistant response if "assistant" in generated_text.lower(): reasoning = generated_text.split("assistant")[-1].strip() else: reasoning = generated_text.strip() return reasoning def predict(self, image_path): """Full prediction pipeline""" # Classify pred_class, confidence = self.classify_image(image_path) # Get authenticity score (confidence that it's real, i.e., confidence[0]) authenticity_score = float(1.0 - confidence[0]) # Higher score = more manipulated # Get manipulation type class_name = self.class_names[pred_class] manipulation_type = self.manipulation_types[class_name] # Generate reasoning reasoning = self.generate_reasoning(image_path, pred_class) return { 'authenticity_score': round(authenticity_score, 2), 'manipulation_type': manipulation_type, 'vlm_reasoning': reasoning } def main(image_dir, classifier_path, output_file): """Main prediction function""" # Initialize detector detector = GenAIDetector(classifier_path) # Get all images image_extensions = ['.jpg', '.jpeg', '.png'] image_files = [] for ext in image_extensions: image_files.extend(Path(image_dir).glob(f'*{ext}')) image_files.extend(Path(image_dir).glob(f'*{ext.upper()}')) print(f"Found {len(image_files)} images") # Process images predictions = [] for image_path in tqdm(image_files, desc="Processing images"): try: result = detector.predict(str(image_path)) result['image_name'] = image_path.name predictions.append(result) except Exception as e: print(f"Error processing {image_path.name}: {str(e)}") continue # Save predictions with open(output_file, 'w') as f: json.dump(predictions, f, indent=2) print(f"\nāœ“ Processed {len(predictions)} images") print(f"āœ“ Saved predictions to {output_file}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('--image_dir', type=str, default='./test_images', help='Directory containing images to predict') parser.add_argument('--classifier_path', type=str, default='best_model.pth', help='Path to trained DINOv2 checkpoint (.pth file)') parser.add_argument('--output_file', type=str, default='predictions.json', help='Output JSON file') args = parser.parse_args() main(args.image_dir, args.classifier_path, args.output_file)