manar54 commited on
Commit
b7bca53
·
verified ·
1 Parent(s): bfe8f56

Upload 2 files

Browse files
Files changed (2) hide show
  1. predict.py +139 -0
  2. requirements.txt +7 -0
predict.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # predict.py
2
+ import argparse
3
+ import os
4
+ import json
5
+ import re
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
9
+ from qwen_vl_utils import process_vision_info
10
+
11
+ # --- CONFIGURATION ---
12
+ # We use the base Qwen model. It will download automatically on the first run.
13
+ MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
14
+
15
+ def load_model():
16
+ """Loads the model with 4-bit quantization for efficiency."""
17
+ print(f"⏳ Loading Model: {MODEL_ID}...")
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_use_double_quant=True,
21
+ bnb_4bit_quant_type="nf4",
22
+ bnb_4bit_compute_dtype=torch.bfloat16
23
+ )
24
+
25
+ # Use 'sdpa' implementation for broad compatibility (Colab T4 / RTX GPUs)
26
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
27
+ MODEL_ID,
28
+ quantization_config=bnb_config,
29
+ device_map="auto",
30
+ trust_remote_code=True,
31
+ attn_implementation="sdpa"
32
+ )
33
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
34
+ print("✅ Model Loaded.")
35
+ return model, processor
36
+
37
+ def analyze_image(model, processor, image_path):
38
+ """Runs the VLM analysis."""
39
+ prompt_text = (
40
+ "You are a Forensic Image Analyst. Analyze this image for GenAI manipulation.\n"
41
+ "Focus on: Lighting inconsistencies, structural logic, and unnatural textures.\n"
42
+ "Provide your output STRICTLY as a JSON object with these keys:\n"
43
+ "- 'authenticity_score': float (0.0=Real, 1.0=Fake)\n"
44
+ "- 'manipulation_type': string (e.g., 'Inpainting', 'None')\n"
45
+ "- 'vlm_reasoning': string (max 2 sentences)\n"
46
+ )
47
+
48
+ messages = [
49
+ {
50
+ "role": "user",
51
+ "content": [
52
+ {"type": "image", "image": image_path},
53
+ {"type": "text", "text": prompt_text},
54
+ ],
55
+ }
56
+ ]
57
+
58
+ # Preprocess
59
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
+ image_inputs, video_inputs = process_vision_info(messages)
61
+ inputs = processor(
62
+ text=[text],
63
+ images=image_inputs,
64
+ videos=video_inputs,
65
+ padding=True,
66
+ return_tensors="pt",
67
+ ).to("cuda")
68
+
69
+ # Generate
70
+ with torch.no_grad():
71
+ generated_ids = model.generate(
72
+ **inputs,
73
+ max_new_tokens=256,
74
+ temperature=0.1 # Low temp for consistency
75
+ )
76
+
77
+ # Decode
78
+ generated_ids_trimmed = [
79
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
80
+ ]
81
+ output_text = processor.batch_decode(
82
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
83
+ )[0]
84
+
85
+ return clean_json(output_text)
86
+
87
+ def clean_json(text):
88
+ """Extracts JSON from response."""
89
+ try:
90
+ json_match = re.search(r"\{.*\}", text, re.DOTALL)
91
+ if json_match:
92
+ return json.loads(json_match.group(0))
93
+ return {"authenticity_score": 0.5, "manipulation_type": "Unknown", "vlm_reasoning": "Parse Error"}
94
+ except:
95
+ return {"authenticity_score": 0.5, "manipulation_type": "Error", "vlm_reasoning": "JSON Error"}
96
+
97
+ def main(input_dir, output_file):
98
+ # Load model once
99
+ model, processor = load_model()
100
+
101
+ predictions = []
102
+
103
+ # Process images
104
+ valid_extensions = ('.png', '.jpg', '.jpeg', '.webp')
105
+ files = [f for f in os.listdir(input_dir) if f.lower().endswith(valid_extensions)]
106
+
107
+ print(f"🚀 Starting inference on {len(files)} images...")
108
+
109
+ for img_name in files:
110
+ img_path = os.path.join(input_dir, img_name)
111
+ try:
112
+ result = analyze_image(model, processor, img_path)
113
+
114
+ entry = {
115
+ "image_name": img_name,
116
+ "authenticity_score": result.get("authenticity_score", 0.5),
117
+ "manipulation_type": result.get("manipulation_type", "Unknown"),
118
+ "vlm_reasoning": result.get("vlm_reasoning", "No reasoning provided.")
119
+ }
120
+ predictions.append(entry)
121
+ print(f"Processed: {img_name}")
122
+ except Exception as e:
123
+ print(f"Failed to process {img_name}: {e}")
124
+
125
+ # Save output
126
+ with open(output_file, 'w') as f:
127
+ json.dump(predictions, f, indent=4)
128
+ print(f"✅ Submission file saved to: {output_file}")
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument("--input_dir", type=str, required=True, help="Path to input images")
133
+ parser.add_argument("--output_file", type=str, required=True, help="Path to output JSON")
134
+ args = parser.parse_args()
135
+
136
+ if not os.path.exists(args.input_dir):
137
+ raise FileNotFoundError(f"Input directory {args.input_dir} not found.")
138
+
139
+ main(args.input_dir, args.output_file)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ transformers>=4.45.0
3
+ accelerate>=0.33.0
4
+ qwen-vl-utils
5
+ bitsandbytes
6
+ pillow
7
+ numpy