Elm-Challenge / predict.py
manar54's picture
Upload 2 files
b7bca53 verified
# predict.py
import argparse
import os
import json
import re
import torch
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
# --- CONFIGURATION ---
# We use the base Qwen model. It will download automatically on the first run.
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
def load_model():
"""Loads the model with 4-bit quantization for efficiency."""
print(f"โณ Loading Model: {MODEL_ID}...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Use 'sdpa' implementation for broad compatibility (Colab T4 / RTX GPUs)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
print("โœ… Model Loaded.")
return model, processor
def analyze_image(model, processor, image_path):
"""Runs the VLM analysis."""
prompt_text = (
"You are a Forensic Image Analyst. Analyze this image for GenAI manipulation.\n"
"Focus on: Lighting inconsistencies, structural logic, and unnatural textures.\n"
"Provide your output STRICTLY as a JSON object with these keys:\n"
"- 'authenticity_score': float (0.0=Real, 1.0=Fake)\n"
"- 'manipulation_type': string (e.g., 'Inpainting', 'None')\n"
"- 'vlm_reasoning': string (max 2 sentences)\n"
)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": prompt_text},
],
}
]
# Preprocess
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1 # Low temp for consistency
)
# Decode
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return clean_json(output_text)
def clean_json(text):
"""Extracts JSON from response."""
try:
json_match = re.search(r"\{.*\}", text, re.DOTALL)
if json_match:
return json.loads(json_match.group(0))
return {"authenticity_score": 0.5, "manipulation_type": "Unknown", "vlm_reasoning": "Parse Error"}
except:
return {"authenticity_score": 0.5, "manipulation_type": "Error", "vlm_reasoning": "JSON Error"}
def main(input_dir, output_file):
# Load model once
model, processor = load_model()
predictions = []
# Process images
valid_extensions = ('.png', '.jpg', '.jpeg', '.webp')
files = [f for f in os.listdir(input_dir) if f.lower().endswith(valid_extensions)]
print(f"๐Ÿš€ Starting inference on {len(files)} images...")
for img_name in files:
img_path = os.path.join(input_dir, img_name)
try:
result = analyze_image(model, processor, img_path)
entry = {
"image_name": img_name,
"authenticity_score": result.get("authenticity_score", 0.5),
"manipulation_type": result.get("manipulation_type", "Unknown"),
"vlm_reasoning": result.get("vlm_reasoning", "No reasoning provided.")
}
predictions.append(entry)
print(f"Processed: {img_name}")
except Exception as e:
print(f"Failed to process {img_name}: {e}")
# Save output
with open(output_file, 'w') as f:
json.dump(predictions, f, indent=4)
print(f"โœ… Submission file saved to: {output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, required=True, help="Path to input images")
parser.add_argument("--output_file", type=str, required=True, help="Path to output JSON")
args = parser.parse_args()
if not os.path.exists(args.input_dir):
raise FileNotFoundError(f"Input directory {args.input_dir} not found.")
main(args.input_dir, args.output_file)