OnurKerimoglu commited on
Commit
0975e8a
·
1 Parent(s): 3fee7c6

enabled showing TA plots

Browse files
Files changed (3) hide show
  1. app.py +22 -4
  2. src/stock_analysis_agent.py +2 -1
  3. src/technical_analysis.py +23 -9
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
 
3
  from src.stock_analysis_agent import StockAnalyst
 
4
 
5
 
6
  def initialize_agent():
@@ -14,12 +15,22 @@ def initialize_agent():
14
  message = f"Error initializing Stock-Analysis Agent:\n{e}"
15
  return message, stock_analyst
16
 
17
- def interact_with_agent(agent_instance, ticker):
18
  if agent_instance is None:
19
  return 'Stock-Analysis Agent is not initialized. Please initialize first.'
20
  response = agent_instance.get_formatted_stock_summary(ticker)
21
  return response
22
 
 
 
 
 
 
 
 
 
 
 
23
  with gr.Blocks() as demo:
24
  gr.Markdown("# Stock Analysis Agent")
25
  # State to store the RAG instance
@@ -35,12 +46,19 @@ with gr.Blocks() as demo:
35
  # gr.Text(label='Status', value=message)
36
  gr.Markdown("Enter a stock symbol ('ticker') to be analyzed (GOOG, MSFT, etc.)")
37
  ticker = gr.Textbox(label="ticker")
 
38
  analyze_button = gr.Button('Analyze Stock')
39
- output = gr.Markdown(label="Output Box")
 
 
 
 
 
 
40
  analyze_button.click(
41
- fn=interact_with_agent,
42
  inputs=[agent_instance, ticker],
43
- outputs=output,
44
  api_name="analyze_stock")
45
 
46
 
 
1
  import gradio as gr
2
 
3
  from src.stock_analysis_agent import StockAnalyst
4
+ from src.technical_analysis import TechnicalAnalysis
5
 
6
 
7
  def initialize_agent():
 
15
  message = f"Error initializing Stock-Analysis Agent:\n{e}"
16
  return message, stock_analyst
17
 
18
+ def ask_stock_agent(agent_instance, ticker):
19
  if agent_instance is None:
20
  return 'Stock-Analysis Agent is not initialized. Please initialize first.'
21
  response = agent_instance.get_formatted_stock_summary(ticker)
22
  return response
23
 
24
+ def plot_stock(ticker):
25
+ # run the technical analysis
26
+ _, fig = TechnicalAnalysis(
27
+ ticker=ticker,
28
+ fetchperiodinweeks=12,
29
+ plot_ta=True,
30
+ savefig=False,
31
+ debug=False).run()
32
+ return fig
33
+
34
  with gr.Blocks() as demo:
35
  gr.Markdown("# Stock Analysis Agent")
36
  # State to store the RAG instance
 
46
  # gr.Text(label='Status', value=message)
47
  gr.Markdown("Enter a stock symbol ('ticker') to be analyzed (GOOG, MSFT, etc.)")
48
  ticker = gr.Textbox(label="ticker")
49
+ # plot_button = gr.Button('Plot')
50
  analyze_button = gr.Button('Analyze Stock')
51
+ plot_output = gr.Plot(label=ticker, format="png")
52
+ analyze_button.click(
53
+ fn=plot_stock,
54
+ inputs=[ticker],
55
+ outputs=plot_output,
56
+ api_name="analyze_stock")
57
+ md_output = gr.Markdown(label="Output Box")
58
  analyze_button.click(
59
+ fn=ask_stock_agent,
60
  inputs=[agent_instance, ticker],
61
+ outputs=md_output,
62
  api_name="analyze_stock")
63
 
64
 
src/stock_analysis_agent.py CHANGED
@@ -25,10 +25,11 @@ def get_stock_prices(
25
  ticker: str
26
  The stock ticker symbol to fetch data for.
27
  """
28
- df = TechnicalAnalysis(
29
  ticker=ticker,
30
  fetchperiodinweeks=12,
31
  plot_ta=False,
 
32
  debug=False).run()
33
  if df.shape[0] > 0:
34
  df['Date'] = df.index.astype(str)
 
25
  ticker: str
26
  The stock ticker symbol to fetch data for.
27
  """
28
+ df, _ = TechnicalAnalysis(
29
  ticker=ticker,
30
  fetchperiodinweeks=12,
31
  plot_ta=False,
32
+ savefig=False,
33
  debug=False).run()
34
  if df.shape[0] > 0:
35
  df['Date'] = df.index.astype(str)
src/technical_analysis.py CHANGED
@@ -18,6 +18,7 @@ class TechnicalAnalysis():
18
  ticker:str,
19
  fetchperiodinweeks:int=12,
20
  plot_ta:bool=True,
 
21
  debug=False):
22
  # input arguments
23
  """
@@ -44,6 +45,7 @@ class TechnicalAnalysis():
44
  self.ticker = ticker
45
  self.fetchperiodinweeks = fetchperiodinweeks
46
  self.plot_ta = plot_ta
 
47
  # done initializing
48
  self.logger.info(f'Initialized TechnicalAnalysis object for ticker: {ticker}')
49
 
@@ -65,15 +67,21 @@ class TechnicalAnalysis():
65
  # plot the results
66
  if self.plot_ta:
67
  os.makedirs('plots', exist_ok=True)
68
- self.plot_stock_metrics(
69
  self.df,
70
  datasets={
71
  'Volume': ['Volume'],
72
  'Prices': ['Close', 'VWAP'], # 'High','Low',
73
  'Indices': ['RSI', 'StochOsc'],
74
- 'Trend': ['MACD', 'MACDsig', 'MACDdif']}
75
- )
76
- return self.df
 
 
 
 
 
 
77
 
78
  def fetch_data(
79
  self
@@ -183,8 +191,9 @@ class TechnicalAnalysis():
183
  datasets={
184
  'Volume': ['Volume'],
185
  'Price': ['Close'] # 'High','Low'
186
- }
187
- ) -> None:
 
188
  """
189
  Plots the given stock metrics datasets as subplots.
190
  This method takes in a DataFrame and a dictionary of datasets, where
@@ -199,6 +208,8 @@ class TechnicalAnalysis():
199
  datasets (dict)
200
  A dictionary of datasets, where each key is a dataset name and
201
  the value is a list of column names to be plotted
 
 
202
  """
203
  numax = len(datasets)
204
  fig, axes = plt.subplots(
@@ -214,8 +225,11 @@ class TechnicalAnalysis():
214
  df,
215
  colstoplot)
216
  plt.tight_layout()
217
- plt.savefig(os.path.join('plots', f'{self.ticker}.png'))
218
- plt.close()
 
 
 
219
 
220
  def plot_stock_metrics_ax(
221
  self,
@@ -259,7 +273,7 @@ class TechnicalAnalysis():
259
  ax.xaxis.set_minor_locator(mdates.DayLocator())
260
 
261
  ax.set_title(dataset)
262
- ax.set_xlabel('Date')
263
  ax.set_ylabel(dataset)
264
  if len(colstoplot) > 1:
265
  ax.legend()
 
18
  ticker:str,
19
  fetchperiodinweeks:int=12,
20
  plot_ta:bool=True,
21
+ savefig:bool=False,
22
  debug=False):
23
  # input arguments
24
  """
 
45
  self.ticker = ticker
46
  self.fetchperiodinweeks = fetchperiodinweeks
47
  self.plot_ta = plot_ta
48
+ self.savefig = savefig
49
  # done initializing
50
  self.logger.info(f'Initialized TechnicalAnalysis object for ticker: {ticker}')
51
 
 
67
  # plot the results
68
  if self.plot_ta:
69
  os.makedirs('plots', exist_ok=True)
70
+ fig = self.plot_stock_metrics(
71
  self.df,
72
  datasets={
73
  'Volume': ['Volume'],
74
  'Prices': ['Close', 'VWAP'], # 'High','Low',
75
  'Indices': ['RSI', 'StochOsc'],
76
+ 'Trend': ['MACD', 'MACDsig', 'MACDdif']},
77
+ savefig=self.savefig
78
+ )
79
+ else:
80
+ fig = None
81
+ else:
82
+ fig = None
83
+
84
+ return self.df, fig
85
 
86
  def fetch_data(
87
  self
 
191
  datasets={
192
  'Volume': ['Volume'],
193
  'Price': ['Close'] # 'High','Low'
194
+ },
195
+ savefig=False
196
+ ) -> None:
197
  """
198
  Plots the given stock metrics datasets as subplots.
199
  This method takes in a DataFrame and a dictionary of datasets, where
 
208
  datasets (dict)
209
  A dictionary of datasets, where each key is a dataset name and
210
  the value is a list of column names to be plotted
211
+ savefig (bool)
212
+ Whether to save the figure to a file
213
  """
214
  numax = len(datasets)
215
  fig, axes = plt.subplots(
 
225
  df,
226
  colstoplot)
227
  plt.tight_layout()
228
+ if savefig:
229
+ plt.savefig(os.path.join('plots', f'{self.ticker}.png'))
230
+ plt.close()
231
+ fig = None
232
+ return fig
233
 
234
  def plot_stock_metrics_ax(
235
  self,
 
273
  ax.xaxis.set_minor_locator(mdates.DayLocator())
274
 
275
  ax.set_title(dataset)
276
+ # ax.set_xlabel('Date')
277
  ax.set_ylabel(dataset)
278
  if len(colstoplot) > 1:
279
  ax.legend()