Spaces:
Running
Running
| from random import randint, random | |
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| import os | |
| import json | |
| from openai import OpenAI | |
| import matplotlib.pyplot as plt | |
| # Flag to indicate MCP server mode | |
| mcp_server = True | |
| # SEC API settings | |
| SEC_API_URL = "https://data.sec.gov/api/xbrl/companyfacts/CIK{}.json" | |
| USER_AGENT = os.environ.get("USER_AGENT", "Your Name your.email@example.com") | |
| # Sample CIK list | |
| CIK_OPTIONS = { | |
| "Tesla (TSLA)": "0001318605", | |
| "Apple (AAPL)": "0000320193", | |
| "Microsoft (MSFT)": "0000789019" | |
| } | |
| # SambaNova API settings | |
| SAMBANOVA_API_URL = "https://api.cloud.sambanova.ai/v1/chat/completions" | |
| SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY") # Set in your environment | |
| def fetch_comprehensive_income_net_of_tax(cik): | |
| """ | |
| Fetch 'ComprehensiveIncomeNetOfTax' USD values from SEC 10-Q filings for a given CIK. | |
| Args: | |
| cik (str): Central Index Key (CIK) of the company. | |
| Returns: | |
| pd.DataFrame: DataFrame of values and metadata for 'ComprehensiveIncomeNetOfTax'. | |
| """ | |
| headers = {"User-Agent": USER_AGENT} | |
| url = SEC_API_URL.format(cik) | |
| print(f"Fetching data from SEC API for CIK: {cik} at URL: {url}") | |
| try: | |
| response = requests.get(url, headers=headers) | |
| data = response.json() | |
| # Navigate directly to the desired metric | |
| item_data = data.get("facts", {}).get("us-gaap", {}).get("ComprehensiveIncomeNetOfTax", {}) | |
| usd_entries = item_data.get("units", {}).get("USD", []) | |
| # print(f'usd_entries: {usd_entries}') | |
| filtered_entries = [ | |
| { | |
| # "Metric": "ComprehensiveIncomeNetOfTax", | |
| "Frame": entry.get("frame"), | |
| "Value": entry.get("val"), | |
| "Period": f"{entry.get('fy')}{entry.get('fp')}", | |
| "Form": entry.get("form"), | |
| "Filed": entry.get("filed") | |
| } | |
| for entry in usd_entries | |
| if entry.get("form") == "10-Q" and entry.get("frame") | |
| ] | |
| # print(f'filtered_entries: {filtered_entries}') | |
| return pd.DataFrame(filtered_entries) | |
| except requests.RequestException as e: | |
| print(f"Error fetching SEC data: {e}") | |
| return pd.DataFrame({"Error": [str(e)]}) | |
| # Generate response using SambaNova | |
| def get_sambanova_response(query, data): | |
| context = f"SEC data: {json.dumps(data.to_dict() if not data.empty and 'Error' not in data.columns else {})}. User query: {query}" | |
| # headers = { | |
| # "Authorization": f"Bearer {SAMBANOVA_API_KEY}", | |
| # "Content-Type": "application/json" | |
| # } | |
| # payload = { | |
| # "model": "sambanova-chat", # Replace with actual model name if different | |
| # "messages": [ | |
| # { | |
| # "role": "system", | |
| # "content": "You are a financial data assistant. Provide concise answers based on SEC data, including trends or summaries where applicable." | |
| # }, | |
| # {"role": "user", "content": context} | |
| # ], | |
| # "max_tokens": 2000 | |
| # } | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a financial data assistant. Provide concise answers based on SEC data, including trends or summaries where applicable." | |
| }, | |
| {"role": "user", "content": context} | |
| ] | |
| # print(f"Sending request to SambaNova API with payload: {json.dumps(payload, indent=2)}") | |
| # print(f"Using headers: {headers}") | |
| # print(f'Context: {context}') | |
| try: | |
| sambanova_client = OpenAI( | |
| api_key = SAMBANOVA_API_KEY, | |
| base_url = "https://api.sambanova.ai/v1", | |
| ) | |
| response = sambanova_client.chat.completions.create( | |
| model = "Llama-4-Maverick-17B-128E-Instruct", | |
| messages = messages, | |
| temperature = 0.1, | |
| top_p = 0.1, | |
| ) | |
| # print(f"SambaNova response: {response}") | |
| # response = requests.post(SAMBANOVA_API_URL, headers=headers, json=payload) | |
| # result = response.json() | |
| # return result["choices"][0]["message"]["content"] | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Method to visualize numerical data | |
| def visualize_data(data): | |
| """ | |
| Generate a line plot using matplotlib and use 'Frame' as the x-axis. | |
| Args: | |
| data (pd.DataFrame): DataFrame containing 'Value' and 'Frame' columns. | |
| Returns: | |
| tuple: Gradio Plot object and visibility flag. | |
| """ | |
| if data.empty or "Error" in data.columns: | |
| return gr.update(value="No data to visualize"), False | |
| df = data.copy() | |
| if "Value" not in df.columns or "Frame" not in df.columns: | |
| return gr.update(value="Missing 'Value' or 'Frame' in data"), False | |
| df["Value"] = pd.to_numeric(df["Value"], errors="coerce") | |
| df = df[df["Value"].notna() & df["Frame"].notna()] | |
| if df.empty: | |
| return gr.update(value="No valid data to plot"), False | |
| # Sort frames in lexical order | |
| df_sorted = df.sort_values(by="Frame") | |
| x = df_sorted["Frame"] | |
| y = df_sorted["Value"] | |
| return { | |
| "plot_data": { | |
| "Frame": x, | |
| "Value": y | |
| }, | |
| "plot_visible": True | |
| } | |
| # MCP server endpoint to handle queries | |
| def mcp_query(query_data): | |
| query = query_data.get("query", "") | |
| cik_name = query_data.get("cik_name", "Apple (AAPL)") | |
| cik = CIK_OPTIONS.get(cik_name) | |
| print(f"Received query: {query} for CIK: {cik}") | |
| if not cik or not query: | |
| raise HTTPException(status_code=400, detail="Invalid CIK or query") | |
| df = fetch_comprehensive_income_net_of_tax(cik) | |
| # print(f"Fetched data for CIK {cik}:\n {df}") | |
| if df.empty or "Error" in df.columns: | |
| raise HTTPException(status_code=500, detail="Error fetching data") | |
| response = get_sambanova_response(query, df) | |
| # print(f"SambaNova response: {response}") | |
| v_data = visualize_data(df) | |
| return { | |
| "response": response, | |
| "data": df.to_dict() if not df.empty else {}, | |
| "plot_data": v_data.get("plot_data", {}), | |
| "plot_visible": v_data.get("plot_visible", False), | |
| } | |
| def process_interface(cik_name, query): | |
| if not query.strip(): | |
| return "❌ Please enter a query.", gr.update(value=None), gr.update(value="No plot", visible=False) | |
| result = mcp_query({"query": query, "cik_name": cik_name}) | |
| # print(f"Processing interface with result: {result}") | |
| if "error" in result: | |
| return result["error"], gr.update(value=None), gr.update(value="Error", visible=False) | |
| df = pd.DataFrame(result.get("data", {})) if result.get("data") else pd.DataFrame() | |
| # Plot using matplotlib | |
| if result["plot_visible"] and result.get("plot_data"): | |
| plot_df = pd.DataFrame(result["plot_data"]) | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| ax.plot(plot_df["Frame"], plot_df["Value"], marker="o") | |
| ax.set_title("Trend Over Time") | |
| ax.set_xlabel("Frame") | |
| ax.set_ylabel("Value") | |
| ax.grid(True) | |
| # Rotate + reduce number of ticks | |
| ax.set_xticks(ax.get_xticks()[::2]) # Show every 4th tick | |
| plt.setp(ax.get_xticklabels(), rotation=45, ha='right') | |
| plt.subplots_adjust(top=0.85) # ✅ Fix top overlap | |
| plt.tight_layout() | |
| plot = gr.Plot(fig) | |
| else: | |
| plot = gr.update(value=None, visible=False) | |
| return result["response"], df, plot | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SEC Data Query Interface") | |
| with gr.Row(): # ✅ Your preferred layout | |
| cik_dropdown = gr.Dropdown( | |
| choices=["Apple (AAPL)", "Tesla (TSLA)", "Microsoft (MSFT)"], | |
| value="Apple (AAPL)", | |
| label="Select Company" | |
| ) | |
| query_input = gr.Textbox( | |
| label="Enter your query (e.g., 'Show trends')", | |
| lines=1, | |
| value="Show trends" | |
| ) | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit") | |
| with gr.Row(): | |
| gr.Markdown("### 📝 Response") | |
| with gr.Row(): | |
| output_text = gr.Textbox(interactive=False) | |
| with gr.Row(): | |
| gr.Markdown("### 📈 Visualization") | |
| with gr.Row(): | |
| output_plot = gr.Plot(label=".", visible=True) | |
| with gr.Row(): | |
| gr.Markdown("### 📊 Financial Metrics") | |
| with gr.Row(): | |
| output_table = gr.DataFrame() | |
| submit_button.click( | |
| fn=process_interface, | |
| inputs=[cik_dropdown, query_input], | |
| outputs=[output_text, output_table, output_plot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) | |