File size: 5,000 Bytes
b7bca53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# predict.py
import argparse
import os
import json
import re
import torch
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info

# --- CONFIGURATION ---
# We use the base Qwen model. It will download automatically on the first run.
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"

def load_model():
    """Loads the model with 4-bit quantization for efficiency."""
    print(f"⏳ Loading Model: {MODEL_ID}...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    
    # Use 'sdpa' implementation for broad compatibility (Colab T4 / RTX GPUs)
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="sdpa" 
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
    print("✅ Model Loaded.")
    return model, processor

def analyze_image(model, processor, image_path):
    """Runs the VLM analysis."""
    prompt_text = (
        "You are a Forensic Image Analyst. Analyze this image for GenAI manipulation.\n"
        "Focus on: Lighting inconsistencies, structural logic, and unnatural textures.\n"
        "Provide your output STRICTLY as a JSON object with these keys:\n"
        "- 'authenticity_score': float (0.0=Real, 1.0=Fake)\n"
        "- 'manipulation_type': string (e.g., 'Inpainting', 'None')\n"
        "- 'vlm_reasoning': string (max 2 sentences)\n"
    )

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},
                {"type": "text", "text": prompt_text},
            ],
        }
    ]

    # Preprocess
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    # Generate
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs, 
            max_new_tokens=256,
            temperature=0.1 # Low temp for consistency
        )

    # Decode
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    return clean_json(output_text)

def clean_json(text):
    """Extracts JSON from response."""
    try:
        json_match = re.search(r"\{.*\}", text, re.DOTALL)
        if json_match:
            return json.loads(json_match.group(0))
        return {"authenticity_score": 0.5, "manipulation_type": "Unknown", "vlm_reasoning": "Parse Error"}
    except:
        return {"authenticity_score": 0.5, "manipulation_type": "Error", "vlm_reasoning": "JSON Error"}

def main(input_dir, output_file):
    # Load model once
    model, processor = load_model()
    
    predictions = []
    
    # Process images
    valid_extensions = ('.png', '.jpg', '.jpeg', '.webp')
    files = [f for f in os.listdir(input_dir) if f.lower().endswith(valid_extensions)]
    
    print(f"🚀 Starting inference on {len(files)} images...")
    
    for img_name in files:
        img_path = os.path.join(input_dir, img_name)
        try:
            result = analyze_image(model, processor, img_path)
            
            entry = {
                "image_name": img_name,
                "authenticity_score": result.get("authenticity_score", 0.5),
                "manipulation_type": result.get("manipulation_type", "Unknown"),
                "vlm_reasoning": result.get("vlm_reasoning", "No reasoning provided.")
            }
            predictions.append(entry)
            print(f"Processed: {img_name}")
        except Exception as e:
            print(f"Failed to process {img_name}: {e}")

    # Save output
    with open(output_file, 'w') as f:
        json.dump(predictions, f, indent=4)
    print(f"✅ Submission file saved to: {output_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True, help="Path to input images")
    parser.add_argument("--output_file", type=str, required=True, help="Path to output JSON")
    args = parser.parse_args()
    
    if not os.path.exists(args.input_dir):
        raise FileNotFoundError(f"Input directory {args.input_dir} not found.")
        
    main(args.input_dir, args.output_file)