bmrayan commited on
Commit
9ea88f7
·
verified ·
1 Parent(s): 7c73207

Upload predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. predict.py +217 -0
predict.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prediction script combining DINOv2 classifier and Qwen2-VL reasoner
3
+ Outputs predictions.json in required format
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import transforms
9
+ from transformers import (
10
+ AutoImageProcessor,
11
+ Dinov2Model,
12
+ Qwen3VLForConditionalGeneration,
13
+ AutoProcessor
14
+ )
15
+ from peft import PeftModel
16
+ from PIL import Image
17
+ import json
18
+ import os
19
+ from pathlib import Path
20
+ from tqdm import tqdm
21
+ from qwen_vl_utils import process_vision_info
22
+
23
+ class DINOv2Classifier(nn.Module):
24
+ def __init__(self, num_classes=3):
25
+ super().__init__()
26
+ self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")
27
+
28
+ # Classification head
29
+ self.classifier = nn.Sequential(
30
+ nn.Linear(768, 512),
31
+ nn.ReLU(),
32
+ nn.Dropout(0.3),
33
+ nn.Linear(512, 256),
34
+ nn.ReLU(),
35
+ nn.Dropout(0.3),
36
+ nn.Linear(256, num_classes)
37
+ )
38
+
39
+ def forward(self, pixel_values):
40
+ outputs = self.dinov2(pixel_values)
41
+ cls_token = outputs.last_hidden_state[:, 0]
42
+ logits = self.classifier(cls_token)
43
+ return logits
44
+
45
+ class GenAIDetector:
46
+ def __init__(self, classifier_path):
47
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+ print(f"Using device: {self.device}")
49
+
50
+ # Load DINOv2 classifier
51
+ print("Loading classifier...")
52
+ self.classifier = DINOv2Classifier(num_classes=3).to(self.device)
53
+ checkpoint = torch.load(classifier_path, map_location=self.device)
54
+ self.classifier.load_state_dict(checkpoint['model_state_dict'])
55
+ self.classifier.eval()
56
+
57
+ self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
58
+
59
+ # Load VLM
60
+ print("Loading VLM reasoner...")
61
+ base_model = Qwen3VLForConditionalGeneration.from_pretrained(
62
+ "Qwen/Qwen3-VL-8B-Instruct",
63
+ torch_dtype="auto",
64
+ device_map="auto"
65
+ )
66
+ self.vlm = base_model
67
+ self.vlm_processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
68
+ self.vlm.eval()
69
+
70
+ self.class_names = ['real', 'manipulated', 'fake']
71
+ self.manipulation_types = {
72
+ 'real': 'none',
73
+ 'manipulated': 'inpainting',
74
+ 'fake': 'full_synthesis'
75
+ }
76
+
77
+ def classify_image(self, image_path):
78
+ """Classify image and get confidence scores"""
79
+ image = Image.open(image_path).convert('RGB')
80
+ inputs = self.image_processor(images=image, return_tensors="pt")
81
+ pixel_values = inputs['pixel_values'].to(self.device)
82
+
83
+ with torch.no_grad():
84
+ logits = self.classifier(pixel_values)
85
+ probs = torch.softmax(logits, dim=1)
86
+ pred_class = torch.argmax(probs, dim=1).item()
87
+ confidence = probs[0].cpu().numpy()
88
+
89
+ return pred_class, confidence
90
+
91
+ def generate_reasoning(self, image_path, predicted_class):
92
+ """Generate reasoning using VLM"""
93
+ class_name = self.class_names[predicted_class]
94
+
95
+ # Prepare prompt
96
+ prompt = f"The given image has been flagged as {class_name}. Explain in 2-3 sentences why that might be. Focus on specific features which indicated this."
97
+
98
+ messages = [
99
+ {
100
+ "role": "user",
101
+ "content": [
102
+ {"type": "image", "image": image_path},
103
+ {"type": "text", "text": prompt}
104
+ ]
105
+ }
106
+ ]
107
+
108
+ # Apply chat template
109
+ text = self.vlm_processor.apply_chat_template(
110
+ messages,
111
+ tokenize=False,
112
+ add_generation_prompt=True
113
+ )
114
+
115
+ # Process inputs
116
+ image_inputs, video_inputs = process_vision_info(messages)
117
+ inputs = self.vlm_processor(
118
+ text=[text],
119
+ images=image_inputs,
120
+ videos=video_inputs,
121
+ padding=True,
122
+ return_tensors="pt"
123
+ )
124
+ inputs = inputs.to(self.device)
125
+
126
+ # Generate
127
+ with torch.no_grad():
128
+ output_ids = self.vlm.generate(
129
+ **inputs,
130
+ max_new_tokens=150,
131
+ temperature=0.7,
132
+ do_sample=True
133
+ )
134
+
135
+ # Decode
136
+ generated_text = self.vlm_processor.batch_decode(
137
+ output_ids,
138
+ skip_special_tokens=True,
139
+ clean_up_tokenization_spaces=False
140
+ )[0]
141
+
142
+ # Extract assistant response
143
+ if "assistant" in generated_text.lower():
144
+ reasoning = generated_text.split("assistant")[-1].strip()
145
+ else:
146
+ reasoning = generated_text.strip()
147
+
148
+ return reasoning
149
+
150
+ def predict(self, image_path):
151
+ """Full prediction pipeline"""
152
+ # Classify
153
+ pred_class, confidence = self.classify_image(image_path)
154
+
155
+ # Get authenticity score (confidence that it's real, i.e., confidence[0])
156
+ authenticity_score = float(1.0 - confidence[0]) # Higher score = more manipulated
157
+
158
+ # Get manipulation type
159
+ class_name = self.class_names[pred_class]
160
+ manipulation_type = self.manipulation_types[class_name]
161
+
162
+ # Generate reasoning
163
+ reasoning = self.generate_reasoning(image_path, pred_class)
164
+
165
+ return {
166
+ 'authenticity_score': round(authenticity_score, 2),
167
+ 'manipulation_type': manipulation_type,
168
+ 'vlm_reasoning': reasoning
169
+ }
170
+
171
+ def main(image_dir, classifier_path, output_file):
172
+ """Main prediction function"""
173
+
174
+ # Initialize detector
175
+ detector = GenAIDetector(classifier_path)
176
+
177
+ # Get all images
178
+ image_extensions = ['.jpg', '.jpeg', '.png']
179
+ image_files = []
180
+ for ext in image_extensions:
181
+ image_files.extend(Path(image_dir).glob(f'*{ext}'))
182
+ image_files.extend(Path(image_dir).glob(f'*{ext.upper()}'))
183
+
184
+ print(f"Found {len(image_files)} images")
185
+
186
+ # Process images
187
+ predictions = []
188
+ for image_path in tqdm(image_files, desc="Processing images"):
189
+ try:
190
+ result = detector.predict(str(image_path))
191
+ result['image_name'] = image_path.name
192
+ predictions.append(result)
193
+ except Exception as e:
194
+ print(f"Error processing {image_path.name}: {str(e)}")
195
+ continue
196
+
197
+ # Save predictions
198
+ with open(output_file, 'w') as f:
199
+ json.dump(predictions, f, indent=2)
200
+
201
+ print(f"\n✓ Processed {len(predictions)} images")
202
+ print(f"✓ Saved predictions to {output_file}")
203
+
204
+ if __name__ == "__main__":
205
+ import argparse
206
+
207
+ parser = argparse.ArgumentParser()
208
+ parser.add_argument('--image_dir', type=str, default='./test_images',
209
+ help='Directory containing images to predict')
210
+ parser.add_argument('--classifier_path', type=str, default='best_model.pth',
211
+ help='Path to trained DINOv2 checkpoint (.pth file)')
212
+ parser.add_argument('--output_file', type=str, default='predictions.json',
213
+ help='Output JSON file')
214
+
215
+ args = parser.parse_args()
216
+
217
+ main(args.image_dir, args.classifier_path, args.output_file)