File size: 1,921 Bytes
1e4485c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85db738
60dd668
1e4485c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import torch
import numpy as np
import joblib
from PIL import Image
from flask import Flask, request, jsonify
from transformers import CLIPProcessor, CLIPModel
from io import BytesIO
from flask_cors import CORS
import base64
import io

# Flask app initialization
app = Flask(__name__)
CORS(app)

# Load models once at the start
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", from_tf=True).to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", from_tf=True)

# Load the ensemble classifier model
ensemble_clf = joblib.load("model/random_forest_tuned_aug.pkl")

# Label mapping
label_map = {0: "real", 1: "deepfake", 2: "ai_gen"}

def extract_features(image):
    image = image.resize((224, 224))  # Resize to the required input size (224x224)
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        # Extract image features using CLIP
        outputs = model.get_image_features(**inputs)
    
    emb = outputs.cpu().numpy().squeeze()
    return emb

@app.route("/predict", methods=["POST"])
def predict():
    # Get the uploaded image
    data = request.json
    if 'image' not in data:
        return jsonify({"error": "No image provided"}), 400

    image_data = base64.b64decode(data['image'])
    image = Image.open(io.BytesIO(image_data)).convert("RGB")
    
    # Extract features and predict
    features = extract_features(image)
    probs = ensemble_clf.predict_proba([features])[0]
    top_idx = np.argmax(probs)
    
    # Prepare response
    response = {
        "prediction": label_map[top_idx],
        "probabilities": probs.tolist()
    }
    
    return jsonify(response)

if __name__ == "__main__":
    # Run Flask app
    app.run(debug=True)