| """ |
| Flask ML Web Application - Iris Flower Classifier |
| Database schema: |
| - users (id, username, email, password, created) |
| - sessions (id, user_id, login_at, logout_at) |
| - predictions(id, username, sepal_length, sepal_width, petal_length, petal_width, prediction, confidence, timestamp) |
| """ |
|
|
| import sqlite3 |
| import hashlib |
| import os |
| from functools import wraps |
| from datetime import datetime |
|
|
| import joblib |
| import numpy as np |
| from flask import (Flask, render_template, request, redirect, |
| url_for, session, flash, g) |
|
|
| |
| app = Flask(__name__) |
| app.secret_key = os.environ.get("SECRET_KEY", "iris-ml-secret-2024") |
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| DB_PATH = os.path.join(BASE_DIR, "users.db") |
| MODEL_DIR = os.path.join(BASE_DIR, "model") |
|
|
| |
| model = joblib.load(os.path.join(MODEL_DIR, "iris_model.pkl")) |
| scaler = joblib.load(os.path.join(MODEL_DIR, "scaler.pkl")) |
| class_names = joblib.load(os.path.join(MODEL_DIR, "class_names.pkl")) |
|
|
| FLOWER_INFO = { |
| "setosa": {"emoji": "🌸", "description": "Iris setosa is a hardy species found in arctic and subarctic regions.", "color": "#f9a8d4"}, |
| "versicolor": {"emoji": "💜", "description": "Iris versicolor (Blue Flag Iris) is native to North America.", "color": "#c4b5fd"}, |
| "virginica": {"emoji": "🌺", "description": "Iris virginica (Virginia Iris) thrives in moist, coastal habitats.", "color": "#6ee7b7"}, |
| } |
|
|
| |
| def get_db(): |
| db = getattr(g, "_database", None) |
| if db is None: |
| db = g._database = sqlite3.connect(DB_PATH) |
| db.row_factory = sqlite3.Row |
| return db |
|
|
| @app.teardown_appcontext |
| def close_db(_): |
| db = getattr(g, "_database", None) |
| if db is not None: |
| db.close() |
|
|
| def init_db(): |
| with app.app_context(): |
| db = get_db() |
|
|
| |
| db.execute(""" |
| CREATE TABLE IF NOT EXISTS users ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| username TEXT UNIQUE NOT NULL, |
| email TEXT UNIQUE NOT NULL, |
| password TEXT NOT NULL, |
| created TEXT NOT NULL |
| ) |
| """) |
|
|
| |
| db.execute(""" |
| CREATE TABLE IF NOT EXISTS sessions ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| user_id INTEGER NOT NULL, |
| login_at TEXT NOT NULL, |
| logout_at TEXT, |
| FOREIGN KEY (user_id) REFERENCES users(id) |
| ) |
| """) |
|
|
| |
| db.execute(""" |
| CREATE TABLE IF NOT EXISTS predictions ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| username TEXT NOT NULL, |
| sepal_length REAL, |
| sepal_width REAL, |
| petal_length REAL, |
| petal_width REAL, |
| prediction TEXT, |
| confidence REAL, |
| timestamp TEXT NOT NULL |
| ) |
| """) |
|
|
| |
| try: |
| db.execute("ALTER TABLE users ADD COLUMN email TEXT NOT NULL DEFAULT ''") |
| except sqlite3.OperationalError: |
| pass |
|
|
| |
| pw_hash = hashlib.sha256("password123".encode()).hexdigest() |
| try: |
| db.execute( |
| "INSERT INTO users (username, email, password, created) VALUES (?, ?, ?, ?)", |
| ("admin", "admin@irisai.local", pw_hash, datetime.now().isoformat()) |
| ) |
| except sqlite3.IntegrityError: |
| pass |
| db.commit() |
|
|
| def hash_pw(pw: str) -> str: |
| return hashlib.sha256(pw.encode()).hexdigest() |
|
|
| |
| def login_required(f): |
| @wraps(f) |
| def decorated(*args, **kwargs): |
| if "username" not in session: |
| flash("Veuillez vous connecter pour accéder à cette page.", "warning") |
| return redirect(url_for("login")) |
| return f(*args, **kwargs) |
| return decorated |
|
|
| |
| @app.route("/") |
| def index(): |
| if "username" in session: |
| return redirect(url_for("predict")) |
| return redirect(url_for("login")) |
|
|
| @app.route("/login", methods=["GET", "POST"]) |
| def login(): |
| if "username" in session: |
| return redirect(url_for("predict")) |
| if request.method == "POST": |
| username = request.form.get("username", "").strip() |
| password = request.form.get("password", "") |
| db = get_db() |
| user = db.execute( |
| "SELECT * FROM users WHERE username = ? AND password = ?", |
| (username, hash_pw(password)) |
| ).fetchone() |
| if user: |
| session["username"] = username |
| session["user_id"] = user["id"] |
|
|
| |
| db.execute( |
| "INSERT INTO sessions (user_id, login_at) VALUES (?, ?)", |
| (user["id"], datetime.now().isoformat()) |
| ) |
| cursor = db.execute("SELECT last_insert_rowid()") |
| session["session_db_id"] = cursor.fetchone()[0] |
| db.commit() |
|
|
| flash(f"Bienvenue, {username} ! 👋", "success") |
| return redirect(url_for("predict")) |
| flash("Nom d'utilisateur ou mot de passe incorrect.", "danger") |
| return render_template("login.html") |
|
|
| @app.route("/register", methods=["GET", "POST"]) |
| def register(): |
| if "username" in session: |
| return redirect(url_for("predict")) |
| if request.method == "POST": |
| username = request.form.get("username", "").strip() |
| email = request.form.get("email", "").strip().lower() |
| password = request.form.get("password", "") |
| confirm = request.form.get("confirm", "") |
| if not username or not email or not password: |
| flash("Tous les champs sont obligatoires.", "danger") |
| elif "@" not in email or "." not in email: |
| flash("Adresse email invalide.", "danger") |
| elif password != confirm: |
| flash("Les mots de passe ne correspondent pas.", "danger") |
| elif len(password) < 6: |
| flash("Le mot de passe doit comporter au moins 6 caractères.", "danger") |
| else: |
| db = get_db() |
| try: |
| db.execute( |
| "INSERT INTO users (username, email, password, created) VALUES (?, ?, ?, ?)", |
| (username, email, hash_pw(password), datetime.now().isoformat()) |
| ) |
| db.commit() |
| flash("Compte créé ! Vous pouvez vous connecter.", "success") |
| return redirect(url_for("login")) |
| except sqlite3.IntegrityError as e: |
| if "username" in str(e): |
| flash("Ce nom d'utilisateur est déjà pris.", "danger") |
| else: |
| flash("Cette adresse email est déjà utilisée.", "danger") |
| return render_template("register.html") |
|
|
| @app.route("/logout") |
| def logout(): |
| |
| session_db_id = session.get("session_db_id") |
| if session_db_id: |
| db = get_db() |
| db.execute( |
| "UPDATE sessions SET logout_at = ? WHERE id = ?", |
| (datetime.now().isoformat(), session_db_id) |
| ) |
| db.commit() |
|
|
| session.clear() |
| flash("Vous avez été déconnecté.", "info") |
| return redirect(url_for("login")) |
|
|
| @app.route("/predict", methods=["GET", "POST"]) |
| @login_required |
| def predict(): |
| result = None |
| form_data = {} |
|
|
| if request.method == "POST": |
| try: |
| sepal_length = float(request.form["sepal_length"]) |
| sepal_width = float(request.form["sepal_width"]) |
| petal_length = float(request.form["petal_length"]) |
| petal_width = float(request.form["petal_width"]) |
|
|
| if not (0 < sepal_length < 20 and 0 < sepal_width < 20 and |
| 0 < petal_length < 20 and 0 < petal_width < 20): |
| flash("Entrez des mesures réalistes (0–20 cm).", "warning") |
| return redirect(url_for("predict")) |
|
|
| form_data = { |
| "sepal_length": sepal_length, "sepal_width": sepal_width, |
| "petal_length": petal_length, "petal_width": petal_width, |
| } |
|
|
| X = np.array([[sepal_length, sepal_width, petal_length, petal_width]]) |
| X_scaled = scaler.transform(X) |
| pred_idx = model.predict(X_scaled)[0] |
| proba = model.predict_proba(X_scaled)[0] |
| pred_name = class_names[pred_idx] |
| confidence = float(proba[pred_idx]) * 100 |
|
|
| db = get_db() |
| db.execute( |
| """INSERT INTO predictions |
| (username, sepal_length, sepal_width, petal_length, petal_width, |
| prediction, confidence, timestamp) |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", |
| (session["username"], sepal_length, sepal_width, |
| petal_length, petal_width, pred_name, confidence, |
| datetime.now().isoformat()) |
| ) |
| db.commit() |
|
|
| all_probs = {class_names[i]: round(float(p) * 100, 1) for i, p in enumerate(proba)} |
| result = { |
| "name": pred_name, |
| "confidence": round(confidence, 1), |
| "info": FLOWER_INFO.get(pred_name, {}), |
| "all_probs": all_probs, |
| } |
|
|
| except (ValueError, KeyError): |
| flash("Entrée invalide — veuillez saisir des valeurs numériques.", "danger") |
|
|
| db = get_db() |
| history = db.execute( |
| "SELECT * FROM predictions WHERE username = ? ORDER BY id DESC LIMIT 5", |
| (session["username"],) |
| ).fetchall() |
|
|
| return render_template("predict.html", |
| result=result, |
| form_data=form_data, |
| history=history, |
| username=session["username"]) |
|
|
| |
| if __name__ == "__main__": |
| init_db() |
| port = int(os.environ.get("PORT", 5000)) |
| app.run(host="0.0.0.0", port=port, debug=False) |
|
|
|
|
| |
| @app.route("/admin") |
| @login_required |
| def admin(): |
| |
| if session.get("username") != "admin": |
| flash("Accès réservé à l'administrateur.", "danger") |
| return redirect(url_for("predict")) |
| db = get_db() |
|
|
| users = db.execute( |
| "SELECT id, username, email, created FROM users ORDER BY id DESC" |
| ).fetchall() |
|
|
| sessions_log = db.execute(""" |
| SELECT s.id, u.username, u.email, s.login_at, s.logout_at |
| FROM sessions s |
| JOIN users u ON s.user_id = u.id |
| ORDER BY s.id DESC |
| LIMIT 50 |
| """).fetchall() |
|
|
| stats = { |
| "total_users": db.execute("SELECT COUNT(*) FROM users").fetchone()[0], |
| "total_sessions": db.execute("SELECT COUNT(*) FROM sessions").fetchone()[0], |
| "active_now": db.execute("SELECT CO |
| |