ai-image-detection / extract_features.py
nahid112376's picture
Upload extract_features.py with huggingface_hub
ecb8c34 verified
#!/usr/bin/env python3
"""
Feature Extraction Script for AI Detection
Extracts spatial features from images using Qwen2.5-VL model
Based on demectai feature extraction pipeline
"""
import os
import torch
import numpy as np
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
from pathlib import Path
# ==================== CONFIGURATION ====================
MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct" # Use HF model path
TARGET_SIZE = (256, 256)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ==================== FEATURE EXTRACTOR ====================
class SpatialFeatureExtractor:
def __init__(self, model_path=MODEL_PATH):
"""Initialize the Qwen2.5-VL model for feature extraction"""
print(f"πŸš€ Initializing Feature Extractor...")
print(f"πŸ’» Device: {DEVICE}")
if DEVICE.type == "cuda":
torch.cuda.empty_cache()
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f"πŸ“¦ Loading model from {model_path}...")
self.processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True
)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16 if DEVICE.type == "cuda" else torch.float32,
device_map="auto" if DEVICE.type == "cuda" else "cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
self.model.eval()
print(f"βœ… Model loaded successfully")
def preprocess_image(self, image_path):
"""Load and preprocess image with aspect ratio preservation"""
image = Image.open(image_path).convert('RGB')
# Resize with padding to preserve aspect ratio
if image.size != TARGET_SIZE:
width, height = image.size
scale = min(TARGET_SIZE[0] / width, TARGET_SIZE[1] / height)
# Resize maintaining aspect ratio
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), Image.LANCZOS)
# Create black canvas and paste resized image in center
canvas = Image.new('RGB', TARGET_SIZE, (0, 0, 0))
paste_x = (TARGET_SIZE[0] - new_width) // 2
paste_y = (TARGET_SIZE[1] - new_height) // 2
canvas.paste(image, (paste_x, paste_y))
image = canvas
return image
def extract_features(self, image_path):
"""Extract spatial features from a single image"""
try:
# Load and preprocess image
image = self.preprocess_image(image_path)
with torch.no_grad():
# Create message for the model
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Image"}
]
}]
# Process with the model
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, _ = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt"
)
inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Extract hidden states
outputs = self.model.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
pixel_values=inputs.get('pixel_values'),
image_grid_thw=inputs.get('image_grid_thw'),
output_hidden_states=True
)
# Get spatial features
spatial_features = outputs.last_hidden_state[0].cpu().numpy()
features = {
'spatial_features': spatial_features,
'num_patches': spatial_features.shape[0],
'hidden_dim': spatial_features.shape[1]
}
return features
except Exception as e:
print(f"❌ Error extracting features from {image_path}: {e}")
return None
def extract_single_image(image_path, output_path=None):
"""
Extract features from a single image
Args:
image_path: Path to the image file
output_path: Optional path to save the features (as .npz)
Returns:
Dictionary with spatial features
"""
extractor = SpatialFeatureExtractor()
features = extractor.extract_features(image_path)
if features is not None and output_path is not None:
np.savez_compressed(output_path, **features)
print(f"πŸ’Ύ Features saved to {output_path}")
return features
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python extract_features.py <image_path> [output_path]")
sys.exit(1)
image_path = sys.argv[1]
output_path = sys.argv[2] if len(sys.argv) > 2 else None
features = extract_single_image(image_path, output_path)
if features:
print(f"βœ… Feature extraction complete!")
print(f" Shape: {features['spatial_features'].shape}")
print(f" Num patches: {features['num_patches']}")
print(f" Hidden dim: {features['hidden_dim']}")