ameythakur's picture
Stock-Trading-RL
9ffa007 unverified
"""
Project: Optimizing Stock Trading Strategy With Reinforcement Learning
Authors: Amey Thakur & Mega Satish
Reference: https://github.com/Amey-Thakur/OPTIMIZING-STOCK-TRADING-STRATEGY-WITH-REINFORCEMENT-LEARNING
License: MIT
Description:
This script contains the Main Application logic served via Streamlit.
It loads the pre-trained Q-Learning model (model.pkl), processes user-selected
stock data, simulates the trading strategy on unseen data, and visualizes
the portfolio performance using interactive Plotly charts.
"""
import numpy as np
import pandas as pd
from pandas._libs.missing import NA
import streamlit as st
import time
import plotly.graph_objects as go
import pickle as pkl
# ==========================================
# 1. Data Processing Logic
# ==========================================
# @st.cache(persist=True)
def data_prep(data, name):
"""
Prepares the dataset for the selected stock ticker.
Args:
data (pd.DataFrame): The raw dataset.
name (str): The specific stock name selected by the user.
Returns:
pd.DataFrame: A clean dataframe with computed Moving Averages (5-day & 1-day).
"""
df = pd.DataFrame(data[data['Name'] == name])
df.dropna(inplace=True)
df.reset_index(drop=True, inplace=True)
# Calculate Moving Averages (Technical Indicators)
# These indicators form the basis of the State Space for the RL agent.
df['5day_MA'] = df['close'].rolling(5).mean()
df['1day_MA'] = df['close'].rolling(1).mean()
# Handle initial NaN values
df.loc[:4, '5day_MA'] = 0
return df
# ==========================================
# 2. Agent Logic (Inference)
# ==========================================
# @st.cache(persist=True)
def get_state(long_ma, short_ma, t):
"""
Determines the current state of the market based on MA crossovers.
Returns a tuple (Trend, Position) matching the Q-Table structure used during training.
"""
if short_ma < long_ma:
if t == 1:
return (0, 1) # Bearish, Cash
else:
return (0, 0) # Bearish, Stock
elif short_ma > long_ma:
if t == 1:
return (1, 1) # Bullish, Cash
else:
return (1, 0) # Bullish, Stock
return (0, 1) # Default
# @st.cache(persist=True)
def trade_t(num_of_stocks, port_value, current_price):
"""
Checks if a trade (Buy) is financially feasible.
"""
if num_of_stocks >= 0:
if port_value > current_price:
return 1 # Can Buy
else: return 0
else:
if port_value > current_price:
return 1
else: return 0
# @st.cache(persist=True)
def next_act(state, qtable, epsilon, action=3):
"""
Decides the next action based on the trained Q-Table.
During inference (testing), epsilon is typically 0 (pure exploitation),
meaning the agent always chooses the optimal action learned during training.
"""
if np.random.rand() < epsilon:
action = np.random.randint(action)
else:
action = np.argmax(qtable[state])
return action
# @st.cache(persist=True)
def test_stock(stocks_test, q_table, invest):
"""
Runs a simulation of the trading strategy on the selected stock.
Args:
stocks_test (pd.DataFrame): The stock data to test on.
q_table (np.array): The loaded reinforcement learning model.
invest (int): Initial investment amount.
Returns:
list: A time-series list of net worth values over the simulation period.
"""
num_stocks = 0
epsilon = 0 # No exploration during testing/inference
net_worth = [invest]
np.random.seed()
for dt in range(len(stocks_test)):
long_ma = stocks_test.iloc[dt]['5day_MA']
short_ma = stocks_test.iloc[dt]['1day_MA']
close_price = stocks_test.iloc[dt]['close']
# Determine Current State
t = trade_t(num_stocks, net_worth[-1], close_price)
state = get_state(long_ma, short_ma, t)
# Agent chooses action
action = next_act(state, q_table, epsilon)
if action == 0: # Buy
num_stocks += 1
to_append = net_worth[-1] - close_price
net_worth.append(np.round(to_append, 1))
elif action == 1: # Sell
num_stocks -= 1
to_append = net_worth[-1] + close_price
net_worth.append(np.round(to_append, 1))
elif action == 2: # Hold
to_append = net_worth[-1] + close_price # Mark-to-market valuation
net_worth.append(np.round(to_append, 1))
# Check for next state existence
try:
next_state = get_state(stocks_test.iloc[dt+1]['5day_MA'], stocks_test.iloc[dt+1]['1day_MA'], t)
except:
break
return net_worth
# ==========================================
# 3. Streamlit Interface
# ==========================================
def fun():
# Reading the Dataset
# Ensure all_stocks_5yr.csv is in the working directory
data = pd.read_csv('all_stocks_5yr.csv')
names = list(data['Name'].unique())
names.insert(0, "<Select Names>")
st.title("Optimizing Stock Trading Strategy With Reinforcement Learning")
st.sidebar.title("Choose Stock and Investment")
st.sidebar.subheader("Choose Company Stocks")
# User Input: Select Stock
stock = st.sidebar.selectbox("(*select one stock only)", names, index=0)
if stock != "<Select Names>":
stock_df = data_prep(data, stock)
# Sidebar Checkbox: Plot Data Trend
if st.sidebar.button("Show Stock Trend", key=1):
fig = go.Figure()
fig.add_trace(go.Scatter(
x=stock_df['date'],
y=stock_df['close'],
mode='lines',
name='Stock_Trend',
line=dict(color='cyan', width=2)
))
fig.update_layout(
title='Stock Trend of ' + stock,
xaxis_title='Date',
yaxis_title='Price ($) '
)
st.plotly_chart(fig, use_container_width=True)
# Simple heuristic for trend feedback
if stock_df.iloc[500]['close'] > stock_df.iloc[0]['close']:
original_title = '<b><p style="font-family:Play; color:Cyan; font-size: 20px;">NOTE:<br>Stock is on a solid upward trend. Investing here might be profitable.</p>'
st.markdown(original_title, unsafe_allow_html=True)
else:
original_title = '<b><p style="font-family:Play; color:Red; font-size: 20px;">NOTE:<br> Stock does not appear to be in a solid uptrend. Better not to invest here; instead, pick different stock.</p>'
st.markdown(original_title, unsafe_allow_html=True)
# Sidebar Checkbox: Investment Simulation
st.sidebar.subheader("Enter Your Available Initial Investment Fund")
invest = st.sidebar.slider('Select a range of values', 1000, 1000000)
if st.sidebar.button("Calculate", key=2):
# Load Pre-trained Model
try:
# Using 'model.pkl' as standardized
q_table = pkl.load(open('model.pkl', 'rb'))
except FileNotFoundError:
st.error("Model file 'model.pkl' not found. Please ensure the model is trained.")
return
# Run Simulation
net_worth = test_stock(stock_df, q_table, invest)
net_worth = pd.DataFrame(net_worth, columns=['value'])
# Plot Results
fig = go.Figure()
fig.add_trace(go.Scatter(
x=net_worth.index,
y=net_worth['value'],
mode='lines',
name='Net_Worth_Trend',
line=dict(color='cyan', width=2)
))
fig.update_layout(
title='Change in Portfolio Value Day by Day',
xaxis_title='Number of Days since Feb 2013 ',
yaxis_title='Value ($) '
)
st.plotly_chart(fig, use_container_width=True)
original_title = '<b><p style="font-family:Play; color:Cyan; font-size: 20px;">NOTE:<br> Increase in your net worth as a result of a model decision.</p>'
st.markdown(original_title, unsafe_allow_html=True)
if __name__ == '__main__':
fun()
# Dummy chart for layout purposes if needed, otherwise optional
# chart_data = pd.DataFrame(np.random.randn(20, 3), columns=['a', 'b', 'c'])