Deploy_app1 / app.py
Binta26's picture
Update app.py
2a7c5fd verified
Raw
History Blame Contribute Delete
3.46 kB
# ============================================================
# APPLICATION FLASK - CLASSIFICATION D'IMAGES
# ============================================================
from flask import Flask, render_template, request
import torch
import torch.nn as nn
import torchvision.transforms as T
import tensorflow as tf
from PIL import Image
import numpy as np
import io
import os
app = Flask(__name__)
CLASSES = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']
IMG_SIZE = 150
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ── DΓ©finition du CNN PyTorch (mΓͺme architecture) ────────────
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 18 * 18, 256), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(256, 6),
)
def forward(self, x):
return self.classifier(self.features(x))
# ── Chargement des modΓ¨les ───────────────────────────────────
model_pytorch = CNN()
model_pytorch.load_state_dict(
torch.load(
os.path.join(BASE_DIR, "models", "binta_py_model.pth"),
map_location="cpu"
)
)
model_pytorch.eval()
model_tensorflow = tf.keras.models.load_model(
os.path.join(BASE_DIR, "models", "binta_ten_model.keras"),
compile=False
)
# ── Transform pour PyTorch ───────────────────────────────────
transform = T.Compose([
T.Resize((IMG_SIZE, IMG_SIZE)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std =[0.229, 0.224, 0.225]),
])
# ── Fonction de prΓ©diction ───────────────────────────────────
def predict(image, model_choice):
img = Image.open(io.BytesIO(image)).convert("RGB")
if model_choice == "pytorch":
tensor = transform(img).unsqueeze(0) # ajoute dimension batch
with torch.no_grad():
out = model_pytorch(tensor)
return CLASSES[out.argmax(1).item()]
elif model_choice == "tensorflow":
img = img.resize((IMG_SIZE, IMG_SIZE))
arr = np.array(img) / 255.0
arr = np.expand_dims(arr, axis=0) # ajoute dimension batch
out = model_tensorflow.predict(arr)
return CLASSES[np.argmax(out)]
# ── Routes ───────────────────────────────────────────────────
@app.route("/", methods=["GET", "POST"])
def index():
prediction = None
model_choice = None
if request.method == "POST":
model_choice = request.form["model"]
image_file = request.files["image"].read()
prediction = predict(image_file, model_choice)
return render_template("index.html",
prediction=prediction,
model_choice=model_choice)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)