WEB_DEV / app.py.save
Binta26's picture
Upload 8 files
43c0ec5 verified
Raw
History Blame Contribute Delete
12.5 kB
"""
Flask ML Web Application - Iris Flower Classifier
Database schema:
- users (id, username, email, password, created)
- sessions (id, user_id, login_at, logout_at)
- predictions(id, username, sepal_length, sepal_width, petal_length, petal_width, prediction, confidence, timestamp)
"""
import sqlite3
import hashlib
import os
from functools import wraps
from datetime import datetime
import joblib
import numpy as np
from flask import (Flask, render_template, request, redirect,
url_for, session, flash, g)
# ── App setup ─────────────────────────────────────────────────────────────────
app = Flask(__name__)
app.secret_key = os.environ.get("SECRET_KEY", "iris-ml-secret-2024")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DB_PATH = os.path.join(BASE_DIR, "users.db")
MODEL_DIR = os.path.join(BASE_DIR, "model")
# ── Load ML artefacts ─────────────────────────────────────────────────────────
model = joblib.load(os.path.join(MODEL_DIR, "iris_model.pkl"))
scaler = joblib.load(os.path.join(MODEL_DIR, "scaler.pkl"))
class_names = joblib.load(os.path.join(MODEL_DIR, "class_names.pkl"))
FLOWER_INFO = {
"setosa": {"emoji": "🌸", "description": "Iris setosa is a hardy species found in arctic and subarctic regions.", "color": "#f9a8d4"},
"versicolor": {"emoji": "💜", "description": "Iris versicolor (Blue Flag Iris) is native to North America.", "color": "#c4b5fd"},
"virginica": {"emoji": "🌺", "description": "Iris virginica (Virginia Iris) thrives in moist, coastal habitats.", "color": "#6ee7b7"},
}
# ── Database helpers ───────────────────────────────────────────────────────────
def get_db():
db = getattr(g, "_database", None)
if db is None:
db = g._database = sqlite3.connect(DB_PATH)
db.row_factory = sqlite3.Row
return db
@app.teardown_appcontext
def close_db(_):
db = getattr(g, "_database", None)
if db is not None:
db.close()
def init_db():
with app.app_context():
db = get_db()
# Table users : id, username, email, password, created
db.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
email TEXT UNIQUE NOT NULL,
password TEXT NOT NULL,
created TEXT NOT NULL
)
""")
# Table sessions : trace chaque connexion / déconnexion
db.execute("""
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
login_at TEXT NOT NULL,
logout_at TEXT,
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
# Table predictions
db.execute("""
CREATE TABLE IF NOT EXISTS predictions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
sepal_length REAL,
sepal_width REAL,
petal_length REAL,
petal_width REAL,
prediction TEXT,
confidence REAL,
timestamp TEXT NOT NULL
)
""")
# Migration : ajouter email si la colonne manque dans une ancienne DB
try:
db.execute("ALTER TABLE users ADD COLUMN email TEXT NOT NULL DEFAULT ''")
except sqlite3.OperationalError:
pass # colonne déjà présente
# Compte admin par défaut
pw_hash = hashlib.sha256("password123".encode()).hexdigest()
try:
db.execute(
"INSERT INTO users (username, email, password, created) VALUES (?, ?, ?, ?)",
("admin", "admin@irisai.local", pw_hash, datetime.now().isoformat())
)
except sqlite3.IntegrityError:
pass
db.commit()
def hash_pw(pw: str) -> str:
return hashlib.sha256(pw.encode()).hexdigest()
# ── Auth decorator ─────────────────────────────────────────────────────────────
def login_required(f):
@wraps(f)
def decorated(*args, **kwargs):
if "username" not in session:
flash("Veuillez vous connecter pour accéder à cette page.", "warning")
return redirect(url_for("login"))
return f(*args, **kwargs)
return decorated
# ── Routes ─────────────────────────────────────────────────────────────────────
@app.route("/")
def index():
if "username" in session:
return redirect(url_for("predict"))
return redirect(url_for("login"))
@app.route("/login", methods=["GET", "POST"])
def login():
if "username" in session:
return redirect(url_for("predict"))
if request.method == "POST":
username = request.form.get("username", "").strip()
password = request.form.get("password", "")
db = get_db()
user = db.execute(
"SELECT * FROM users WHERE username = ? AND password = ?",
(username, hash_pw(password))
).fetchone()
if user:
session["username"] = username
session["user_id"] = user["id"]
# Enregistrer la connexion dans la table sessions
db.execute(
"INSERT INTO sessions (user_id, login_at) VALUES (?, ?)",
(user["id"], datetime.now().isoformat())
)
cursor = db.execute("SELECT last_insert_rowid()")
session["session_db_id"] = cursor.fetchone()[0]
db.commit()
flash(f"Bienvenue, {username} ! 👋", "success")
return redirect(url_for("predict"))
flash("Nom d'utilisateur ou mot de passe incorrect.", "danger")
return render_template("login.html")
@app.route("/register", methods=["GET", "POST"])
def register():
if "username" in session:
return redirect(url_for("predict"))
if request.method == "POST":
username = request.form.get("username", "").strip()
email = request.form.get("email", "").strip().lower()
password = request.form.get("password", "")
confirm = request.form.get("confirm", "")
if not username or not email or not password:
flash("Tous les champs sont obligatoires.", "danger")
elif "@" not in email or "." not in email:
flash("Adresse email invalide.", "danger")
elif password != confirm:
flash("Les mots de passe ne correspondent pas.", "danger")
elif len(password) < 6:
flash("Le mot de passe doit comporter au moins 6 caractères.", "danger")
else:
db = get_db()
try:
db.execute(
"INSERT INTO users (username, email, password, created) VALUES (?, ?, ?, ?)",
(username, email, hash_pw(password), datetime.now().isoformat())
)
db.commit()
flash("Compte créé ! Vous pouvez vous connecter.", "success")
return redirect(url_for("login"))
except sqlite3.IntegrityError as e:
if "username" in str(e):
flash("Ce nom d'utilisateur est déjà pris.", "danger")
else:
flash("Cette adresse email est déjà utilisée.", "danger")
return render_template("register.html")
@app.route("/logout")
def logout():
# Enregistrer l'heure de déconnexion dans sessions
session_db_id = session.get("session_db_id")
if session_db_id:
db = get_db()
db.execute(
"UPDATE sessions SET logout_at = ? WHERE id = ?",
(datetime.now().isoformat(), session_db_id)
)
db.commit()
session.clear()
flash("Vous avez été déconnecté.", "info")
return redirect(url_for("login"))
@app.route("/predict", methods=["GET", "POST"])
@login_required
def predict():
result = None
form_data = {}
if request.method == "POST":
try:
sepal_length = float(request.form["sepal_length"])
sepal_width = float(request.form["sepal_width"])
petal_length = float(request.form["petal_length"])
petal_width = float(request.form["petal_width"])
if not (0 < sepal_length < 20 and 0 < sepal_width < 20 and
0 < petal_length < 20 and 0 < petal_width < 20):
flash("Entrez des mesures réalistes (0–20 cm).", "warning")
return redirect(url_for("predict"))
form_data = {
"sepal_length": sepal_length, "sepal_width": sepal_width,
"petal_length": petal_length, "petal_width": petal_width,
}
X = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
X_scaled = scaler.transform(X)
pred_idx = model.predict(X_scaled)[0]
proba = model.predict_proba(X_scaled)[0]
pred_name = class_names[pred_idx]
confidence = float(proba[pred_idx]) * 100
db = get_db()
db.execute(
"""INSERT INTO predictions
(username, sepal_length, sepal_width, petal_length, petal_width,
prediction, confidence, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(session["username"], sepal_length, sepal_width,
petal_length, petal_width, pred_name, confidence,
datetime.now().isoformat())
)
db.commit()
all_probs = {class_names[i]: round(float(p) * 100, 1) for i, p in enumerate(proba)}
result = {
"name": pred_name,
"confidence": round(confidence, 1),
"info": FLOWER_INFO.get(pred_name, {}),
"all_probs": all_probs,
}
except (ValueError, KeyError):
flash("Entrée invalide — veuillez saisir des valeurs numériques.", "danger")
db = get_db()
history = db.execute(
"SELECT * FROM predictions WHERE username = ? ORDER BY id DESC LIMIT 5",
(session["username"],)
).fetchall()
return render_template("predict.html",
result=result,
form_data=form_data,
history=history,
username=session["username"])
# ── Entry point ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
init_db()
port = int(os.environ.get("PORT", 5000))
app.run(host="0.0.0.0", port=port, debug=False)
# ── Admin dashboard (temps réel, refresh auto) ────────────────────────────────
@app.route("/admin")
@login_required
def admin():
# Seulement l'admin peut accéder
if session.get("username") != "admin":
flash("Accès réservé à l'administrateur.", "danger")
return redirect(url_for("predict"))
db = get_db()
users = db.execute(
"SELECT id, username, email, created FROM users ORDER BY id DESC"
).fetchall()
sessions_log = db.execute("""
SELECT s.id, u.username, u.email, s.login_at, s.logout_at
FROM sessions s
JOIN users u ON s.user_id = u.id
ORDER BY s.id DESC
LIMIT 50
""").fetchall()
stats = {
"total_users": db.execute("SELECT COUNT(*) FROM users").fetchone()[0],
"total_sessions": db.execute("SELECT COUNT(*) FROM sessions").fetchone()[0],
"active_now": db.execute("SELECT CO