ahudock commited on
Commit
514d93c
·
1 Parent(s): ce5dc13

Convert to simiple tools that return one scalar value, along with a helper function to improve efficiency

Browse files
Files changed (1) hide show
  1. app.py +78 -97
app.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  import pytz
5
  import yaml
6
  import yfinance as yf
7
- from typing import Dict, Any
8
  from tools.final_answer import FinalAnswerTool
9
 
10
  from Gradio_UI import GradioUI
@@ -40,132 +40,113 @@ def get_current_time_in_timezone(timezone: str) -> str:
40
  return f"Error fetching time for timezone '{timezone}': {str(e)}"
41
 
42
  @tool
43
- def yf_get_ticker_price(ticker: str) -> float:
44
  """
45
- Return the latest closing price for a stock ticker.
46
 
47
  Args:
48
- ticker: A string representing a stock ticker.
49
- """
50
- stock = yf.Ticker(ticker)
51
-
52
- # Get current underlying price
53
- hist = stock.history(period="1d")
54
- if hist.empty:
55
- return {"error": f"No price data found for {ticker}."}
56
 
57
- return float(hist["Close"].iloc[-1])
 
 
 
 
58
 
59
  @tool
60
- def yf_get_atm_straddle_price(ticker: str, dte: int = 0) -> Dict[str, Any]:
61
  """
62
- Calculate the price of the at-the-money (ATM) straddle for a given stock ticker,
63
- including the implied volatility (IV) of the ATM call and put options. The ATM
64
- strike is defined as the strike closest to the current underlying price. The
65
- expiration date is selected based on the specified days-to-expiration (DTE).
66
 
67
  Args:
68
- ticker (str): The stock ticker symbol, e.g., AAPL.
69
- dte (int): Desired days-to-expiration. The tool selects the expiration date
70
- closest to this value. A value of 0 selects the nearest available
71
- expiration. Defaults to 0.
72
 
73
  Returns:
74
- Dict[str, Any]: A dictionary containing:
75
- - 'underlying_price' (float): The current stock price.
76
- - 'expiration' (str): The selected option expiration date.
77
- - 'strike' (float): The ATM strike price selected.
78
- - 'call_price' (float): The last traded price of the ATM call option.
79
- - 'put_price' (float): The last traded price of the ATM put option.
80
- - 'call_iv' (float): The implied volatility of the ATM call option.
81
- - 'put_iv' (float): The implied volatility of the ATM put option.
82
- - 'straddle_price' (float): The total ATM straddle price (call + put).
83
  """
84
- stock = yf.Ticker(ticker)
85
-
86
- # Get current underlying price
87
- underlying_price = yf_get_ticker_price(stock)
88
 
89
- # Get option expiration dates
90
- expirations = stock.options
91
- if not expirations:
92
- return {"error": f"No options data available for {ticker}."}
93
-
94
- # Convert expiration strings to datetime objects
95
- exp_dates = [datetime.strptime(exp, "%Y-%m-%d") for exp in expirations]
96
- today = datetime.now()
97
 
98
- # Compute DTE for each expiration
99
- dtes = [(exp - today).days for exp in exp_dates]
 
100
 
101
- # Select expiration closest to requested DTE
102
- target_idx = min(range(len(dtes)), key=lambda i: abs(dtes[i] - dte))
103
- expiration = expirations[target_idx]
 
 
104
 
105
- # Load option chain for selected expiration
106
- opt_chain = stock.option_chain(expiration)
107
- calls = opt_chain.calls
108
- puts = opt_chain.puts
109
 
110
- if calls.empty or puts.empty:
111
- return {"error": f"Options chain incomplete for {ticker} on {expiration}."}
 
112
 
113
- # Find ATM strike
114
- calls["diff"] = (calls["strike"] - underlying_price).abs()
115
- atm_strike = float(calls.sort_values("diff").iloc[0]["strike"])
 
 
116
 
117
- # Extract ATM call and put rows
118
- call_row = calls[calls["strike"] == atm_strike].iloc[0]
119
- put_row = puts[puts["strike"] == atm_strike].iloc[0]
 
120
 
121
- call_price = float(call_row.get("lastPrice", 0.0))
122
- put_price = float(put_row.get("lastPrice", 0.0))
 
123
 
124
- call_iv = float(call_row.get("impliedVolatility", 0.0))
125
- put_iv = float(put_row.get("impliedVolatility", 0.0))
 
 
 
126
 
127
- straddle_price = call_price + put_price
 
 
 
128
 
129
- return {
130
- "underlying_price": underlying_price,
131
- "expiration": expiration,
132
- "strike": atm_strike,
133
- "call_price": call_price,
134
- "put_price": put_price,
135
- "call_iv": call_iv,
136
- "put_iv": put_iv,
137
- "straddle_price": straddle_price,
138
- }
139
 
 
 
 
 
 
140
 
141
  @tool
142
- def estimate_expected_move(ticker:str, dte:int = 0) -> Dict[str, Any]:
143
  """
144
- Roughly estimates the expected move for a given stock ticker by adding/subtracting the price of the ATM straddle for
145
- a given expiration to the underlying price.
146
 
147
  Args:
148
- ticker (str): The stock ticker to estimate the expected move for.
149
- dte (int): Days to expiration.
150
 
151
  Returns:
152
- Dict[str, Any]: A dictionary containing:
153
- - 'underlying_price' (float): The current stock price.
154
- - 'em' (float): The expected move (price of the ATM straddle).
155
- - 'em_lower' (float): The lower expected move.
156
- - 'em_upper' (float): The upper expected move.
157
- """
158
- atm_straddle = yf_get_atm_straddle_price(ticker, dte)
159
- underlying_price = yf_get_ticker_price(ticker)
160
- em_upper = underlying_price + atm_straddle
161
- em_lower = underlying_price - atm_straddle
162
-
163
- return {
164
- "underlying_price": underlying_price,
165
- "em": atm_straddle,
166
- "em_upper": em_upper,
167
- "em_lower": em_lower,
168
- }
169
 
170
  ddgs = DuckDuckGoSearchTool();
171
 
@@ -190,7 +171,7 @@ with open("prompts.yaml", 'r') as stream:
190
 
191
  agent = CodeAgent(
192
  model=model,
193
- tools=[final_answer, ddgs, yf_get_ticker_price, yf_get_atm_straddle_price, estimate_expected_move], ## add your tools here (don't remove final answer)
194
  max_steps=6,
195
  verbosity_level=1,
196
  grammar=None,
 
4
  import pytz
5
  import yaml
6
  import yfinance as yf
7
+ from typing import Tuple, Dict, Any
8
  from tools.final_answer import FinalAnswerTool
9
 
10
  from Gradio_UI import GradioUI
 
40
  return f"Error fetching time for timezone '{timezone}': {str(e)}"
41
 
42
  @tool
43
+ def yf_underlying_price(ticker: str, dte: int = 0) -> float:
44
  """
45
+ Get the current underlying stock price for the given ticker.
46
 
47
  Args:
48
+ ticker (str): The stock ticker symbol, e.g., AAPL.
49
+ dte (int): Days-to-expiration used to select the option chain.
50
+ Defaults to 0 (nearest expiration).
 
 
 
 
 
51
 
52
+ Returns:
53
+ float: The current underlying stock price.
54
+ """
55
+ underlying_price, _, _, _, _ = _get_atm_option_data(ticker, dte)
56
+ return underlying_price
57
 
58
  @tool
59
+ def yf_atm_strike(ticker: str, dte: int = 0) -> float:
60
  """
61
+ Get the at-the-money (ATM) strike price for the given ticker.
 
 
 
62
 
63
  Args:
64
+ ticker (str): The stock ticker symbol.
65
+ dte (int): Desired days-to-expiration. Defaults to 0.
 
 
66
 
67
  Returns:
68
+ float: The ATM strike price.
 
 
 
 
 
 
 
 
69
  """
70
+ _, _, atm_strike, _, _ = _get_atm_option_data(ticker, dte)
71
+ return atm_strike
 
 
72
 
73
+ @tool
74
+ def yf_atm_call_price(ticker: str, dte: int = 0) -> float:
75
+ """
76
+ Get the last traded price of the ATM call option.
 
 
 
 
77
 
78
+ Args:
79
+ ticker (str): The stock ticker symbol.
80
+ dte (int): Desired days-to-expiration. Defaults to 0.
81
 
82
+ Returns:
83
+ float: The ATM call option price.
84
+ """
85
+ _, _, _, call_row, _ = _get_atm_option_data(ticker, dte)
86
+ return float(call_row.get("lastPrice", 0.0))
87
 
88
+ @tool
89
+ def yf_atm_put_price(ticker: str, dte: int = 0) -> float:
90
+ """
91
+ Get the last traded price of the ATM put option.
92
 
93
+ Args:
94
+ ticker (str): The stock ticker symbol.
95
+ dte (int): Desired days-to-expiration. Defaults to 0.
96
 
97
+ Returns:
98
+ float: The ATM put option price.
99
+ """
100
+ _, _, _, _, put_row = _get_atm_option_data(ticker, dte)
101
+ return float(put_row.get("lastPrice", 0.0))
102
 
103
+ @tool
104
+ def yf_atm_call_iv(ticker: str, dte: int = 0) -> float:
105
+ """
106
+ Get the implied volatility (IV) of the ATM call option.
107
 
108
+ Args:
109
+ ticker (str): The stock ticker symbol.
110
+ dte (int): Desired days-to-expiration. Defaults to 0.
111
 
112
+ Returns:
113
+ float: The implied volatility of the ATM call option.
114
+ """
115
+ _, _, _, call_row, _ = _get_atm_option_data(ticker, dte)
116
+ return float(call_row.get("impliedVolatility", 0.0))
117
 
118
+ @tool
119
+ def yf_atm_put_iv(ticker: str, dte: int = 0) -> float:
120
+ """
121
+ Get the implied volatility (IV) of the ATM put option.
122
 
123
+ Args:
124
+ ticker (str): The stock ticker symbol.
125
+ dte (int): Desired days-to-expiration. Defaults to 0.
 
 
 
 
 
 
 
126
 
127
+ Returns:
128
+ float: The implied volatility of the ATM put option.
129
+ """
130
+ _, _, _, _, put_row = _get_atm_option_data(ticker, dte)
131
+ return float(put_row.get("impliedVolatility", 0.0))
132
 
133
  @tool
134
+ def yf_atm_straddle_price(ticker: str, dte: int = 0) -> float:
135
  """
136
+ Calculate the price of the ATM straddle (call + put).
 
137
 
138
  Args:
139
+ ticker (str): The stock ticker symbol.
140
+ dte (int): Desired days-to-expiration. Defaults to 0.
141
 
142
  Returns:
143
+ float: The total ATM straddle price.
144
+ """
145
+ _, _, _, call_row, put_row = _get_atm_option_data(ticker, dte)
146
+ call_price = float(call_row.get("lastPrice", 0.0))
147
+ put_price = float(put_row.get("lastPrice", 0.0))
148
+ return call_price + put_price
149
+
 
 
 
 
 
 
 
 
 
 
150
 
151
  ddgs = DuckDuckGoSearchTool();
152
 
 
171
 
172
  agent = CodeAgent(
173
  model=model,
174
+ tools=[final_answer, ddgs, yf_underlying_price, yf_atm_strike, yf_atm_call_price, yf_atm_put_price, yf_atm_call_iv, yf_atm_put_iv, yf_atm_straddle_price], ## add your tools here (don't remove final answer)
175
  max_steps=6,
176
  verbosity_level=1,
177
  grammar=None,