MCP_TradingAgents / accounts.py
Anindhya's picture
Upload folder using huggingface_hub
01f0e50 verified
from pydantic import BaseModel
import json
from dotenv import load_dotenv
from datetime import datetime
from market import get_share_price
from database import write_account, read_account, write_log
load_dotenv(override=True)
INITIAL_BALANCE = 10_000.0
SPREAD = 0.002
class Transaction(BaseModel):
symbol: str
quantity: int
price: float
timestamp: str
rationale: str
def total(self) -> float:
return self.quantity * self.price
def __repr__(self):
return f"{abs(self.quantity)} shares of {self.symbol} at {self.price} each."
class Account(BaseModel):
name: str
balance: float
strategy: str
holdings: dict[str, int]
transactions: list[Transaction]
portfolio_value_time_series: list[tuple[str, float]]
@classmethod
def get(cls, name: str):
fields = read_account(name.lower())
if not fields:
fields = {
"name": name.lower(),
"balance": INITIAL_BALANCE,
"strategy": "",
"holdings": {},
"transactions": [],
"portfolio_value_time_series": []
}
write_account(name, fields)
return cls(**fields)
def save(self):
write_account(self.name.lower(), self.model_dump())
def reset(self, strategy: str):
self.balance = INITIAL_BALANCE
self.strategy = strategy
self.holdings = {}
self.transactions = []
self.portfolio_value_time_series = []
self.save()
def deposit(self, amount: float):
""" Deposit funds into the account. """
if amount <= 0:
raise ValueError("Deposit amount must be positive.")
self.balance += amount
print(f"Deposited ${amount}. New balance: ${self.balance}")
self.save()
def withdraw(self, amount: float):
""" Withdraw funds from the account, ensuring it doesn't go negative. """
if amount > self.balance:
raise ValueError("Insufficient funds for withdrawal.")
self.balance -= amount
print(f"Withdrew ${amount}. New balance: ${self.balance}")
self.save()
def buy_shares(self, symbol: str, quantity: int, rationale: str) -> str:
""" Buy shares of a stock if sufficient funds are available. """
price = get_share_price(symbol)
buy_price = price * (1 + SPREAD)
total_cost = buy_price * quantity
if total_cost > self.balance:
raise ValueError("Insufficient funds to buy shares.")
elif price==0:
raise ValueError(f"Unrecognized symbol {symbol}")
# Update holdings
self.holdings[symbol] = self.holdings.get(symbol, 0) + quantity
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Record transaction
transaction = Transaction(symbol=symbol, quantity=quantity, price=buy_price, timestamp=timestamp, rationale=rationale)
self.transactions.append(transaction)
# Update balance
self.balance -= total_cost
self.save()
write_log(self.name, "account", f"Bought {quantity} of {symbol}")
return "Completed. Latest details:\n" + self.report()
def sell_shares(self, symbol: str, quantity: int, rationale: str) -> str:
""" Sell shares of a stock if the user has enough shares. """
if self.holdings.get(symbol, 0) < quantity:
raise ValueError(f"Cannot sell {quantity} shares of {symbol}. Not enough shares held.")
price = get_share_price(symbol)
sell_price = price * (1 - SPREAD)
total_proceeds = sell_price * quantity
# Update holdings
self.holdings[symbol] -= quantity
# If shares are completely sold, remove from holdings
if self.holdings[symbol] == 0:
del self.holdings[symbol]
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Record transaction
transaction = Transaction(symbol=symbol, quantity=-quantity, price=sell_price, timestamp=timestamp, rationale=rationale) # negative quantity for sell
self.transactions.append(transaction)
# Update balance
self.balance += total_proceeds
self.save()
write_log(self.name, "account", f"Sold {quantity} of {symbol}")
return "Completed. Latest details:\n" + self.report()
def calculate_portfolio_value(self):
""" Calculate the total value of the user's portfolio. """
total_value = self.balance
for symbol, quantity in self.holdings.items():
total_value += get_share_price(symbol) * quantity
return total_value
def calculate_profit_loss(self, portfolio_value: float):
""" Calculate profit or loss from the initial spend. """
initial_spend = sum(transaction.total() for transaction in self.transactions)
return portfolio_value - initial_spend - self.balance
def get_holdings(self):
""" Report the current holdings of the user. """
return self.holdings
def get_profit_loss(self):
""" Report the user's profit or loss at any point in time. """
return self.calculate_profit_loss()
def list_transactions(self):
""" List all transactions made by the user. """
return [transaction.model_dump() for transaction in self.transactions]
def report(self) -> str:
""" Return a json string representing the account. """
portfolio_value = self.calculate_portfolio_value()
self.portfolio_value_time_series.append((datetime.now().strftime("%Y-%m-%d %H:%M:%S"), portfolio_value))
self.save()
pnl = self.calculate_profit_loss(portfolio_value)
data = self.model_dump()
data["total_portfolio_value"] = portfolio_value
data["total_profit_loss"] = pnl
write_log(self.name, "account", f"Retrieved account details")
return json.dumps(data)
def get_strategy(self) -> str:
""" Return the strategy of the account """
write_log(self.name, "account", f"Retrieved strategy")
return self.strategy
def change_strategy(self, strategy: str) -> str:
""" At your discretion, if you choose to, call this to change your investment strategy for the future """
self.strategy = strategy
self.save()
write_log(self.name, "account", f"Changed strategy")
return "Changed strategy"
# Example of usage:
if __name__ == "__main__":
account = Account("John Doe")
account.deposit(1000)
account.buy_shares("AAPL", 5)
account.sell_shares("AAPL", 2)
print(f"Current Holdings: {account.get_holdings()}")
print(f"Total Portfolio Value: {account.calculate_portfolio_value()}")
print(f"Profit/Loss: {account.get_profit_loss()}")
print(f"Transactions: {account.list_transactions()}")