File size: 5,000 Bytes
b7bca53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# predict.py
import argparse
import os
import json
import re
import torch
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
# --- CONFIGURATION ---
# We use the base Qwen model. It will download automatically on the first run.
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
def load_model():
"""Loads the model with 4-bit quantization for efficiency."""
print(f"⏳ Loading Model: {MODEL_ID}...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Use 'sdpa' implementation for broad compatibility (Colab T4 / RTX GPUs)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
print("✅ Model Loaded.")
return model, processor
def analyze_image(model, processor, image_path):
"""Runs the VLM analysis."""
prompt_text = (
"You are a Forensic Image Analyst. Analyze this image for GenAI manipulation.\n"
"Focus on: Lighting inconsistencies, structural logic, and unnatural textures.\n"
"Provide your output STRICTLY as a JSON object with these keys:\n"
"- 'authenticity_score': float (0.0=Real, 1.0=Fake)\n"
"- 'manipulation_type': string (e.g., 'Inpainting', 'None')\n"
"- 'vlm_reasoning': string (max 2 sentences)\n"
)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": prompt_text},
],
}
]
# Preprocess
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1 # Low temp for consistency
)
# Decode
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return clean_json(output_text)
def clean_json(text):
"""Extracts JSON from response."""
try:
json_match = re.search(r"\{.*\}", text, re.DOTALL)
if json_match:
return json.loads(json_match.group(0))
return {"authenticity_score": 0.5, "manipulation_type": "Unknown", "vlm_reasoning": "Parse Error"}
except:
return {"authenticity_score": 0.5, "manipulation_type": "Error", "vlm_reasoning": "JSON Error"}
def main(input_dir, output_file):
# Load model once
model, processor = load_model()
predictions = []
# Process images
valid_extensions = ('.png', '.jpg', '.jpeg', '.webp')
files = [f for f in os.listdir(input_dir) if f.lower().endswith(valid_extensions)]
print(f"🚀 Starting inference on {len(files)} images...")
for img_name in files:
img_path = os.path.join(input_dir, img_name)
try:
result = analyze_image(model, processor, img_path)
entry = {
"image_name": img_name,
"authenticity_score": result.get("authenticity_score", 0.5),
"manipulation_type": result.get("manipulation_type", "Unknown"),
"vlm_reasoning": result.get("vlm_reasoning", "No reasoning provided.")
}
predictions.append(entry)
print(f"Processed: {img_name}")
except Exception as e:
print(f"Failed to process {img_name}: {e}")
# Save output
with open(output_file, 'w') as f:
json.dump(predictions, f, indent=4)
print(f"✅ Submission file saved to: {output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, required=True, help="Path to input images")
parser.add_argument("--output_file", type=str, required=True, help="Path to output JSON")
args = parser.parse_args()
if not os.path.exists(args.input_dir):
raise FileNotFoundError(f"Input directory {args.input_dir} not found.")
main(args.input_dir, args.output_file) |