StockAnalysisAgent / src /stock_analysis_agent.py
OnurKerimoglu's picture
use the df_hist for FetchForecast
067b5d0 unverified
from typing import Union, Dict, TypedDict, Annotated, List
import dotenv
from IPython.display import Image, display
import json
from langgraph.graph import StateGraph, START # , END
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph.message import add_messages
import logging
from pydantic import BaseModel, Field
from src.fetch_data import FetchData
from src.fetch_forecast import FetchForecast
from src.technical_analysis import TechnicalAnalysis
from src.fundamental_analysis import FundamentalAnalysis
from src.ticker_finder import TickerFinder
@tool
def get_stock_prices(
ticker: str
) -> Union[Dict, str]:
"""
Fetches historical stock price data and technical indicator for a given ticker.
Args:
ticker: str
The stock ticker symbol to fetch data for.
"""
df_hist = FetchData(ticker, fetchperiodinweeks=12).run()
df_past, df_fcst = FetchForecast(ticker, df_hist).run()
df, _ = TechnicalAnalysis(
ticker=ticker,
df_hist=df_hist,
df_past=df_past,
df_fcst=df_fcst,
plot_ta=False,
savefig=False,
debug=False).run()
if df_past is None:
fcst_prices = "Price forecasts could not be obtained"
fcst_returns = "Return forecasts could not be obtained"
else:
df_fcst['Date'] = df_fcst['Date'].astype(str)
fcst_prices = df_fcst[['Date','Close']].to_dict(orient='records')
fcst_returns = df_fcst[['Date','Returns']].to_dict(orient='records')
if df.shape[0] == 0:
hist_prices = "Recent price data could not be obtained"
indicators = "Indicator data could not be obtained"
else:
df['Date'] = df.index.astype(str)
# split the data into price and indicators, and take the last 10 days of data
hist_prices = df[['Date','Close', 'High', 'Low', 'Open', 'Volume']].iloc[-10:,:].to_dict(orient='records')
indicators = df[['VWAP', 'RSI', 'StochOsc', 'MACD', 'MACDsig', 'MACDdif']].iloc[-10:,:].to_dict(orient='records')
if (df_past is None) or (df.shape[0] == 0):
return f"Error fetching technical data for ticker: {ticker}"
else:
return {'recent prices': hist_prices, "forecasted prices": fcst_prices, "forecasted returns": fcst_returns, 'indicators': indicators}
@tool
def get_financial_metrics(
ticker: str
) -> Union[Dict, str]:
"""
Fetches key financial metrics for a given ticker.
Args:
ticker: str
The stock ticker symbol to fetch data for.
"""
dict_fundamentals = FundamentalAnalysis(
ticker=ticker).run()
if len(dict_fundamentals) > 0:
return dict_fundamentals
else:
return f"Error fetching financial metrics for ticker: {ticker}"
class StockAnalysisResponse(BaseModel):
"""Stock Analysis Response Schema"""
stock: str = Field(description="Stock symbol")
price_analysis: str = Field(description="Detailed analysis of stock price trends")
forecast_analysis: str = Field(description="Detailed analysis of stock price forecasts")
technical_analysis: str = Field(description="Detailed analysis of technical indicators")
fundamental_analysis: str = Field(description="Detailed analysis of financial metrics")
final_summary: str = Field(description="Conclusive summary of the analyses above")
recommended_action: str = Field(description="Suggested action based on the above analyses among options: [strong sell, sell, hold, buy, strong buy]")
class StockAnalyst():
def __init__(
self,
debug: bool=False) -> None:
"""
Initialize StockAnalyst object.
Sets up the logger, loads the .env data and builds the agent graph.
Args:
debug : bool, optional, default: False
if True, logger will be set to DEBUG level
"""
# set up logging
if debug:
self.logger_level = logging.DEBUG
else:
self.logger_level = logging.INFO
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=self.logger_level) # filename='TechnicalAnalysis.log',
# load the env variables fom .env file
dotenv.load_dotenv(dotenv.find_dotenv())
# initialize the tickerfinder
self.tickerfinder = TickerFinder()
# build the graph
self.graph = self.build_graph()
self.logger.info('Initialized StockAnalyst object with TickerFinder and built the agent graph.')
def get_prompt(
self,
company: str) -> str:
"""
Generates a stock analysis prompt for a given company.
Args:
company : str
The stock symbol (ticker) of the company to analyze.
Returns: str
A formatted string prompt for stock analysis, which includes
instructions for evaluating the company's performance.
"""
stock_analyst_prompt= """
You are a stock analyst specializing in evaluating the performance of a given company (whose symbol is {company})
based on recent price data and technical indicators as well as financial metrics.
Your task is to provide a comprehensive summaries of price movements, technical and fundamental analysis for a given stock,
and based on the analysis, provide receommended action (see below) for details.
You have access to the following tools:
1. **get_stock_prices**: Retrieves the historical price data, technical indicators like VWAP, RSI, Stochastic Oscillator and MACD metrics, forecasted prices and relative returns for the next 5 business days.
2. **get_financial_metrics**: Retrieves key financial metrics, such as revenue, earnings per share (EPS), price-to-earnings ratio (P/E), and debt-to-equity ratio.
### Your Tasks:
1. **Input Stock Symbol**: use the provided stock symbol to query the tools and gather the relevant information.
2. **Analyze Data**: evaluate the results from the tools
3. **Summarize and Synthesize**: in particular, we need:
a) A summary of recent stock price movements, highlighting final available closing prices.
b) A summary of trends and potential resistance.
c) A summary of technical indicators (e.g., whether the stock is overbought or oversold).
d) A summary of forecasted returns and closing prices for the next 5 business days.
e) A summary of Financial health and performance based on financial metrics.
f) A final, conclusive synthesis that highlights key concerns and strenghts
g) Recommended action among following options:
- strong sell: if there are overwhelmingly bad signals
- sell: if there are some bad signals
- hold: there are either neutral signals, or good signals mixed with bad signals
- buy: there are some good signals
- strong buy: there are overwhelmingly good signals
### Constraints:
- Use only the data provided by the tools.
- If any tool fails to provide data, clearly state that in your summary.
- Try to provide a balanced synthesis based on the data provided by the tools.
- Avoid speculative language; focus on observable data and trends.
- Ensure that your response is objective, concise, and actionable.
"""
return stock_analyst_prompt.format(company=company)
def build_graph(
self
) -> StateGraph:
"""
Builds a state graph for stock analysis using a language model and financial tools.
This function constructs a state graph that processes stock analysis requests.
It defines a state schema, initializes tools for retrieving stock prices and
financial metrics, and binds these tools to a language model. The function
then adds nodes and edges to the graph, representing the sequence of operations
for analyzing stock data and generating analytical messages.
Returns:
StateGraph: A compiled state graph ready to process stock analysis tasks.
"""
class State(TypedDict):
messages: Annotated[list, add_messages]
stock: str
graph_builder = StateGraph(State)
tools = [get_stock_prices, get_financial_metrics]
llm = ChatOpenAI(model='gpt-4o-mini')
llm_with_tool = llm.bind_tools(
tools,
strict=True,
response_format=StockAnalysisResponse)
def stock_analyst(state: State):
messages = [
SystemMessage(content=self.get_prompt(state['stock'])),
] + state['messages']
return {
'messages': llm_with_tool.invoke(messages)
}
graph_builder.add_node('stock_analyst', stock_analyst)
graph_builder.add_edge(START, 'stock_analyst')
graph_builder.add_node(ToolNode(tools))
graph_builder.add_conditional_edges('stock_analyst', tools_condition)
graph_builder.add_edge('tools', 'stock_analyst')
# graph_builder.add_edge('stock_analyst', END)
graph = graph_builder.compile()
return graph
def draw_graph(
self,
graph
) -> None:
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
def get_stock_analyses(
self,
ticker
) -> List[Union[HumanMessage, AIMessage, ToolMessage]]:
"""
Retrieves a list of stock analyses based on a given ticker symbol.
This function interacts with the state graph to stream events related
to stock analysis for the specified ticker. It sends a message asking
"Should I buy this stock?" and collects the resulting messages generated
by the graph, which contain stock suggestions.
Args:
ticker : str
The stock symbol (ticker) of the company to get suggestions for.
Returns:
List[Union[HumanMessage, AIMessage, ToolMessage]]: A list of messages of various types.
"""
events = self.graph.stream(
{
'messages':[('user', 'Should I buy this stock?')],
'stock': ticker
},
stream_mode='values'
)
# run the events and collect the current (last emitted) messages in a list
messages = []
for event in events:
if 'messages' in event:
messages.append(event['messages'][-1])
return messages
def get_formatted_stock_summary(
self,
ticker
) -> str:
"""
Retrieves analyses for a given stock ticker, syntheses information from messages, returns
a markdown formatted string containing a company's name, sector, and a summary of analyses.
Args:
ticker : str
The stock symbol (ticker) of the company to get the formatted summary for.
Returns:
str: A formatted string containing the company name, sector, and a summary of its stock analysis.
"""
messages = self.get_stock_analyses(ticker)
FA_str = messages[-2].model_dump()['content']
summary_str = messages[-1].model_dump()['content']
response_pretty = ''
try:
FA_dict = json.loads(FA_str)
response_pretty += f"**Company Name:** {FA_dict['Company Name']} \n"
response_pretty += f"**Sector:** {FA_dict['Sector']}\n\n"
except Exception as e:
response_pretty += f"**ticker**: {ticker}\n\n"
print(f'Error parsing the Financial Analysis response:\n{e}')
try:
summary_dict = json.loads(summary_str)
for key, value in summary_dict.items():
if key != 'stock':
pretty_key = key.replace('_', ' ').title()
response_pretty += f"**{pretty_key}**: {value}\n\n"
except Exception as e:
response_pretty += f'*An error occured stylizing the response, printing the raw response*:\n{summary_str}'
print(f'Error parsing summary response:\n{e}\n')
return response_pretty
if __name__ == "__main__":
stock_analyst = StockAnalyst(debug=False)
# messages = stock_analyst.get_stock_suggestion('GOOG')
# for message in messages:
# message.pretty_print()
print(stock_analyst.get_formatted_stock_summary('GOOG'))