File size: 1,893 Bytes
b5b2f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import torch
import numpy as np
import joblib
from PIL import Image
from flask import Flask, request, jsonify
from transformers import CLIPProcessor, CLIPModel
from io import BytesIO
from flask_cors import CORS
import base64
import io

# Flask app initialization
app = Flask(__name__)
CORS(app)

# Load models once at the start
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load the ensemble classifier model
ensemble_clf = joblib.load("model/random_forest_tuned_aug.pkl")

# Label mapping
label_map = {0: "real", 1: "deepfake", 2: "ai_gen"}

def extract_features(image):
    image = image.resize((224, 224))  # Resize to the required input size (224x224)
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        # Extract image features using CLIP
        outputs = model.get_image_features(**inputs)
    
    emb = outputs.cpu().numpy().squeeze()
    return emb

@app.route("/predict", methods=["POST"])
def predict():
    # Get the uploaded image
    data = request.json
    if 'image' not in data:
        return jsonify({"error": "No image provided"}), 400

    image_data = base64.b64decode(data['image'])
    image = Image.open(io.BytesIO(image_data)).convert("RGB")
    
    # Extract features and predict
    features = extract_features(image)
    probs = ensemble_clf.predict_proba([features])[0]
    top_idx = np.argmax(probs)
    
    # Prepare response
    response = {
        "prediction": label_map[top_idx],
        "probabilities": probs.tolist()
    }
    
    return jsonify(response)

if __name__ == "__main__":
    # Run Flask app
    app.run(debug=True)