File size: 4,201 Bytes
ab51159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Intel Scene Classifier β€” Flask App
"""

import io
import os
import numpy as np
from flask import Flask, jsonify, render_template, request
from PIL import Image

app = Flask(__name__)

CLASSES  = ["buildings", "forest", "glacier", "mountain", "sea", "street"]
IMG_SIZE = 150

_pytorch_model = None
_tf_model      = None


# ── Loaders ────────────────────────────────────────────────────────────────────
def load_pytorch():
    global _pytorch_model
    if _pytorch_model is not None:
        return _pytorch_model

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torchvision import transforms

    class CNN_Torch(nn.Module):
        def __init__(self, num_classes=6):
            super().__init__()
            self.conv1      = nn.Conv2d(3, 32, kernel_size=3, padding=1)
            self.conv2      = nn.Conv2d(32, 64, kernel_size=3, padding=1)
            self.conv3      = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.conv3_drop = nn.Dropout2d(p=0.25)
            self.pool       = nn.MaxPool2d(2, 2)
            self.fc1        = nn.Linear(128 * 18 * 18, 256)
            self.fc2        = nn.Linear(256, num_classes)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3_drop(self.conv3(x))))
            x = x.view(-1, 128 * 18 * 18)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = CNN_Torch(6).to(device)
    model.load_state_dict(torch.load("parfait_model.pth", map_location=device))
    model.eval()

    tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    _pytorch_model = (model, device, tf)
    return _pytorch_model


def load_tensorflow():
    global _tf_model
    if _tf_model is None:
        import tensorflow as tf
        _tf_model = tf.keras.models.load_model("parfait_model.keras")
    return _tf_model


# ── Routes ─────────────────────────────────────────────────────────────────────
@app.route("/")
def index():
    return render_template("index.html")


@app.route("/predict", methods=["POST"])
def predict():
    if "image" not in request.files:
        return jsonify({"error": "Aucune image fournie"}), 400

    framework = request.form.get("model", "pytorch")

    try:
        pil_img = Image.open(io.BytesIO(request.files["image"].read())).convert("RGB")
    except Exception:
        return jsonify({"error": "Fichier image invalide"}), 400

    try:
        if framework == "pytorch":
            import torch
            model, device, tf = load_pytorch()
            tensor = tf(pil_img).unsqueeze(0).to(device)
            with torch.no_grad():
                out   = model(tensor)
                probs = torch.exp(out).cpu().numpy()[0]
        else:
            model = load_tensorflow()
            arr   = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE)), dtype=np.float32)
            probs = model.predict(np.expand_dims(arr, 0), verbose=0)[0]

        pred_idx = int(np.argmax(probs))
        return jsonify({
            "class":         CLASSES[pred_idx],
            "confidence":    float(probs[pred_idx]),
            "probabilities": {c: float(p) for c, p in zip(CLASSES, probs)},
        })

    except FileNotFoundError as e:
        return jsonify({"error": f"Modèle introuvable : {e}. Placez les fichiers .pth et .keras à la racine."}), 500
    except Exception as e:
        return jsonify({"error": str(e)}), 500


if __name__ == "__main__":
    port = int(os.environ.get("PORT", 5000))
    app.run(host="0.0.0.0", port=port, debug=False)