ColinceTatsa commited on
Commit
4bda154
·
verified ·
1 Parent(s): 835d03f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +210 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,212 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from sklearn.metrics import mean_squared_error
6
+ from datetime import timedelta
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+
11
+ # --- Configuration de la page ---
12
+ # Doit être la première commande Streamlit du script
13
+ st.set_page_config(
14
+ page_title="Prédiction Boursière GRU",
15
+ page_icon="📈",
16
+ layout="wide"
17
+ )
18
+
19
+ # --- Définition des modèles et fonctions (partie non visible dans l'UI) ---
20
+
21
+ # === Définition de l'architecture du modèle GRU ===
22
+ class GRUModel(nn.Module):
23
+ def __init__(self, input_size=1, hidden_layer_size=50, num_layers=2, output_size=1):
24
+ super().__init__()
25
+ self.hidden_layer_size = hidden_layer_size
26
+ self.gru = nn.GRU(input_size, hidden_layer_size, num_layers, batch_first=True)
27
+ self.fc = nn.Linear(hidden_layer_size, output_size)
28
+
29
+ def forward(self, input_seq):
30
+ gru_out, _ = self.gru(input_seq)
31
+ predictions = self.fc(gru_out[:, -1])
32
+ return predictions
33
+
34
+ # === Fonctions de chargement et de traitement (avec cache pour la performance) ===
35
+ @st.cache_data
36
+ def load_data(csv_path="action_amd.csv"):
37
+ """Charge les données depuis le fichier CSV et les formate correctement."""
38
+ try:
39
+ df = pd.read_csv(csv_path)
40
+ except FileNotFoundError:
41
+ st.error(f"Erreur : Le fichier '{csv_path}' est introuvable. "
42
+ "Assurez-vous qu'il se trouve dans le même dossier que votre script.")
43
+ st.stop()
44
+ if 'Date' not in df.columns or 'Close' not in df.columns:
45
+ st.error("Le fichier CSV doit contenir les colonnes 'Date' et 'Close'.")
46
+ st.stop()
47
+ df_filtered = df[['Date', 'Close']].copy()
48
+ df_filtered['Date'] = pd.to_datetime(df_filtered['Date'])
49
+ df_renamed = df_filtered.rename(columns={'Date': 'ds', 'Close': 'y'})
50
+ return df_renamed.sort_values(by='ds')
51
+
52
+ @st.cache_resource
53
+ def load_gru_model(path, model_class):
54
+ """Charge le modèle GRU pré-entraîné."""
55
+ model = model_class()
56
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
57
+ model.eval()
58
+ return model
59
+
60
+ def predict_gru(model, df, forecast_days, window_size=20):
61
+ """Effectue des prédictions sur les N prochains jours."""
62
+ data_values = df['y'].values
63
+ predictions = []
64
+ input_seq_np = data_values[-window_size:]
65
+
66
+ for _ in range(forecast_days):
67
+ input_seq_torch = torch.from_numpy(input_seq_np).float().view(1, window_size, 1)
68
+ with torch.no_grad():
69
+ pred = model(input_seq_torch).item()
70
+ predictions.append(pred)
71
+ input_seq_np = np.append(input_seq_np[1:], pred)
72
+
73
+ last_date = df['ds'].max()
74
+ future_dates = []
75
+ current_date = last_date
76
+ while len(future_dates) < forecast_days:
77
+ current_date += timedelta(days=1)
78
+ if current_date.weekday() < 5:
79
+ future_dates.append(current_date)
80
+
81
+ return pd.DataFrame({'ds': future_dates, 'yhat': predictions})
82
+
83
+ def calculate_rmse(y_true, y_pred):
84
+ return np.sqrt(mean_squared_error(y_true, y_pred))
85
+
86
+
87
+ # --- Définition des pages de l'application ---
88
+
89
+ def page_accueil():
90
+ """Affiche la page d'accueil."""
91
+ st.title("Projet de Prédiction de Séries Temporelles avec GRU")
92
+ st.markdown("---")
93
+
94
+ col1, col2 = st.columns([1, 3])
95
+ with col1:
96
+ try:
97
+ logo_keyce = Image.open("Keyce_Logo.jpg")
98
+ st.image(logo_keyce, width=150)
99
+ except FileNotFoundError:
100
+ st.warning("Logo Keyce Keyce_Logo.jpg non trouvé.")
101
+
102
+ with col2:
103
+ st.header("KEYCE INFORMATIQUE - Master II IA")
104
+ st.subheader("Session Normale de Réseaux de Neurones Récurrents (RNN)")
105
+
106
+ st.markdown("---")
107
+
108
+ st.header("Présentation de l'étudiant")
109
+ st.markdown("### **Nom :** TATSA TCHINDA Colince")
110
+
111
+ st.info("Utilisez le menu de navigation à gauche pour accéder à la page de prédiction.")
112
+
113
+ def page_prediction():
114
+ """Affiche la page de prédiction et ses résultats."""
115
+ st.title("📈 Prédiction du Cours de l'Action AMD")
116
+
117
+ # --- Étape 1: Chargement des données ---
118
+ st.header("Étape 1 : Chargement et Visualisation des Données")
119
+ with st.spinner("Chargement des données historiques..."):
120
+ data = load_data()
121
+ st.success("Données chargées avec succès !")
122
+
123
+ fig, ax = plt.subplots(figsize=(12, 5))
124
+ ax.plot(data['ds'], data['y'], label="Historique", color='black')
125
+ ax.set_xlabel("Date")
126
+ ax.set_ylabel("Prix de clôture ($)")
127
+ ax.set_title("Cours historique de l'action AMD")
128
+ ax.grid(True)
129
+ ax.legend()
130
+ st.pyplot(fig)
131
+
132
+ # --- Étape 2: Chargement du modèle ---
133
+ st.header("Étape 2 : Chargement du Modèle GRU")
134
+ with st.spinner("Chargement du modèle pré-entraîné..."):
135
+ try:
136
+ gru_model = load_gru_model("model_gru.pth", GRUModel)
137
+ st.success("Modèle GRU chargé avec succès !")
138
+ except FileNotFoundError:
139
+ st.error("Erreur : Le fichier 'model_gru.pth' est introuvable.")
140
+ st.stop()
141
+
142
+ # --- Étape 3: Prédictions ---
143
+ st.header("Étape 3 : Génération des Prédictions")
144
+ WINDOW_SIZE = 20
145
+ FORECAST_DAYS = 21
146
+
147
+ with st.spinner(f"Calcul des prédictions pour les {FORECAST_DAYS} prochains jours..."):
148
+ gru_forecast = predict_gru(gru_model, data, FORECAST_DAYS, window_size=WINDOW_SIZE)
149
+
150
+ st.subheader("Prédictions du modèle GRU vs Historique récent")
151
+ fig, ax = plt.subplots(figsize=(12, 5))
152
+ ax.plot(data['ds'].tail(100), data['y'].tail(100), label="Historique récent", color='black')
153
+ ax.plot(gru_forecast['ds'], gru_forecast['yhat'], label=f"Prédiction GRU ({FORECAST_DAYS} jours)", color='orange', linestyle='--')
154
+ ax.set_xlabel("Date")
155
+ ax.set_ylabel("Prix de clôture ($)")
156
+ ax.set_title("Prédictions GRU vs Historique")
157
+ ax.grid(True)
158
+ ax.legend()
159
+ st.pyplot(fig)
160
+
161
+ # --- Étape 4: Évaluation et Résultats ---
162
+ st.header("Étape 4 : Évaluation et Résultats")
163
+
164
+ st.subheader("Performance du Modèle (Backtesting)")
165
+ true_values = data['y'].values[-WINDOW_SIZE:]
166
+ # ... (Le reste de votre logique de backtesting est correct)
167
+ input_for_backtest_np = data['y'].values[-(WINDOW_SIZE * 2):-WINDOW_SIZE]
168
+ backtest_preds = []
169
+ input_seq_for_backtest = input_for_backtest_np.copy()
170
+ for _ in range(WINDOW_SIZE):
171
+ input_tensor = torch.from_numpy(input_seq_for_backtest[-WINDOW_SIZE:]).float().view(1, WINDOW_SIZE, 1)
172
+ with torch.no_grad():
173
+ pred = gru_model(input_tensor).item()
174
+ backtest_preds.append(pred)
175
+ input_seq_for_backtest = np.append(input_seq_for_backtest, pred)
176
+ gru_rmse = calculate_rmse(true_values, backtest_preds)
177
+
178
+ st.metric(label="RMSE (Backtest sur 20 jours)", value=f"{gru_rmse:.4f}")
179
+ st.info("Le RMSE évalue l'écart moyen entre les valeurs prédites et les valeurs réelles sur la période de test.")
180
+
181
+ st.subheader("Tableau des Prédictions")
182
+ styled_df = gru_forecast.style.format({'yhat': '{:.2f}'})
183
+ st.dataframe(styled_df, use_container_width=True)
184
+
185
+ csv = gru_forecast.to_csv(index=False).encode('utf-8')
186
+ st.download_button(
187
+ label="📥 Télécharger les prédictions (.csv)",
188
+ data=csv,
189
+ file_name="predictions_gru_amd.csv",
190
+ mime="text/csv",
191
+ )
192
+
193
+ # --- Barre latérale et Navigation ---
194
+
195
+ st.sidebar.header("Navigation")
196
+ try:
197
+ logo_theme = Image.open("Theme_Logo.jpg")
198
+ st.sidebar.image(logo_theme, use_container_width=True)
199
+ except FileNotFoundError:
200
+ st.sidebar.warning("Logo thème Theme_Logo.jpg non trouvé.")
201
+
202
+ page = st.sidebar.selectbox(
203
+ "Choisissez une page",
204
+ ["Accueil", "Prédiction"]
205
+ )
206
+
207
+ # --- Affichage de la page sélectionnée ---
208
 
209
+ if page == "Accueil":
210
+ page_accueil()
211
+ elif page == "Prédiction":
212
+ page_prediction()