""" Project: Optimizing Stock Trading Strategy With Reinforcement Learning Authors: Amey Thakur & Mega Satish Reference: https://github.com/Amey-Thakur/OPTIMIZING-STOCK-TRADING-STRATEGY-WITH-REINFORCEMENT-LEARNING License: MIT Description: This script contains the Main Application logic served via Streamlit. It loads the pre-trained Q-Learning model (model.pkl), processes user-selected stock data, simulates the trading strategy on unseen data, and visualizes the portfolio performance using interactive Plotly charts. """ import numpy as np import pandas as pd from pandas._libs.missing import NA import streamlit as st import time import plotly.graph_objects as go import pickle as pkl # ========================================== # 1. Data Processing Logic # ========================================== # @st.cache(persist=True) def data_prep(data, name): """ Prepares the dataset for the selected stock ticker. Args: data (pd.DataFrame): The raw dataset. name (str): The specific stock name selected by the user. Returns: pd.DataFrame: A clean dataframe with computed Moving Averages (5-day & 1-day). """ df = pd.DataFrame(data[data['Name'] == name]) df.dropna(inplace=True) df.reset_index(drop=True, inplace=True) # Calculate Moving Averages (Technical Indicators) # These indicators form the basis of the State Space for the RL agent. df['5day_MA'] = df['close'].rolling(5).mean() df['1day_MA'] = df['close'].rolling(1).mean() # Handle initial NaN values df.loc[:4, '5day_MA'] = 0 return df # ========================================== # 2. Agent Logic (Inference) # ========================================== # @st.cache(persist=True) def get_state(long_ma, short_ma, t): """ Determines the current state of the market based on MA crossovers. Returns a tuple (Trend, Position) matching the Q-Table structure used during training. """ if short_ma < long_ma: if t == 1: return (0, 1) # Bearish, Cash else: return (0, 0) # Bearish, Stock elif short_ma > long_ma: if t == 1: return (1, 1) # Bullish, Cash else: return (1, 0) # Bullish, Stock return (0, 1) # Default # @st.cache(persist=True) def trade_t(num_of_stocks, port_value, current_price): """ Checks if a trade (Buy) is financially feasible. """ if num_of_stocks >= 0: if port_value > current_price: return 1 # Can Buy else: return 0 else: if port_value > current_price: return 1 else: return 0 # @st.cache(persist=True) def next_act(state, qtable, epsilon, action=3): """ Decides the next action based on the trained Q-Table. During inference (testing), epsilon is typically 0 (pure exploitation), meaning the agent always chooses the optimal action learned during training. """ if np.random.rand() < epsilon: action = np.random.randint(action) else: action = np.argmax(qtable[state]) return action # @st.cache(persist=True) def test_stock(stocks_test, q_table, invest): """ Runs a simulation of the trading strategy on the selected stock. Args: stocks_test (pd.DataFrame): The stock data to test on. q_table (np.array): The loaded reinforcement learning model. invest (int): Initial investment amount. Returns: list: A time-series list of net worth values over the simulation period. """ num_stocks = 0 epsilon = 0 # No exploration during testing/inference net_worth = [invest] np.random.seed() for dt in range(len(stocks_test)): long_ma = stocks_test.iloc[dt]['5day_MA'] short_ma = stocks_test.iloc[dt]['1day_MA'] close_price = stocks_test.iloc[dt]['close'] # Determine Current State t = trade_t(num_stocks, net_worth[-1], close_price) state = get_state(long_ma, short_ma, t) # Agent chooses action action = next_act(state, q_table, epsilon) if action == 0: # Buy num_stocks += 1 to_append = net_worth[-1] - close_price net_worth.append(np.round(to_append, 1)) elif action == 1: # Sell num_stocks -= 1 to_append = net_worth[-1] + close_price net_worth.append(np.round(to_append, 1)) elif action == 2: # Hold to_append = net_worth[-1] + close_price # Mark-to-market valuation net_worth.append(np.round(to_append, 1)) # Check for next state existence try: next_state = get_state(stocks_test.iloc[dt+1]['5day_MA'], stocks_test.iloc[dt+1]['1day_MA'], t) except: break return net_worth # ========================================== # 3. Streamlit Interface # ========================================== def fun(): # Reading the Dataset # Ensure all_stocks_5yr.csv is in the working directory data = pd.read_csv('all_stocks_5yr.csv') names = list(data['Name'].unique()) names.insert(0, "": stock_df = data_prep(data, stock) # Sidebar Checkbox: Plot Data Trend if st.sidebar.button("Show Stock Trend", key=1): fig = go.Figure() fig.add_trace(go.Scatter( x=stock_df['date'], y=stock_df['close'], mode='lines', name='Stock_Trend', line=dict(color='cyan', width=2) )) fig.update_layout( title='Stock Trend of ' + stock, xaxis_title='Date', yaxis_title='Price ($) ' ) st.plotly_chart(fig, use_container_width=True) # Simple heuristic for trend feedback if stock_df.iloc[500]['close'] > stock_df.iloc[0]['close']: original_title = '

NOTE:
Stock is on a solid upward trend. Investing here might be profitable.

' st.markdown(original_title, unsafe_allow_html=True) else: original_title = '

NOTE:
Stock does not appear to be in a solid uptrend. Better not to invest here; instead, pick different stock.

' st.markdown(original_title, unsafe_allow_html=True) # Sidebar Checkbox: Investment Simulation st.sidebar.subheader("Enter Your Available Initial Investment Fund") invest = st.sidebar.slider('Select a range of values', 1000, 1000000) if st.sidebar.button("Calculate", key=2): # Load Pre-trained Model try: # Using 'model.pkl' as standardized q_table = pkl.load(open('model.pkl', 'rb')) except FileNotFoundError: st.error("Model file 'model.pkl' not found. Please ensure the model is trained.") return # Run Simulation net_worth = test_stock(stock_df, q_table, invest) net_worth = pd.DataFrame(net_worth, columns=['value']) # Plot Results fig = go.Figure() fig.add_trace(go.Scatter( x=net_worth.index, y=net_worth['value'], mode='lines', name='Net_Worth_Trend', line=dict(color='cyan', width=2) )) fig.update_layout( title='Change in Portfolio Value Day by Day', xaxis_title='Number of Days since Feb 2013 ', yaxis_title='Value ($) ' ) st.plotly_chart(fig, use_container_width=True) original_title = '

NOTE:
Increase in your net worth as a result of a model decision.

' st.markdown(original_title, unsafe_allow_html=True) if __name__ == '__main__': fun() # Dummy chart for layout purposes if needed, otherwise optional # chart_data = pd.DataFrame(np.random.randn(20, 3), columns=['a', 'b', 'c'])