Stock_Market_AI / custom_tools.py
roger33303's picture
Update custom_tools.py
942d7b5 verified
import os
from langchain_community.tools import DuckDuckGoSearchResults, RedditSearchRun
from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper
from langchain_community.tools.reddit_search.tool import RedditSearchSchema
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from langchain.tools import Tool , tool
from pydantic import BaseModel
from time import sleep
import re
groq_api= os.getenv('GROQ_API_KEY')
Onews_api = os.getenv('NEWS_API')
from newsdataapi import NewsDataApiClient
import yfinance as yf
import pandas as pd
class RedditInput(BaseModel):
query: str
sort: str = "new"
time_filter: str = "week"
subreddit: str = "stocks"
limit: str = "5"
class WebSearchInput(BaseModel):
query: str
class StanderdNewsSearchProtocol(BaseModel):
topic: str
class StockFundamentals(BaseModel):
company_name: str
@tool(args_schema=RedditInput)
def reddit_search_tool(query: str, sort: str, time_filter: str, subreddit: str, limit: str) -> str:
"""
Search Reddit for a given query. Provide query and optionally sort, time_filter, subreddit, and limit.
"""
sleep(1)
try:
search = RedditSearchRun(api_wrapper=RedditSearchAPIWrapper())
search_params = RedditSearchSchema(
query=query,
sort=sort,
time_filter=time_filter,
subreddit=subreddit,
limit=limit
)
result = search.run(tool_input=search_params.model_dump())
except Exception as e:
result = "There was an error in ruuning the tool. try again or skip the tool"
sleep(1)
return result
def resolve_ticker(company_name: str) -> str:
"""
Resolves the correct stock ticker for a given company name using web search.
Example: 'Apple' -> 'AAPL', 'Tesla' -> 'TSLA'
"""
try:
wrapper = DuckDuckGoSearchAPIWrapper(max_results=1)
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
query = f"{company_name} stock ticker site:finance.yahoo.com"
results = search.invoke(query)
match = re.search(r"finance\.yahoo\.com/quote/([^/?]+)", results)
if match:
return match.group(1).strip()
else : return f"Not able to find the correct stocks name for {company_name}. Trying again..."
except :
return "Not able to run the tool successfuly."
@tool(args_schema=StockFundamentals)
def fetch_stock_summary(company_name: str) -> str:
"""
Fetches a comprehensive stock summary including technical indicators, daily stats for the last 4 days,
1-month summary, and quarterly trends.
Args: company_name: Full name of the company.
"""
sleep(1)
try:
ticker = resolve_ticker(company_name=company_name)
stock = yf.Ticker(ticker)
info = stock.info
current_price = info.get("currentPrice", "N/A")
market_cap = info.get("marketCap", "N/A")
pe_ratio = info.get("trailingPE", "N/A")
sector = info.get("sector", "N/A")
industry = info.get("industry", "N/A")
summary = info.get("longBusinessSummary", "N/A")
last_4_days = stock.history(period="5d")
last_4 = last_4_days.tail(4).copy()
daily_info = "\nLast 4 Days:\n"
for date, row in last_4.iterrows():
change = ((row['Close'] - row['Open']) / row['Open']) * 100
daily_info += f"- {date.date()}: Close ${row['Close']:.2f}, Vol: {int(row['Volume'])}, Change: {change:+.2f}%\n"
month_df = stock.history(period="1mo")
avg_close = month_df['Close'].mean()
high_close = month_df['Close'].max()
low_close = month_df['Close'].min()
total_volume = month_df['Volume'].sum()
month_summary = (
f"\n1-Month Summary:\n"
f"- Avg Close: ${avg_close:.2f}\n"
f"- High: ${high_close:.2f} | Low: ${low_close:.2f}\n"
f"- Total Volume: {int(total_volume)}"
)
quarter_df = stock.history(period="3mo")
start_price = quarter_df['Close'].iloc[0]
end_price = quarter_df['Close'].iloc[-1]
pct_change = ((end_price - start_price) / start_price) * 100
high_q = quarter_df['Close'].max()
low_q = quarter_df['Close'].min()
avg_vol_q = quarter_df['Volume'].mean()
quarter_summary = (
f"\nQuarterly Summary (3mo):\n"
f"- Start Price: ${start_price:.2f} | End Price: ${end_price:.2f}\n"
f"- % Change: {pct_change:.2f}%\n"
f"- High: ${high_q:.2f} | Low: ${low_q:.2f}\n"
f"- Avg Volume: {int(avg_vol_q)}"
)
df = month_df.copy()
df['SMA_10'] = df['Close'].rolling(10).mean()
df['EMA_10'] = df['Close'].ewm(span=10).mean()
delta = df['Close'].diff()
gain = delta.where(delta > 0, 0.0)
loss = -delta.where(delta < 0, 0.0)
avg_gain = gain.rolling(window=14).mean()
avg_loss = loss.rolling(window=14).mean()
rs = avg_gain / avg_loss
df['RSI_14'] = 100 - (100 / (1 + rs))
ema_12 = df['Close'].ewm(span=12, adjust=False).mean()
ema_26 = df['Close'].ewm(span=26, adjust=False).mean()
df['MACD'] = ema_12 - ema_26
df['MACD_Signal'] = df['MACD'].ewm(span=9, adjust=False).mean()
df['BB_Middle'] = df['Close'].rolling(20).mean()
df['BB_Upper'] = df['BB_Middle'] + 2 * df['Close'].rolling(20).std()
df['BB_Lower'] = df['BB_Middle'] - 2 * df['Close'].rolling(20).std()
df['ATR_14'] = df[['High', 'Low', 'Close']].apply(lambda x: max(x['High'] - x['Low'], abs(x['High'] - x['Close']), abs(x['Low'] - x['Close'])), axis=1).rolling(14).mean()
df['Volatility'] = df['Close'].pct_change().rolling(14).std()
latest = df.iloc[-1]
indicators = (
f"\nTechnical Indicators:\n"
f"- SMA(10): {latest['SMA_10']:.2f} | EMA(10): {latest['EMA_10']:.2f}\n"
f"- RSI(14): {latest['RSI_14']:.2f}\n"
f"- MACD: {latest['MACD']:.2f} | Signal: {latest['MACD_Signal']:.2f}\n"
f"- Bollinger Bands: Upper={latest['BB_Upper']:.2f}, Lower={latest['BB_Lower']:.2f}\n"
f"- ATR(14): {latest['ATR_14']:.2f}\n"
f"- Volatility (14-day): {latest['Volatility']:.4f}"
)
output = (
f"{ticker.upper()} Summary:\n"
f"- Current Price: ${current_price}\n"
f"- Market Cap: {market_cap}\n"
f"- Sector: {sector} | Industry: {industry}\n"
f"- PE Ratio: {pe_ratio}\n"
f"{daily_info}"
f"{month_summary}"
f"{quarter_summary}"
f"{indicators}"
f"\n\nCompany Overview:\n{summary}"
)
return output
except Exception as e:
return f"Error fetching stock data for {company_name}: {str(e)}"
@tool(args_schema=WebSearchInput)
def web_search(query: str) -> str:
"""
This function allows to search anything on internet. A big query with more details will only give a high quality result.
Args: query: Search query.
"""
sleep(1)
try:
wrapper = DuckDuckGoSearchAPIWrapper(max_results=2)
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
return search.invoke(query)
except:
return "Error in running the tool."
@tool(args_schema=StanderdNewsSearchProtocol)
def tech_news(topic:str) -> str:
"""
Fetches recent UK-based technology news headlines and descriptions from NewsData.io
with a focus on the given topic (matched in the article title).
Args:
topic (str): The keyword to search for in technology news article titles.
Returns:
str: A concatenated string of news summaries with topic-specific tech news.
"""
sleep(1)
try:
client = NewsDataApiClient(apikey=Onews_api,
debug=True,
folder_path="./news_output")
content = client.latest_api(category="technology", language="en", country="gb", size=3,qInTitle=topic)
content = content['results']
tech_news= ""
for i, j in enumerate(content):
full_news = f"tech_news {i+1}: "+ j["description"]
tech_news += full_news
return tech_news
except:
return "There was an error. Can't run the tool"
@tool(args_schema=StanderdNewsSearchProtocol)
def politics_news(topic:str) -> str:
"""
Fetches recent UK-based politics news headlines and descriptions from NewsData.io
with a focus on the given topic (matched in the article title).
Args:
topic (str): The keyword to search for in politics news article titles.
Returns:
str: A concatenated string of news summaries with topic-specific political news.
"""
sleep(1)
try:
client = NewsDataApiClient(apikey=Onews_api,
debug=True,
folder_path="./news_output")
content = client.latest_api(category="politics", language="en", country="gb", size=3,qInTitle=topic)
content = content['results']
p_news= ""
for i, j in enumerate(content):
full_news = f"politics_news {i+1}: "+ j["description"]
p_news += full_news
return p_news
except:
return "There was an error. Can't run the tool"
@tool(args_schema=StanderdNewsSearchProtocol)
def business_news(topic:str) -> str:
"""
Fetches recent UK-based business news headlines and descriptions from NewsData.io
with a focus on the given topic (matched in the article title).
Args:
topic (str): The keyword to search for in business news article titles.
Returns:
str: A concatenated string of news summaries with topic-specific business news.
"""
sleep(1)
try:
client = NewsDataApiClient(apikey=Onews_api,
debug=True,
folder_path="./news_output")
content = client.latest_api(category="business", language="en", country="gb", size=3,qInTitle=topic)
content = content['results']
b_news= ""
for i, j in enumerate(content):
full_news = f"business_news {i+1}: "+ j["description"]
b_news += full_news
return b_news
except:
return "There was an error. Can't run the tool"
@tool(args_schema=StanderdNewsSearchProtocol)
def world_news(topic:str) -> str:
"""
Fetches recent world news headlines related to UK and descriptions from NewsData.io
with a focus on the given topic (matched in the article title).
Args:
topic (str): The keyword to search for in World news article titles.
Returns:
str: A concatenated string of news summaries with topic-specific world news.
"""
sleep(1)
try:
client = NewsDataApiClient(apikey=Onews_api,
debug=True,
folder_path="./news_output")
content = client.latest_api(category="world", language="en", country="gb", size=3,qInTitle=topic)
content = content['results']
w_news= ""
for i, j in enumerate(content):
full_news = f"world_news {i+1}: "+ j["description"]
w_news += full_news
return w_news
except:
return "There was an error. Can't run the tool"
stock_data_tool = Tool(
name="Stock Market Data",
func=fetch_stock_summary,
description=(
"Use this tool to get current stock market data like price, market cap, and historical trend for a specific Company. (e.g., apple or APPLE, NVIDIA or nvidia, TESLA or tesla)."
"Args: company_name (str): the name of the company (e.g., 'Tesla')"
)
)
web_search = Tool(
name="Web Search",
func=web_search,
description=(
"Use this tool to Search and get any general information from the Internet about the stock. This tool takes a query and returns the result."
"For high Quality results provide a good length query with as much details as posible."
)
)
reddit_search_tool = Tool(
name="Reddit Search",
func=reddit_search_tool,
description=(
"Use this tool to search Reddit for recent discussions and sentiments about a stock, event, or topic."
"Input should be a search query (e.g., 'Do you like tesla?', 'what do you think about Tesla products?' , 'Tesla is a scam')."
"Args: query (str): The search query (e.g., 'Tesla stock'). sort (str): Sort order ('new', 'hot', etc.). Defaults to 'new'. time_filter (str): Time range ('hour', 'day', 'week', 'month', 'year', 'all'). Defaults to 'week'. subreddit (str): type of subreddit ('stocks', 'products', 'car', 'bikes'). limit (str): Maximum number of results to return. Defaults to '10'."
)
)
tech_news_tool = Tool(
name="Technology News Search",
func=tech_news,
description=("Use this tool to get the latest technology news articles from the UK that match a topic (e.g., AI, robotics, fintech, Apple, Meta, Tesla).")
)
politics_news_tool = Tool(
name="Politics News Search",
func=politics_news,
description=("Use this tool to get the latest politicial news articles from the UK that match a topic (e.g., AI, robotics, fintech, Apple, Meta, Tesla).")
)
business_news_tool = Tool(
name="Business News Search",
func=business_news,
description=("Use this tool to get the latest Business news articles from the UK that match a topic (e.g., AI, robotics, fintech, Apple, Meta, Tesla).")
)
world_news_tool = Tool(
name="World News Search",
func=world_news,
description=("Use this tool to get the latest World news (geopolitical) articles from the UK that match a topic (e.g., AI, robotics, fintech, Apple, Meta, Tesla).")
)
def get_tools():
return [
stock_data_tool,
reddit_search_tool,
web_search,
tech_news_tool,
business_news_tool,
politics_news_tool,
world_news_tool
]