| | |
| | from fastapi import FastAPI, HTTPException |
| | import uvicorn |
| | import sqlite3 |
| | import logging |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain_ollama import ChatOllama |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger("Private_MCP_Server") |
| |
|
| | |
| | DB_FILE = "portfolio.db" |
| |
|
| | |
| | |
| | |
| | llm = ChatOllama(model="llama3", temperature=0) |
| |
|
| | |
| | |
| | text_to_sql_prompt = ChatPromptTemplate.from_messages([ |
| | ("system", |
| | """You are a Text-to-SQL assistant. Convert the question to a read-only SQLite query for the 'holdings' table. |
| | Schema: symbol (TEXT), shares (INTEGER), average_cost (REAL). |
| | RULES: |
| | 1. SELECT only. No INSERT/UPDATE/DELETE. |
| | 2. Output ONLY the SQL query. No markdown. |
| | """), |
| | ("human", "Question: {question}") |
| | ]) |
| |
|
| | |
| | sql_generation_chain = text_to_sql_prompt | llm | StrOutputParser() |
| |
|
| | |
| | app = FastAPI(title="Aegis Private MCP Server") |
| |
|
| | @app.on_event("startup") |
| | async def startup_db(): |
| | """Initialize the database with dummy data if it doesn't exist.""" |
| | try: |
| | with sqlite3.connect(DB_FILE) as conn: |
| | cursor = conn.cursor() |
| | cursor.execute(""" |
| | CREATE TABLE IF NOT EXISTS holdings ( |
| | symbol TEXT PRIMARY KEY, |
| | shares INTEGER, |
| | average_cost REAL |
| | ) |
| | """) |
| | |
| | |
| | cursor.execute("SELECT count(*) FROM holdings") |
| | if cursor.fetchone()[0] == 0: |
| | logger.info("Populating database with diverse dummy data...") |
| | |
| | dummy_data = [ |
| | |
| | ('AAPL', 5000, 180.20), ('MSFT', 3000, 350.50), ('GOOGL', 1500, 140.10), ('NVDA', 800, 450.00), ('AMD', 2000, 110.30), |
| | ('INTC', 4000, 35.40), ('CRM', 1200, 220.10), ('ADBE', 600, 550.20), ('ORCL', 2500, 115.50), ('CSCO', 3500, 52.10), |
| | |
| | ('JPM', 2000, 150.40), ('BAC', 5000, 32.10), ('GS', 500, 340.50), ('V', 1000, 240.20), ('MA', 800, 380.10), |
| | |
| | ('WMT', 1500, 160.30), ('TGT', 1000, 130.50), ('COST', 400, 550.10), ('KO', 3000, 58.20), ('PEP', 2500, 170.40), |
| | ('PG', 2000, 150.10), ('NKE', 1200, 105.30), ('SBUX', 1800, 95.40), |
| | |
| | ('JNJ', 2500, 160.20), ('PFE', 4000, 35.10), ('UNH', 600, 480.50), ('LLY', 400, 580.10), ('MRK', 2000, 110.20), |
| | |
| | ('XOM', 3000, 105.40), ('CVX', 2000, 150.20), ('GE', 1500, 110.50), ('CAT', 800, 280.10), ('BA', 500, 210.30), |
| | |
| | ('TSLA', 1000, 220.90), ('F', 5000, 12.10), ('GM', 4000, 35.40) |
| | ] |
| | cursor.executemany("INSERT INTO holdings (symbol, shares, average_cost) VALUES (?, ?, ?)", dummy_data) |
| | conn.commit() |
| | logger.info("Database populated successfully.") |
| | else: |
| | logger.info("Database already contains data.") |
| | except Exception as e: |
| | logger.error(f"Failed to initialize database: {e}") |
| |
|
| |
|
| | def execute_safe_query(query: str, params=None): |
| | """ |
| | Executes a SQL query after a basic safety check. |
| | This is a critical security function. |
| | """ |
| | |
| | if not query.strip().upper().startswith("SELECT"): |
| | logger.error(f"SECURITY VIOLATION: Attempted to execute non-SELECT query: {query}") |
| | raise HTTPException(status_code=403, detail="Forbidden: Only SELECT queries are allowed.") |
| | |
| | try: |
| | with sqlite3.connect(DB_FILE) as conn: |
| | conn.row_factory = sqlite3.Row |
| | cursor = conn.cursor() |
| | if params: |
| | cursor.execute(query, params) |
| | else: |
| | cursor.execute(query) |
| | |
| | results = [dict(row) for row in cursor.fetchall()] |
| | |
| | for row in results: |
| | for key, value in row.items(): |
| | if value is None: |
| | row[key] = 0 |
| | return results |
| | except sqlite3.Error as e: |
| | logger.error(f"Database error executing query '{query}': {e}") |
| | raise HTTPException(status_code=500, detail=f"Database query failed: {e}") |
| |
|
| | @app.post("/portfolio_data") |
| | async def get_portfolio_data(payload: dict): |
| | """ |
| | Takes a natural language question, converts it to SQL using Llama 3, |
| | and executes it against the internal portfolio database. |
| | """ |
| | question = payload.get("question") |
| | if not question: |
| | raise HTTPException(status_code=400, detail="'question' is a required field.") |
| |
|
| | logger.info(f"Received portfolio data question: '{question}'") |
| |
|
| | try: |
| | |
| | try: |
| | generated_sql = await sql_generation_chain.ainvoke({"question": question}) |
| | logger.info(f"Llama 3 generated SQL: {generated_sql}") |
| | except Exception as llm_error: |
| | logger.warning(f"LLM generation failed (likely Ollama offline): {llm_error}. Using fallback logic.") |
| | |
| | import re |
| | q_upper = question.upper() |
| | |
| | matches = re.findall(r'\b[A-Z]{1,5}\b', q_upper) |
| | |
| | found_symbol = None |
| | ignored_words = ["WHAT", "IS", "THE", "TO", "OF", "FOR", "IN", "AND", "OR", "SHOW", "ME", "DATA", "STOCK", "PRICE", "DO", "WE", "OWN", "HAVE", "ANY", "EXPOSURE", "CURRENT"] |
| | |
| | for match in matches: |
| | if match not in ignored_words: |
| | found_symbol = match |
| | break |
| | |
| | if found_symbol: |
| | generated_sql = f"SELECT * FROM holdings WHERE symbol='{found_symbol}'" |
| | else: |
| | generated_sql = "SELECT * FROM holdings" |
| | logger.info(f"Fallback SQL generated: {generated_sql}") |
| |
|
| | |
| | results = execute_safe_query(generated_sql) |
| | logger.info(f"Successfully executed query and found {len(results)} records.") |
| |
|
| | return {"status": "success", "question": question, "generated_sql": generated_sql, "data": results} |
| |
|
| | except HTTPException as http_exc: |
| | |
| | raise http_exc |
| | except Exception as e: |
| | logger.critical(f"An unexpected error occurred in the portfolio data endpoint: {e}") |
| | |
| | return {"status": "error", "message": str(e), "data": []} |
| |
|
| | @app.get("/") |
| | def read_root(): |
| | return {"message": "Aegis Private MCP Server is operational."} |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | uvicorn.run(app, host="127.0.0.1", port=8003) |