Spaces:
Running
Running
File size: 7,715 Bytes
5d2eba0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | # private_mcp.py
from fastapi import FastAPI, HTTPException
import uvicorn
import sqlite3
import logging
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_ollama import ChatOllama
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("Private_MCP_Server")
# --- Database Configuration ---
DB_FILE = "portfolio.db"
# --- LLM Configuration (Local Llama 3) ---
# This connects to the Ollama application running on your machine.
# Make sure Ollama and the llama3 model are running.
llm = ChatOllama(model="llama3", temperature=0)
# --- Text-to-SQL Prompt Engineering ---
# This prompt is carefully designed to make Llama 3 generate ONLY safe SQL queries.
text_to_sql_prompt = ChatPromptTemplate.from_messages([
("system",
"""You are a Text-to-SQL assistant. Convert the question to a read-only SQLite query for the 'holdings' table.
Schema: symbol (TEXT), shares (INTEGER), average_cost (REAL).
RULES:
1. SELECT only. No INSERT/UPDATE/DELETE.
2. Output ONLY the SQL query. No markdown.
"""),
("human", "Question: {question}")
])
# Create the LangChain chain for Text-to-SQL
sql_generation_chain = text_to_sql_prompt | llm | StrOutputParser()
# --- FastAPI App ---
app = FastAPI(title="Aegis Private MCP Server")
@app.on_event("startup")
async def startup_db():
"""Initialize the database with dummy data if it doesn't exist."""
try:
with sqlite3.connect(DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS holdings (
symbol TEXT PRIMARY KEY,
shares INTEGER,
average_cost REAL
)
""")
# Check if data exists
cursor.execute("SELECT count(*) FROM holdings")
if cursor.fetchone()[0] == 0:
logger.info("Populating database with diverse dummy data...")
# Expanded list of companies across sectors
dummy_data = [
# Tech
('AAPL', 5000, 180.20), ('MSFT', 3000, 350.50), ('GOOGL', 1500, 140.10), ('NVDA', 800, 450.00), ('AMD', 2000, 110.30),
('INTC', 4000, 35.40), ('CRM', 1200, 220.10), ('ADBE', 600, 550.20), ('ORCL', 2500, 115.50), ('CSCO', 3500, 52.10),
# Finance
('JPM', 2000, 150.40), ('BAC', 5000, 32.10), ('GS', 500, 340.50), ('V', 1000, 240.20), ('MA', 800, 380.10),
# Retail & Consumer
('WMT', 1500, 160.30), ('TGT', 1000, 130.50), ('COST', 400, 550.10), ('KO', 3000, 58.20), ('PEP', 2500, 170.40),
('PG', 2000, 150.10), ('NKE', 1200, 105.30), ('SBUX', 1800, 95.40),
# Healthcare
('JNJ', 2500, 160.20), ('PFE', 4000, 35.10), ('UNH', 600, 480.50), ('LLY', 400, 580.10), ('MRK', 2000, 110.20),
# Energy & Industrial
('XOM', 3000, 105.40), ('CVX', 2000, 150.20), ('GE', 1500, 110.50), ('CAT', 800, 280.10), ('BA', 500, 210.30),
# Auto
('TSLA', 1000, 220.90), ('F', 5000, 12.10), ('GM', 4000, 35.40)
]
cursor.executemany("INSERT INTO holdings (symbol, shares, average_cost) VALUES (?, ?, ?)", dummy_data)
conn.commit()
logger.info("Database populated successfully.")
else:
logger.info("Database already contains data.")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
def execute_safe_query(query: str, params=None):
"""
Executes a SQL query after a basic safety check.
This is a critical security function.
"""
# SECURITY CHECK: Ensure the query is read-only.
if not query.strip().upper().startswith("SELECT"):
logger.error(f"SECURITY VIOLATION: Attempted to execute non-SELECT query: {query}")
raise HTTPException(status_code=403, detail="Forbidden: Only SELECT queries are allowed.")
try:
with sqlite3.connect(DB_FILE) as conn:
conn.row_factory = sqlite3.Row # Makes results dict-like
cursor = conn.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
results = [dict(row) for row in cursor.fetchall()]
# Sanitize results: Replace None with 0 (common for SUM on empty set)
for row in results:
for key, value in row.items():
if value is None:
row[key] = 0
return results
except sqlite3.Error as e:
logger.error(f"Database error executing query '{query}': {e}")
raise HTTPException(status_code=500, detail=f"Database query failed: {e}")
@app.post("/portfolio_data")
async def get_portfolio_data(payload: dict):
"""
Takes a natural language question, converts it to SQL using Llama 3,
and executes it against the internal portfolio database.
"""
question = payload.get("question")
if not question:
raise HTTPException(status_code=400, detail="'question' is a required field.")
logger.info(f"Received portfolio data question: '{question}'")
try:
# Step 1: Generate the SQL query using the local LLM
try:
generated_sql = await sql_generation_chain.ainvoke({"question": question})
logger.info(f"Llama 3 generated SQL: {generated_sql}")
except Exception as llm_error:
logger.warning(f"LLM generation failed (likely Ollama offline): {llm_error}. Using fallback logic.")
# Fallback Logic: Dynamic symbol extraction
import re
q_upper = question.upper()
# Look for common ticker patterns (1-5 uppercase letters)
matches = re.findall(r'\b[A-Z]{1,5}\b', q_upper)
found_symbol = None
ignored_words = ["WHAT", "IS", "THE", "TO", "OF", "FOR", "IN", "AND", "OR", "SHOW", "ME", "DATA", "STOCK", "PRICE", "DO", "WE", "OWN", "HAVE", "ANY", "EXPOSURE", "CURRENT"]
for match in matches:
if match not in ignored_words:
found_symbol = match
break
if found_symbol:
generated_sql = f"SELECT * FROM holdings WHERE symbol='{found_symbol}'"
else:
generated_sql = "SELECT * FROM holdings" # Default to showing all
logger.info(f"Fallback SQL generated: {generated_sql}")
# Step 2: Execute the query using our secure function
results = execute_safe_query(generated_sql)
logger.info(f"Successfully executed query and found {len(results)} records.")
return {"status": "success", "question": question, "generated_sql": generated_sql, "data": results}
except HTTPException as http_exc:
# Re-raise HTTP exceptions from our secure executor
raise http_exc
except Exception as e:
logger.critical(f"An unexpected error occurred in the portfolio data endpoint: {e}")
# Don't crash the client, return an empty success with error note
return {"status": "error", "message": str(e), "data": []}
@app.get("/")
def read_root():
return {"message": "Aegis Private MCP Server is operational."}
# --- Main Execution ---
if __name__ == "__main__":
# This server runs on port 8003
uvicorn.run(app, host="127.0.0.1", port=8003) |