DeepTrust / model /predict.py
priyansh-nagar's picture
Update model/predict.py
005fde1 verified
import torch
from torchvision import models, transforms
from PIL import Image
MODEL_PATH = "models/deeptrust_weights.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Recreate the same model architecture
model = models.efficientnet_b0(pretrained=False)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, 2)
# Load checkpoint safely
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# Transform for input images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
def predict(image: Image.Image):
"""
Predicts if an image is Real or AI-generated and returns trust score.
"""
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1)
# Get the predicted index
pred_index = torch.argmax(probs, dim=1).item()
confidence = probs[0, pred_index].item()
# Map pred_index to Real / Fake
# This is now automatic: higher probability class = model's predicted class
label = "Real" if pred_index == 0 else "Fake"
# Trust score: higher for Real images if predicted Real
trust_score = int(confidence*100 if label=="Real" else (1-confidence)*100)
# Debug print to verify
print("pred_index:", pred_index, "label:", label, "confidence:", confidence)
return label, confidence, trust_score