Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,9 @@ from sklearn.model_selection import train_test_split
|
|
| 7 |
import gradio as gr
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from datetime import datetime, timedelta
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Define stock tickers
|
| 12 |
STOCK_TICKERS = [
|
|
@@ -130,14 +133,34 @@ def buy_or_sell(current_price: float, predicted_price: float) -> str:
|
|
| 130 |
else:
|
| 131 |
return "Sell"
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def stock_prediction_app(ticker: str, start_date: str, end_date: str):
|
| 134 |
"""
|
| 135 |
Main function to handle stock prediction and return outputs.
|
| 136 |
|
| 137 |
Parameters:
|
| 138 |
- ticker (str): Selected stock ticker.
|
| 139 |
-
- start_date (str): Training start date.
|
| 140 |
-
- end_date (str): Training end date.
|
| 141 |
|
| 142 |
Returns:
|
| 143 |
- percentage_change (str): Percentage change from start to end date.
|
|
@@ -146,9 +169,23 @@ def stock_prediction_app(ticker: str, start_date: str, end_date: str):
|
|
| 146 |
- decision (str): Buy or Sell decision.
|
| 147 |
- plot (matplotlib.figure.Figure): Plot of historical prices with tomorrow's prediction.
|
| 148 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
# Fetch data
|
| 150 |
data = fetch_stock_data(ticker, start_date, end_date)
|
| 151 |
-
|
| 152 |
if data.empty:
|
| 153 |
return "N/A", "N/A", "N/A", "No Data Available", None
|
| 154 |
|
|
@@ -158,7 +195,7 @@ def stock_prediction_app(ticker: str, start_date: str, end_date: str):
|
|
| 158 |
percentage_change = ((end_price - start_price) / start_price) * 100
|
| 159 |
highest_price = data['Close'].max()
|
| 160 |
lowest_price = data['Close'].min()
|
| 161 |
-
|
| 162 |
# Preprocess data
|
| 163 |
try:
|
| 164 |
X, y = preprocess_data(data)
|
|
@@ -167,33 +204,33 @@ def stock_prediction_app(ticker: str, start_date: str, end_date: str):
|
|
| 167 |
|
| 168 |
if len(X) == 0:
|
| 169 |
return f"{percentage_change:.2f}%", highest_price, lowest_price, "No Prediction", None
|
| 170 |
-
|
| 171 |
# Train the model
|
| 172 |
try:
|
| 173 |
model = train_model(X, y)
|
| 174 |
except Exception as e:
|
| 175 |
return f"Error in training model: {e}", highest_price, lowest_price, "Error", None
|
| 176 |
-
|
| 177 |
# Make prediction
|
| 178 |
try:
|
| 179 |
predicted_price = make_prediction(model, data)
|
| 180 |
except Exception as e:
|
| 181 |
return f"Error in making prediction: {e}", highest_price, lowest_price, "Error", None
|
| 182 |
-
|
| 183 |
# Current price is the last closing price
|
| 184 |
current_price = data['Close'].iloc[-1]
|
| 185 |
decision = buy_or_sell(current_price, predicted_price)
|
| 186 |
-
|
| 187 |
# Plotting historical prices and predicted tomorrow's price
|
| 188 |
plt.figure(figsize=(10,5))
|
| 189 |
plt.plot(data['Close'], label='Historical Close Price')
|
| 190 |
-
|
| 191 |
# Add predicted price for tomorrow
|
| 192 |
tomorrow_date = data.index[-1] + timedelta(days=1)
|
| 193 |
# Ensure tomorrow is a business day
|
| 194 |
while tomorrow_date.weekday() >= 5: # Saturday=5, Sunday=6
|
| 195 |
tomorrow_date += timedelta(days=1)
|
| 196 |
-
|
| 197 |
plt.scatter(tomorrow_date, predicted_price, color='red', label='Predicted Close Price (Tomorrow)')
|
| 198 |
plt.title(f'{ticker} Price Prediction for Tomorrow')
|
| 199 |
plt.xlabel('Date')
|
|
@@ -202,8 +239,30 @@ def stock_prediction_app(ticker: str, start_date: str, end_date: str):
|
|
| 202 |
plt.tight_layout()
|
| 203 |
fig = plt.gcf()
|
| 204 |
plt.close()
|
| 205 |
-
|
| 206 |
# Formatting outputs
|
| 207 |
percentage_change_str = f"{percentage_change:.2f}%"
|
| 208 |
-
|
| 209 |
return percentage_change_str, highest_price, lowest_price, decision, fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from datetime import datetime, timedelta
|
| 10 |
+
import joblib
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
|
| 14 |
# Define stock tickers
|
| 15 |
STOCK_TICKERS = [
|
|
|
|
| 133 |
else:
|
| 134 |
return "Sell"
|
| 135 |
|
| 136 |
+
def validate_date_format(date_text: str) -> bool:
|
| 137 |
+
"""
|
| 138 |
+
Validates that the input string is a date in 'YYYY-MM-DD' format.
|
| 139 |
+
|
| 140 |
+
Parameters:
|
| 141 |
+
- date_text (str): Date string to validate.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
- bool: True if valid, False otherwise.
|
| 145 |
+
"""
|
| 146 |
+
# Regular expression for YYYY-MM-DD format
|
| 147 |
+
regex = r'^\d{4}-\d{2}-\d{2}$'
|
| 148 |
+
if re.match(regex, date_text):
|
| 149 |
+
try:
|
| 150 |
+
datetime.strptime(date_text, '%Y-%m-%d')
|
| 151 |
+
return True
|
| 152 |
+
except ValueError:
|
| 153 |
+
return False
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
def stock_prediction_app(ticker: str, start_date: str, end_date: str):
|
| 157 |
"""
|
| 158 |
Main function to handle stock prediction and return outputs.
|
| 159 |
|
| 160 |
Parameters:
|
| 161 |
- ticker (str): Selected stock ticker.
|
| 162 |
+
- start_date (str): Training start date in 'YYYY-MM-DD' format.
|
| 163 |
+
- end_date (str): Training end date in 'YYYY-MM-DD' format.
|
| 164 |
|
| 165 |
Returns:
|
| 166 |
- percentage_change (str): Percentage change from start to end date.
|
|
|
|
| 169 |
- decision (str): Buy or Sell decision.
|
| 170 |
- plot (matplotlib.figure.Figure): Plot of historical prices with tomorrow's prediction.
|
| 171 |
"""
|
| 172 |
+
# Validate date formats
|
| 173 |
+
if not (validate_date_format(start_date) and validate_date_format(end_date)):
|
| 174 |
+
return "Invalid date format. Please use YYYY-MM-DD.", "N/A", "N/A", "Error", None
|
| 175 |
+
|
| 176 |
+
# Convert strings to datetime objects
|
| 177 |
+
try:
|
| 178 |
+
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
|
| 179 |
+
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
|
| 180 |
+
except ValueError:
|
| 181 |
+
return "Invalid date values. Please ensure dates are correct.", "N/A", "N/A", "Error", None
|
| 182 |
+
|
| 183 |
+
if start_dt >= end_dt:
|
| 184 |
+
return "Start date must be before end date.", "N/A", "N/A", "Error", None
|
| 185 |
+
|
| 186 |
# Fetch data
|
| 187 |
data = fetch_stock_data(ticker, start_date, end_date)
|
| 188 |
+
|
| 189 |
if data.empty:
|
| 190 |
return "N/A", "N/A", "N/A", "No Data Available", None
|
| 191 |
|
|
|
|
| 195 |
percentage_change = ((end_price - start_price) / start_price) * 100
|
| 196 |
highest_price = data['Close'].max()
|
| 197 |
lowest_price = data['Close'].min()
|
| 198 |
+
|
| 199 |
# Preprocess data
|
| 200 |
try:
|
| 201 |
X, y = preprocess_data(data)
|
|
|
|
| 204 |
|
| 205 |
if len(X) == 0:
|
| 206 |
return f"{percentage_change:.2f}%", highest_price, lowest_price, "No Prediction", None
|
| 207 |
+
|
| 208 |
# Train the model
|
| 209 |
try:
|
| 210 |
model = train_model(X, y)
|
| 211 |
except Exception as e:
|
| 212 |
return f"Error in training model: {e}", highest_price, lowest_price, "Error", None
|
| 213 |
+
|
| 214 |
# Make prediction
|
| 215 |
try:
|
| 216 |
predicted_price = make_prediction(model, data)
|
| 217 |
except Exception as e:
|
| 218 |
return f"Error in making prediction: {e}", highest_price, lowest_price, "Error", None
|
| 219 |
+
|
| 220 |
# Current price is the last closing price
|
| 221 |
current_price = data['Close'].iloc[-1]
|
| 222 |
decision = buy_or_sell(current_price, predicted_price)
|
| 223 |
+
|
| 224 |
# Plotting historical prices and predicted tomorrow's price
|
| 225 |
plt.figure(figsize=(10,5))
|
| 226 |
plt.plot(data['Close'], label='Historical Close Price')
|
| 227 |
+
|
| 228 |
# Add predicted price for tomorrow
|
| 229 |
tomorrow_date = data.index[-1] + timedelta(days=1)
|
| 230 |
# Ensure tomorrow is a business day
|
| 231 |
while tomorrow_date.weekday() >= 5: # Saturday=5, Sunday=6
|
| 232 |
tomorrow_date += timedelta(days=1)
|
| 233 |
+
|
| 234 |
plt.scatter(tomorrow_date, predicted_price, color='red', label='Predicted Close Price (Tomorrow)')
|
| 235 |
plt.title(f'{ticker} Price Prediction for Tomorrow')
|
| 236 |
plt.xlabel('Date')
|
|
|
|
| 239 |
plt.tight_layout()
|
| 240 |
fig = plt.gcf()
|
| 241 |
plt.close()
|
| 242 |
+
|
| 243 |
# Formatting outputs
|
| 244 |
percentage_change_str = f"{percentage_change:.2f}%"
|
| 245 |
+
|
| 246 |
return percentage_change_str, highest_price, lowest_price, decision, fig
|
| 247 |
+
|
| 248 |
+
# Define the Gradio interface
|
| 249 |
+
iface = gr.Interface(
|
| 250 |
+
fn=stock_prediction_app,
|
| 251 |
+
inputs=[
|
| 252 |
+
gr.Dropdown(choices=STOCK_TICKERS, label="Select Stock Ticker"),
|
| 253 |
+
gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", placeholder="e.g., 2020-01-01"),
|
| 254 |
+
gr.Textbox(label="Enter End Date (YYYY-MM-DD)", placeholder="e.g., 2023-12-31")
|
| 255 |
+
],
|
| 256 |
+
outputs=[
|
| 257 |
+
gr.Textbox(label="Percentage Change"),
|
| 258 |
+
gr.Number(label="Highest Closing Price"),
|
| 259 |
+
gr.Number(label="Lowest Closing Price"),
|
| 260 |
+
gr.Textbox(label="Decision (Buy/Sell)"),
|
| 261 |
+
gr.Plot(label="Stock Performance")
|
| 262 |
+
],
|
| 263 |
+
title="Stock Prediction App",
|
| 264 |
+
description="Predict whether to buy or sell a stock based on historical data. Please enter dates in YYYY-MM-DD format."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Launch the interface
|
| 268 |
+
iface.launch()
|