File size: 3,633 Bytes
aa3a1a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3543934
aa3a1a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pandas as pd
import gradio as gr
import joblib
import numpy as np
import plotly.express as px
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

# Charger les modèles
try:
    model = joblib.load("C:\\Users\\karballah\\Documents\\APP_Class_Clustering_ML_L3\\lg.joblib")
    model_cluster = joblib.load("C:\\Users\\karballah\\Documents\\APP_Class_Clustering_ML_L3\\kmeans_model.joblib")
    print("Modèles chargés avec succès.")
except FileNotFoundError:
    print("Erreur: Le fichier n'a pas été trouvé.")
    exit(1)

# Fonction de prédiction du diabète
def predict_diabetes(pregnancies, glucose, blood_pressure, skin_thickness, insulin, bmi, dpf, age):
    input_data = np.array([[pregnancies, glucose, blood_pressure, skin_thickness, insulin, bmi, dpf, age]])
    prediction = model.predict(input_data)[0]
    return "Diabétique" if prediction == 1 else "Non diabétique"

# Fonction de visualisation des clusters
def plot_clusters(selected_cluster):
    np.random.seed(42)
    pca_features = np.random.randn(100, 2)
    clusters = np.random.randint(0, 5, size=100)

    pca_df = pd.DataFrame(pca_features, columns=['PC1', 'PC2'])
    pca_df['Cluster'] = clusters

    if selected_cluster == "Tous":
        selected_data = pca_df
    else:
        selected_data = pca_df[pca_df['Cluster'] == int(selected_cluster)]
    
    if selected_data.empty:
        return px.scatter(title="Aucun point à afficher")

    fig = px.scatter(selected_data, x='PC1', y='PC2', color=selected_data['Cluster'].astype(str),
                     title=f"Visualisation du Cluster {selected_cluster}", labels={'color': 'Cluster'})
    
    return fig

# Fonction pour télécharger les clusters en CSV
def download_clusters():
    cluster_data = {
        'PC1': np.random.randn(100),
        'PC2': np.random.randn(100),
        'Cluster': np.random.randint(0, 5, 100)
    }
    df_clusters = pd.DataFrame(cluster_data)
    return df_clusters.to_csv(index=False), "clusters.csv"

# Interface utilisateur avec Gradio
with gr.Blocks() as app:
    gr.Markdown("## Application Machine Learning : Classification et Clustering")
    
    # Section Classification
    gr.Markdown("### Prédiction du Diabète")
    with gr.Row():
        pregnancies = gr.Number(label="Grossesses")
        glucose = gr.Number(label="Glucose")
        blood_pressure = gr.Number(label="Pression artérielle")
    with gr.Row():
        skin_thickness = gr.Number(label="Épaisseur de peau")
        insulin = gr.Number(label="Insuline")
        bmi = gr.Number(label="IMC")
    with gr.Row():
        dpf = gr.Number(label="DPF")
        age = gr.Number(label="Âge")
    
    predict_button = gr.Button("Prédire")
    output_label = gr.Textbox(label="Résultat")
    
    predict_button.click(fn=predict_diabetes, 
                         inputs=[pregnancies, glucose, blood_pressure, skin_thickness, insulin, bmi, dpf, age], 
                         outputs=output_label)
    
    # Section Clustering
    gr.Markdown("### Visualisation des Clusters des Réactions en Ligne")
    cluster_selector = gr.Dropdown(["Tous"] + [str(i) for i in range(5)], label="Sélectionner un cluster")
    cluster_plot = gr.Plot()

    def update_plot(selected_cluster):
        return plot_clusters(selected_cluster)

    cluster_selector.change(fn=update_plot, inputs=[cluster_selector], outputs=[cluster_plot])

    # Téléchargement des clusters
    download_button = gr.Button("Télécharger les clusters")
    download_button.click(fn=download_clusters, outputs=gr.File())
    
app.launch()