Spaces:
Runtime error
Runtime error
| # ============================================================ | |
| # APPLICATION FLASK - CLASSIFICATION D'IMAGES | |
| # ============================================================ | |
| from flask import Flask, render_template, request | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import tensorflow as tf | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| import os | |
| app = Flask(__name__) | |
| CLASSES = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street'] | |
| IMG_SIZE = 150 | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # ββ DΓ©finition du CNN PyTorch (mΓͺme architecture) ββββββββββββ | |
| class CNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(128 * 18 * 18, 256), nn.ReLU(), nn.Dropout(0.5), | |
| nn.Linear(256, 6), | |
| ) | |
| def forward(self, x): | |
| return self.classifier(self.features(x)) | |
| # ββ Chargement des modΓ¨les βββββββββββββββββββββββββββββββββββ | |
| model_pytorch = CNN() | |
| model_pytorch.load_state_dict( | |
| torch.load( | |
| os.path.join(BASE_DIR, "models", "binta_py_model.pth"), | |
| map_location="cpu" | |
| ) | |
| ) | |
| model_pytorch.eval() | |
| model_tensorflow = tf.keras.models.load_model( | |
| os.path.join(BASE_DIR, "models", "binta_ten_model.keras"), | |
| compile=False | |
| ) | |
| # ββ Transform pour PyTorch βββββββββββββββββββββββββββββββββββ | |
| transform = T.Compose([ | |
| T.Resize((IMG_SIZE, IMG_SIZE)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], | |
| std =[0.229, 0.224, 0.225]), | |
| ]) | |
| # ββ Fonction de prΓ©diction βββββββββββββββββββββββββββββββββββ | |
| def predict(image, model_choice): | |
| img = Image.open(io.BytesIO(image)).convert("RGB") | |
| if model_choice == "pytorch": | |
| tensor = transform(img).unsqueeze(0) # ajoute dimension batch | |
| with torch.no_grad(): | |
| out = model_pytorch(tensor) | |
| return CLASSES[out.argmax(1).item()] | |
| elif model_choice == "tensorflow": | |
| img = img.resize((IMG_SIZE, IMG_SIZE)) | |
| arr = np.array(img) / 255.0 | |
| arr = np.expand_dims(arr, axis=0) # ajoute dimension batch | |
| out = model_tensorflow.predict(arr) | |
| return CLASSES[np.argmax(out)] | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def index(): | |
| prediction = None | |
| model_choice = None | |
| if request.method == "POST": | |
| model_choice = request.form["model"] | |
| image_file = request.files["image"].read() | |
| prediction = predict(image_file, model_choice) | |
| return render_template("index.html", | |
| prediction=prediction, | |
| model_choice=model_choice) | |
| if __name__ == "__main__": | |
| app.run(host='0.0.0.0', port=7860) |