_Bot / server.py
akashraut's picture
Update server.py
1abadd0 verified
# server.py
import sys, json, asyncio
import yfinance as yf
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from pydantic import BaseModel
from langchain_core.tools import tool, BaseTool
# --- MCP Server Logic ---
class MCPToolServer:
def __init__(self, tools: list[BaseTool]):
self.tools = {t.name: t for t in tools}
async def _handle_request(self, req: dict):
method, params, req_id = req.get("method"), req.get("params", {}), req.get("id")
if method == "discover":
result = [
{"name": t.name, "description": t.description, "args_schema": t.args}
for t in self.tools.values()
]
return {"jsonrpc": "2.0", "result": result, "id": req_id}
if method == "execute":
tool_name, tool_args = params.get("tool_name"), params.get("tool_args", {})
if tool_to_exec := self.tools.get(tool_name):
try:
result = await tool_to_exec.ainvoke(tool_args)
return {"jsonrpc": "2.0", "result": result, "id": req_id}
except Exception as e:
return {
"jsonrpc": "2.0",
"error": {"code": -32603, "message": str(e)},
"id": req_id,
}
return {"jsonrpc": "2.0", "error": {"code": -32601, "message": "Method not found"}, "id": req_id}
async def serve(self):
reader = asyncio.StreamReader()
await asyncio.get_event_loop().connect_read_pipe(lambda: asyncio.StreamReaderProtocol(reader), sys.stdin)
writer_transport, _ = await asyncio.get_event_loop().connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
writer = asyncio.StreamWriter(writer_transport, _, None, asyncio.get_event_loop())
while line := await reader.readline():
try:
response = await self._handle_request(json.loads(line))
writer.write(json.dumps(response).encode() + b"\n")
await writer.drain()
except json.JSONDecodeError:
continue
# --- Tool Definitions ---
COMMODITY_TICKERS = {"gold": "GC=F", "silver": "SI=F"}
class PriceInput(BaseModel):
commodity_name: str
@tool(args_schema=PriceInput)
async def get_current_price(commodity_name: str) -> str:
"""Gets the most recent price for gold or silver."""
ticker = COMMODITY_TICKERS.get(commodity_name.lower())
if not ticker:
return f"Error: '{commodity_name}' is not supported."
try:
price = yf.Ticker(ticker).history(period="1d")['Close'].iloc[-1]
return f"The current price of {commodity_name} is approx. ${price:.2f} USD."
except Exception as e:
return f"Error fetching price: {e}"
class ForecastInput(BaseModel):
commodity_name: str
forecast_days: int
@tool(args_schema=ForecastInput)
async def get_price_forecast(commodity_name: str = None, forecast_days: int = None, **kwargs) -> str:
"""
Generates a 3 to 5 day price forecast for gold or silver.
Also handles positional-style input like ['gold', 5].
"""
# 🔹 Step 1: Handle positional args (['gold', 5])
if not commodity_name and "args" in kwargs and isinstance(kwargs["args"], list):
args_list = kwargs["args"]
if len(args_list) >= 2:
commodity_name = args_list[0]
forecast_days = int(args_list[1])
# 🔹 Step 2: Handle JSON input like {"commodity_name": "gold", "forecast_days": 5}
if not commodity_name:
commodity_name = kwargs.get("commodity_name")
if not forecast_days:
forecast_days = kwargs.get("forecast_days")
# 🔹 Step 3: Validate
if not commodity_name:
return "Error: Please provide a commodity (gold or silver)."
if not forecast_days:
return "Error: Please provide forecast_days (3, 4, or 5)."
commodity_name = str(commodity_name).lower()
forecast_days = int(forecast_days)
ticker = COMMODITY_TICKERS.get(commodity_name)
if not ticker:
return f"Error: '{commodity_name}' is not supported. Use 'gold' or 'silver'."
if not 3 <= forecast_days <= 5:
return "Error: forecast_days must be 3, 4, or 5."
try:
data = yf.download(ticker, period="6mo", progress=False)['Close']
if data.empty:
return "Not enough historical data for forecasting."
forecast = ExponentialSmoothing(data, trend="add").fit().forecast(steps=forecast_days)
return "\n".join([f"Day {i+1}: ${val:.2f} USD" for i, val in enumerate(forecast)])
except Exception as e:
return f"Error during forecasting: {e}"
# --- EntryPoint ---
if __name__ == "__main__":
server = MCPToolServer(tools=[get_current_price, get_price_forecast])
asyncio.run(server.serve())