from random import randint, random import gradio as gr import pandas as pd import requests import os import json from openai import OpenAI import matplotlib.pyplot as plt # Flag to indicate MCP server mode mcp_server = True # SEC API settings SEC_API_URL = "https://data.sec.gov/api/xbrl/companyfacts/CIK{}.json" USER_AGENT = os.environ.get("USER_AGENT", "Your Name your.email@example.com") # Sample CIK list CIK_OPTIONS = { "Tesla (TSLA)": "0001318605", "Apple (AAPL)": "0000320193", "Microsoft (MSFT)": "0000789019", "Amazon (AMZN)": "0001018724", "Alphabet (GOOGL)": "0001652044", "Meta Platforms (META)": "0001326801", "NVIDIA (NVDA)": "0001045810", "Berkshire Hathaway (BRK.A)": "0001067983", "JPMorgan Chase (JPM)": "0000019617", "Johnson & Johnson (JNJ)": "0000200406", "Visa (V)": "0001403161", "Procter & Gamble (PG)": "0000080424", "UnitedHealth Group (UNH)": "0000731766", "Home Depot (HD)": "0000354950", "Mastercard (MA)": "0001141391", "Exxon Mobil (XOM)": "0000034088", "Pfizer (PFE)": "0000078003", "Coca-Cola (KO)": "0000021344", "PepsiCo (PEP)": "0000077476", "Walmart (WMT)": "0000104169" } # SambaNova API settings SAMBANOVA_API_URL = "https://api.cloud.sambanova.ai/v1/chat/completions" SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY") # Set in your environment def fetch_comprehensive_income_net_of_tax(cik): """ Fetch 'ComprehensiveIncomeNetOfTax' USD values from SEC 10-Q filings for a given CIK. Args: cik (str): Central Index Key (CIK) of the company. Returns: pd.DataFrame: DataFrame of values and metadata for 'ComprehensiveIncomeNetOfTax'. """ headers = {"User-Agent": USER_AGENT} url = SEC_API_URL.format(cik) print(f"Fetching data from SEC API for CIK: {cik} at URL: {url}") try: response = requests.get(url, headers=headers) data = response.json() # Navigate directly to the desired metric item_data = data.get("facts", {}).get("us-gaap", {}).get("ComprehensiveIncomeNetOfTax", {}) usd_entries = item_data.get("units", {}).get("USD", []) filtered_entries = [ { # "Metric": "ComprehensiveIncomeNetOfTax", "Frame": entry.get("frame"), "Value": entry.get("val"), "Period": f"{entry.get('fy')}{entry.get('fp')}", "Form": entry.get("form"), "Filed": entry.get("filed") } for entry in usd_entries if entry.get("form") == "10-Q" and entry.get("frame") ] return pd.DataFrame(filtered_entries) except requests.RequestException as e: print(f"Error fetching SEC data: {e}") return pd.DataFrame({"Error": [str(e)]}) # Generate response using SambaNova def get_sambanova_response(query, data): context = f"SEC data: {json.dumps(data.to_dict() if not data.empty and 'Error' not in data.columns else {})}. User query: {query}" messages = [ { "role": "system", "content": "You are a financial data assistant. Provide concise answers based on SEC data, including trends or summaries where applicable." }, {"role": "user", "content": context} ] try: sambanova_client = OpenAI( api_key = SAMBANOVA_API_KEY, base_url = "https://api.sambanova.ai/v1", ) response = sambanova_client.chat.completions.create( model = "Llama-4-Maverick-17B-128E-Instruct", messages = messages, temperature = 0.1, top_p = 0.1, ) return response.choices[0].message.content except Exception as e: return f"Error: {str(e)}" # Method to visualize numerical data def visualize_data(data): """ Generate a line plot using matplotlib and use 'Frame' as the x-axis. Args: data (pd.DataFrame): DataFrame containing 'Value' and 'Frame' columns. Returns: tuple: Gradio Plot object and visibility flag. """ if data.empty or "Error" in data.columns: return gr.update(value="No data to visualize"), False df = data.copy() if "Value" not in df.columns or "Frame" not in df.columns: return gr.update(value="Missing 'Value' or 'Frame' in data"), False df["Value"] = pd.to_numeric(df["Value"], errors="coerce") df = df[df["Value"].notna() & df["Frame"].notna()] if df.empty: return gr.update(value="No valid data to plot"), False # Sort frames in lexical order df_sorted = df.sort_values(by="Frame") x = df_sorted["Frame"] y = df_sorted["Value"] return { "plot_data": { "Frame": x, "Value": y }, "plot_visible": True } # MCP server endpoint to handle queries def mcp_query(query_data): query = query_data.get("query", "") cik_name = query_data.get("cik_name", "Apple (AAPL)") cik = CIK_OPTIONS.get(cik_name) print(f"Received query: {query} for CIK: {cik}") if not cik or not query: raise HTTPException(status_code=400, detail="Invalid CIK or query") df = fetch_comprehensive_income_net_of_tax(cik) if df.empty or "Error" in df.columns: raise HTTPException(status_code=500, detail="Error fetching data") response = get_sambanova_response(query, df) v_data = visualize_data(df) return { "response": response, "data": df.to_dict() if not df.empty else {}, "plot_data": v_data.get("plot_data", {}), "plot_visible": v_data.get("plot_visible", False), } def process_interface(cik_name, query): if not query.strip(): return "❌ Please enter a query.", gr.update(value=None), gr.update(value="No plot", visible=False) result = mcp_query({"query": query, "cik_name": cik_name}) if "error" in result: return result["error"], gr.update(value=None), gr.update(value="Error", visible=False) df = pd.DataFrame(result.get("data", {})) if result.get("data") else pd.DataFrame() # Plot using matplotlib if result["plot_visible"] and result.get("plot_data"): plot_df = pd.DataFrame(result["plot_data"]) fig, ax = plt.subplots(figsize=(12, 4)) ax.plot(plot_df["Frame"], plot_df["Value"], marker="o") ax.set_title("Trend Over Time") ax.set_xlabel("Frame") ax.set_ylabel("Value") ax.grid(True) # Rotate + reduce number of ticks ax.set_xticks(ax.get_xticks()[::2]) # Show every 4th tick plt.setp(ax.get_xticklabels(), rotation=45, ha='right') plt.subplots_adjust(top=0.85) # ✅ Fix top overlap plt.tight_layout() plot = gr.Plot(fig) else: plot = gr.update(value=None, visible=False) return result["response"], df, plot # Gradio UI with gr.Blocks() as demo: gr.Markdown("# SEC Data Query Interface") with gr.Row(): # ✅ Your preferred layout cik_dropdown = gr.Dropdown( choices=[ "Tesla (TSLA)", "Apple (AAPL)", "Microsoft (MSFT)", "Amazon (AMZN)", "Alphabet (GOOGL)", "Meta Platforms (META)", "NVIDIA (NVDA)", "Berkshire Hathaway (BRK.A)", "JPMorgan Chase (JPM)", "Johnson & Johnson (JNJ)", "Visa (V)", "Procter & Gamble (PG)", "UnitedHealth Group (UNH)", "Home Depot (HD)", "Mastercard (MA)", "Exxon Mobil (XOM)", "Pfizer (PFE)", "Coca-Cola (KO)", "PepsiCo (PEP)", "Walmart (WMT)" ], value="Apple (AAPL)", label="Select Company" ) query_input = gr.Textbox( label="Enter your query (e.g., 'Show trends')", lines=1, value="Show trends" ) with gr.Row(): submit_button = gr.Button("Submit") with gr.Row(): gr.Markdown("### 📝 Response") with gr.Row(): output_text = gr.Textbox(interactive=False) with gr.Row(): gr.Markdown("### 📈 Visualization") with gr.Row(): output_plot = gr.Plot(label=".", visible=True) with gr.Row(): gr.Markdown("### 📊 Financial Metrics") with gr.Row(): output_table = gr.DataFrame() submit_button.click( fn=process_interface, inputs=[cik_dropdown, query_input], outputs=[output_text, output_table, output_plot] ) if __name__ == "__main__": demo.launch(mcp_server=True)