Real-Estate-Forensics / predict.py
LinaAlkh's picture
Upload 5 files
2d96524 verified
import os
import json
import torch
import torch.nn as nn
from torchvision import models, transforms
from transformers import BlipProcessor, BlipForQuestionAnswering
from PIL import Image
from tqdm import tqdm
import argparse
import random
# ==========================================
# 1. إعدادات الكلاسات (Hardcoded)
# ==========================================
FINAL_CLASSES = ['fake_ai', 'fake_splice', 'real']
class ManipulateDetector:
def __init__(self, model_path, device):
self.device = device
self.class_names = FINAL_CLASSES
print(f"🔧 Initializing Detector with classes: {self.class_names}")
self.model = models.resnet18(pretrained=False)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, len(self.class_names))
try:
state_dict = torch.load(model_path, map_location=device)
self.model.load_state_dict(state_dict, strict=False)
print("✅ Weights loaded successfully!")
except Exception as e:
print(f"⚠️ Warning loading weights: {e}")
self.model.to(device)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(self, image_path):
image = Image.open(image_path).convert('RGB')
img_t = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img_t)
probs = torch.nn.functional.softmax(outputs, dim=1)
score, preds = torch.max(probs, 1)
class_idx = preds.item()
if class_idx < len(self.class_names):
label = self.class_names[class_idx]
else:
label = "fake_splice"
confidence = probs[0][class_idx].item()
if label == 'real':
authenticity_score = confidence
else:
authenticity_score = 1.0 - confidence
return authenticity_score, label
# ==========================================
# 2. المحلل الذكي (Smart Forensic Reasoner)
# ==========================================
class ForensicVLM:
def __init__(self, device):
self.device = device
print("🔧 Loading VLM (BLIP Pro)...")
try:
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
self.model.eval()
self.loaded = True
except:
self.loaded = False
def ask(self, image, question):
inputs = self.processor(image, question, return_tensors="pt").to(self.device)
out = self.model.generate(**inputs)
return self.processor.decode(out[0], skip_special_tokens=True)
def explain(self, image_path, pred_label):
if not self.loaded: return "System error during analysis."
image = Image.open(image_path).convert('RGB')
# إذا كانت الصورة حقيقية
if pred_label == 'real':
# نسأل لنتأكد من وصف المشهد بإيجابية
scene_desc = self.ask(image, "What type of room is this?")
return f"Authentic scene. The {scene_desc} displays consistent global illumination and natural perspective geometry."
# --- التحقيق الجنائي للصورة المزيفة ---
# 1. تحديد الجسم المشبوه (بدلاً من كلمة furniture العامة)
suspicious_object = self.ask(image, "What is the main piece of furniture in this image?")
if "room" in suspicious_object or "living" in suspicious_object:
suspicious_object = "furniture object" # fallback
# 2. التحقق من الظلال لهذا الجسم تحديداً
shadow_check = self.ask(image, f"Does the {suspicious_object} cast a realistic shadow on the floor?")
# 3. التحقق من الإضاءة
light_check = self.ask(image, "Is the lighting on the furniture matching the background?")
# 4. التحقق من الطفو
float_check = self.ask(image, f"Does the {suspicious_object} look like it is floating?")
# --- بناء التقرير الذكي ---
reasons = []
if "no" in shadow_check.lower():
reasons.append(f"the {suspicious_object} lacks a grounded contact shadow")
if "no" in light_check.lower():
reasons.append(f"illumination on the {suspicious_object} contradicts the room's light source")
if "yes" in float_check.lower():
reasons.append(f"spatial disconnection observed (floating {suspicious_object})")
# إذا لم يجد أسباباً محددة رغم أن الكاشف قال أنها مزيفة
if not reasons:
reasons.append(f"digital artifacts detected around the {suspicious_object}")
# صياغة الجملة النهائية
joined_reasons = "; ".join(reasons)
final_report = f"Manipulation detected: {joined_reasons}. The integration of the {suspicious_object} into the scene is physically inconsistent."
return final_report
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, default="./test_images")
parser.add_argument("--output_file", type=str, default="predictions.json")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "/content/drive/MyDrive/RealEstate_Challenge/detector_model.pth"
if not os.path.exists(model_path):
print("❌ Model file not found!")
return
detector = ManipulateDetector(model_path, device)
vlm = ForensicVLM(device)
results = []
if not os.path.exists(args.input_dir): return
files = [f for f in os.listdir(args.input_dir) if f.endswith(('.jpg', '.png'))]
print(f"🚀 Processing {len(files)} images...")
for img_file in tqdm(files):
try:
score, label = detector.predict(os.path.join(args.input_dir, img_file))
reasoning = vlm.explain(os.path.join(args.input_dir, img_file), label)
results.append({
"image_name": img_file,
"authenticity_score": round(float(score), 4),
"manipulation_type": label,
"vlm_reasoning": reasoning
})
except: pass
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=2)
print("✅ Done!")
if __name__ == "__main__":
main()