File size: 8,765 Bytes
8c47be2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
from random import randint, random
import gradio as gr
import pandas as pd
import requests
import os
import json
from openai import OpenAI
import matplotlib.pyplot as plt

# Flag to indicate MCP server mode
mcp_server = True

# SEC API settings
SEC_API_URL = "https://data.sec.gov/api/xbrl/companyfacts/CIK{}.json"
USER_AGENT = os.environ.get("USER_AGENT", "Your Name your.email@example.com")

# Sample CIK list
CIK_OPTIONS = {
    "Tesla (TSLA)": "0001318605",
    "Apple (AAPL)": "0000320193",
    "Microsoft (MSFT)": "0000789019",
    "Amazon (AMZN)": "0001018724",
    "Alphabet (GOOGL)": "0001652044",
    "Meta Platforms (META)": "0001326801",
    "NVIDIA (NVDA)": "0001045810",
    "Berkshire Hathaway (BRK.A)": "0001067983",
    "JPMorgan Chase (JPM)": "0000019617",
    "Johnson & Johnson (JNJ)": "0000200406",
    "Visa (V)": "0001403161",
    "Procter & Gamble (PG)": "0000080424",
    "UnitedHealth Group (UNH)": "0000731766",
    "Home Depot (HD)": "0000354950",
    "Mastercard (MA)": "0001141391",
    "Exxon Mobil (XOM)": "0000034088",
    "Pfizer (PFE)": "0000078003",
    "Coca-Cola (KO)": "0000021344",
    "PepsiCo (PEP)": "0000077476",
    "Walmart (WMT)": "0000104169"
}

# SambaNova API settings
SAMBANOVA_API_URL = "https://api.cloud.sambanova.ai/v1/chat/completions"
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")  # Set in your environment

def fetch_comprehensive_income_net_of_tax(cik):
    """
    Fetch 'ComprehensiveIncomeNetOfTax' USD values from SEC 10-Q filings for a given CIK.

    Args:
        cik (str): Central Index Key (CIK) of the company.

    Returns:
        pd.DataFrame: DataFrame of values and metadata for 'ComprehensiveIncomeNetOfTax'.
    """
    headers = {"User-Agent": USER_AGENT}
    url = SEC_API_URL.format(cik)

    print(f"Fetching data from SEC API for CIK: {cik} at URL: {url}")

    try:
        response = requests.get(url, headers=headers)
        data = response.json()

        # Navigate directly to the desired metric
        item_data = data.get("facts", {}).get("us-gaap", {}).get("ComprehensiveIncomeNetOfTax", {})
        usd_entries = item_data.get("units", {}).get("USD", [])

        filtered_entries = [
            {
                # "Metric": "ComprehensiveIncomeNetOfTax",
                "Frame": entry.get("frame"),
                "Value": entry.get("val"),
                "Period": f"{entry.get('fy')}{entry.get('fp')}",
                "Form": entry.get("form"),
                "Filed": entry.get("filed")
            }
            for entry in usd_entries
            if entry.get("form") == "10-Q" and entry.get("frame")
        ]

        return pd.DataFrame(filtered_entries)

    except requests.RequestException as e:
        print(f"Error fetching SEC data: {e}")
        return pd.DataFrame({"Error": [str(e)]})


# Generate response using SambaNova
def get_sambanova_response(query, data):
    context = f"SEC data: {json.dumps(data.to_dict() if not data.empty and 'Error' not in data.columns else {})}. User query: {query}"
    messages = [
            {
                "role": "system",
                "content": "You are a financial data assistant. Provide concise answers based on SEC data, including trends or summaries where applicable."
            },
            {"role": "user", "content": context}
        ]

    try:
        sambanova_client = OpenAI(
            api_key = SAMBANOVA_API_KEY,
            base_url = "https://api.sambanova.ai/v1",
        )

        response = sambanova_client.chat.completions.create(
            model = "Llama-4-Maverick-17B-128E-Instruct",
            messages = messages,
            temperature = 0.1,
            top_p = 0.1,
        )

        return response.choices[0].message.content
    except Exception as e:
        return f"Error: {str(e)}"

# Method to visualize numerical data
def visualize_data(data):
    """
    Generate a line plot using matplotlib and use 'Frame' as the x-axis.

    Args:
        data (pd.DataFrame): DataFrame containing 'Value' and 'Frame' columns.

    Returns:
        tuple: Gradio Plot object and visibility flag.
    """
    if data.empty or "Error" in data.columns:
        return gr.update(value="No data to visualize"), False

    df = data.copy()
    if "Value" not in df.columns or "Frame" not in df.columns:
        return gr.update(value="Missing 'Value' or 'Frame' in data"), False

    df["Value"] = pd.to_numeric(df["Value"], errors="coerce")
    df = df[df["Value"].notna() & df["Frame"].notna()]

    if df.empty:
        return gr.update(value="No valid data to plot"), False

    # Sort frames in lexical order
    df_sorted = df.sort_values(by="Frame")
    x = df_sorted["Frame"]
    y = df_sorted["Value"]

    return {
        "plot_data": {
            "Frame": x,
            "Value": y
        },
        "plot_visible": True
    }

# MCP server endpoint to handle queries
def mcp_query(query_data):
    query = query_data.get("query", "")
    cik_name = query_data.get("cik_name", "Apple (AAPL)")
    cik = CIK_OPTIONS.get(cik_name)

    print(f"Received query: {query} for CIK: {cik}")

    if not cik or not query:
        raise HTTPException(status_code=400, detail="Invalid CIK or query")

    df = fetch_comprehensive_income_net_of_tax(cik)

    if df.empty or "Error" in df.columns:
        raise HTTPException(status_code=500, detail="Error fetching data")

    response = get_sambanova_response(query, df)

    v_data = visualize_data(df)

    return {
        "response": response,
        "data": df.to_dict() if not df.empty else {},
        "plot_data": v_data.get("plot_data", {}),
        "plot_visible": v_data.get("plot_visible", False),
    }

def process_interface(cik_name, query):
    if not query.strip():
        return "❌ Please enter a query.", gr.update(value=None), gr.update(value="No plot", visible=False)

    result = mcp_query({"query": query, "cik_name": cik_name})

    if "error" in result:
        return result["error"], gr.update(value=None), gr.update(value="Error", visible=False)

    df = pd.DataFrame(result.get("data", {})) if result.get("data") else pd.DataFrame()
    
    # Plot using matplotlib
    if result["plot_visible"] and result.get("plot_data"):
        plot_df = pd.DataFrame(result["plot_data"])
        fig, ax = plt.subplots(figsize=(12, 4))
        ax.plot(plot_df["Frame"], plot_df["Value"], marker="o")
        ax.set_title("Trend Over Time")
        ax.set_xlabel("Frame")
        ax.set_ylabel("Value")
        ax.grid(True)
        # Rotate + reduce number of ticks
        ax.set_xticks(ax.get_xticks()[::2])  # Show every 4th tick
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')

        plt.subplots_adjust(top=0.85)  # ✅ Fix top overlap

        plt.tight_layout()
        plot = gr.Plot(fig)
    else:
        plot = gr.update(value=None, visible=False)

    return result["response"], df, plot

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# SEC Data Query Interface")

    with gr.Row():  # ✅ Your preferred layout
        cik_dropdown = gr.Dropdown(
            choices=[
                "Tesla (TSLA)",
                "Apple (AAPL)",
                "Microsoft (MSFT)",
                "Amazon (AMZN)",
                "Alphabet (GOOGL)",
                "Meta Platforms (META)",
                "NVIDIA (NVDA)",
                "Berkshire Hathaway (BRK.A)",
                "JPMorgan Chase (JPM)",
                "Johnson & Johnson (JNJ)",
                "Visa (V)",
                "Procter & Gamble (PG)",
                "UnitedHealth Group (UNH)",
                "Home Depot (HD)",
                "Mastercard (MA)",
                "Exxon Mobil (XOM)",
                "Pfizer (PFE)",
                "Coca-Cola (KO)",
                "PepsiCo (PEP)",
                "Walmart (WMT)"
            ],
            value="Apple (AAPL)",
            label="Select Company"
        )
        query_input = gr.Textbox(
            label="Enter your query (e.g., 'Show trends')",
            lines=1,
            value="Show trends"
        )

    with gr.Row():
        submit_button = gr.Button("Submit")

    with gr.Row():
        gr.Markdown("### 📝 Response")
    with gr.Row():
        output_text = gr.Textbox(interactive=False)
 
    with gr.Row():
        gr.Markdown("### 📈 Visualization")
    with gr.Row():
         output_plot = gr.Plot(label=".", visible=True)

    with gr.Row():
        gr.Markdown("### 📊 Financial Metrics")
    with gr.Row():
        output_table = gr.DataFrame()
 
    submit_button.click(
        fn=process_interface,
        inputs=[cik_dropdown, query_input],
        outputs=[output_text, output_table, output_plot]
    )
    
if __name__ == "__main__":
    demo.launch(mcp_server=True)