|
|
import os |
|
|
import json |
|
|
import requests |
|
|
from openai import OpenAI |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
STOCK_API_KEY = os.getenv("STOCK_API_KEY") |
|
|
|
|
|
if not OPENAI_API_KEY: |
|
|
raise ValueError("โ OPENAI_API_KEY not set. Please add it in environment variables.") |
|
|
if not STOCK_API_KEY: |
|
|
print("โ ๏ธ Warning: STOCK_API_KEY not found. Stock lookup will fail until set.") |
|
|
|
|
|
client = OpenAI(api_key=OPENAI_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
def get_latest_stock_price(symbol: str) -> str: |
|
|
import os, requests |
|
|
|
|
|
api_key = os.environ.get("STOCK_API_KEY") |
|
|
if not api_key: |
|
|
return "API key not found. Please set STOCK_API_KEY environment variable." |
|
|
|
|
|
base_url = "https://www.alphavantage.co/query" |
|
|
params = { |
|
|
"function": "TIME_SERIES_DAILY", |
|
|
"symbol": symbol, |
|
|
"apikey": api_key, |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.get(base_url, params=params, timeout=10) |
|
|
if response.status_code != 200: |
|
|
return f"API request failed with status code {response.status_code}" |
|
|
|
|
|
|
|
|
try: |
|
|
data = response.json() |
|
|
except ValueError: |
|
|
return "Invalid response received from API (not JSON)." |
|
|
|
|
|
if "Note" in data: |
|
|
|
|
|
return "API rate limit exceeded. Please wait and try again later." |
|
|
|
|
|
if "Error Message" in data: |
|
|
return f"Invalid symbol or API error: {data['Error Message']}" |
|
|
|
|
|
if "Time Series (Daily)" in data and data["Time Series (Daily)"]: |
|
|
latest_day_data = list(data["Time Series (Daily)"].values())[0] |
|
|
latest_price = latest_day_data.get("4. close") |
|
|
if latest_price: |
|
|
return f"The latest closing price for {symbol} is ${latest_price}" |
|
|
else: |
|
|
return "Price data not available." |
|
|
else: |
|
|
return "No data available for this symbol." |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
return f"Network error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "get_latest_stock_price", |
|
|
"description": "Get the latest stock price for a given symbol", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"symbol": { |
|
|
"type": "string", |
|
|
"description": "The symbol of the stock", |
|
|
}, |
|
|
}, |
|
|
"required": ["symbol"], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
def stock_chat(user_message: str) -> str: |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "You are a stock bot. Answer only finance stock related questions.", |
|
|
}, |
|
|
{"role": "user", "content": user_message}, |
|
|
] |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-3.5-turbo-1106", |
|
|
messages=messages, |
|
|
tools=tools, |
|
|
tool_choice="auto", |
|
|
) |
|
|
|
|
|
msg = response.choices[0].message |
|
|
|
|
|
|
|
|
if hasattr(msg, "tool_calls") and msg.tool_calls: |
|
|
for call in msg.tool_calls: |
|
|
if call.function.name == "get_latest_stock_price": |
|
|
try: |
|
|
args = json.loads(call.function.arguments or "{}") |
|
|
except json.JSONDecodeError: |
|
|
args = {} |
|
|
|
|
|
if "symbol" in args: |
|
|
symbol = args["symbol"] |
|
|
price = get_latest_stock_price(symbol) |
|
|
|
|
|
|
|
|
messages.append( |
|
|
{ |
|
|
"role": "assistant", |
|
|
"tool_calls": [ |
|
|
{ |
|
|
"id": call.id, |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "get_latest_stock_price", |
|
|
"arguments": json.dumps({"symbol": symbol}), |
|
|
}, |
|
|
} |
|
|
], |
|
|
"content": None, |
|
|
} |
|
|
) |
|
|
|
|
|
messages.append( |
|
|
{ |
|
|
"role": "tool", |
|
|
"tool_call_id": call.id, |
|
|
"content": f"The latest stock price for {symbol} is {price}", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
final = client.chat.completions.create( |
|
|
model="gpt-3.5-turbo-1106", |
|
|
messages=messages, |
|
|
) |
|
|
return final.choices[0].message.content or "No reply" |
|
|
else: |
|
|
return "Symbol not found in tool call arguments." |
|
|
else: |
|
|
return "Unknown tool called." |
|
|
else: |
|
|
|
|
|
return msg.content or "No reply" |
|
|
|
|
|
|
|
|
def chatbot_interface(message): |
|
|
return stock_chat(message) |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=chatbot_interface, |
|
|
inputs=gr.Textbox(label="Ask about a stock (e.g., 'What is the latest stock price for AAPL?')"), |
|
|
outputs=gr.Textbox( |
|
|
label="Response", |
|
|
lines=5, |
|
|
max_lines=100, |
|
|
autoscroll=False, |
|
|
show_copy_button=True |
|
|
), |
|
|
title="๐ StockBot - AI Stock Assistant", |
|
|
description="Ask questions about stock prices. Data provided by AlphaVantage and OpenAI GPT-3.5.", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|