Spaces:
Sleeping
Sleeping
File size: 6,940 Bytes
01f0e50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from pydantic import BaseModel
import json
from dotenv import load_dotenv
from datetime import datetime
from market import get_share_price
from database import write_account, read_account, write_log
load_dotenv(override=True)
INITIAL_BALANCE = 10_000.0
SPREAD = 0.002
class Transaction(BaseModel):
symbol: str
quantity: int
price: float
timestamp: str
rationale: str
def total(self) -> float:
return self.quantity * self.price
def __repr__(self):
return f"{abs(self.quantity)} shares of {self.symbol} at {self.price} each."
class Account(BaseModel):
name: str
balance: float
strategy: str
holdings: dict[str, int]
transactions: list[Transaction]
portfolio_value_time_series: list[tuple[str, float]]
@classmethod
def get(cls, name: str):
fields = read_account(name.lower())
if not fields:
fields = {
"name": name.lower(),
"balance": INITIAL_BALANCE,
"strategy": "",
"holdings": {},
"transactions": [],
"portfolio_value_time_series": []
}
write_account(name, fields)
return cls(**fields)
def save(self):
write_account(self.name.lower(), self.model_dump())
def reset(self, strategy: str):
self.balance = INITIAL_BALANCE
self.strategy = strategy
self.holdings = {}
self.transactions = []
self.portfolio_value_time_series = []
self.save()
def deposit(self, amount: float):
""" Deposit funds into the account. """
if amount <= 0:
raise ValueError("Deposit amount must be positive.")
self.balance += amount
print(f"Deposited ${amount}. New balance: ${self.balance}")
self.save()
def withdraw(self, amount: float):
""" Withdraw funds from the account, ensuring it doesn't go negative. """
if amount > self.balance:
raise ValueError("Insufficient funds for withdrawal.")
self.balance -= amount
print(f"Withdrew ${amount}. New balance: ${self.balance}")
self.save()
def buy_shares(self, symbol: str, quantity: int, rationale: str) -> str:
""" Buy shares of a stock if sufficient funds are available. """
price = get_share_price(symbol)
buy_price = price * (1 + SPREAD)
total_cost = buy_price * quantity
if total_cost > self.balance:
raise ValueError("Insufficient funds to buy shares.")
elif price==0:
raise ValueError(f"Unrecognized symbol {symbol}")
# Update holdings
self.holdings[symbol] = self.holdings.get(symbol, 0) + quantity
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Record transaction
transaction = Transaction(symbol=symbol, quantity=quantity, price=buy_price, timestamp=timestamp, rationale=rationale)
self.transactions.append(transaction)
# Update balance
self.balance -= total_cost
self.save()
write_log(self.name, "account", f"Bought {quantity} of {symbol}")
return "Completed. Latest details:\n" + self.report()
def sell_shares(self, symbol: str, quantity: int, rationale: str) -> str:
""" Sell shares of a stock if the user has enough shares. """
if self.holdings.get(symbol, 0) < quantity:
raise ValueError(f"Cannot sell {quantity} shares of {symbol}. Not enough shares held.")
price = get_share_price(symbol)
sell_price = price * (1 - SPREAD)
total_proceeds = sell_price * quantity
# Update holdings
self.holdings[symbol] -= quantity
# If shares are completely sold, remove from holdings
if self.holdings[symbol] == 0:
del self.holdings[symbol]
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Record transaction
transaction = Transaction(symbol=symbol, quantity=-quantity, price=sell_price, timestamp=timestamp, rationale=rationale) # negative quantity for sell
self.transactions.append(transaction)
# Update balance
self.balance += total_proceeds
self.save()
write_log(self.name, "account", f"Sold {quantity} of {symbol}")
return "Completed. Latest details:\n" + self.report()
def calculate_portfolio_value(self):
""" Calculate the total value of the user's portfolio. """
total_value = self.balance
for symbol, quantity in self.holdings.items():
total_value += get_share_price(symbol) * quantity
return total_value
def calculate_profit_loss(self, portfolio_value: float):
""" Calculate profit or loss from the initial spend. """
initial_spend = sum(transaction.total() for transaction in self.transactions)
return portfolio_value - initial_spend - self.balance
def get_holdings(self):
""" Report the current holdings of the user. """
return self.holdings
def get_profit_loss(self):
""" Report the user's profit or loss at any point in time. """
return self.calculate_profit_loss()
def list_transactions(self):
""" List all transactions made by the user. """
return [transaction.model_dump() for transaction in self.transactions]
def report(self) -> str:
""" Return a json string representing the account. """
portfolio_value = self.calculate_portfolio_value()
self.portfolio_value_time_series.append((datetime.now().strftime("%Y-%m-%d %H:%M:%S"), portfolio_value))
self.save()
pnl = self.calculate_profit_loss(portfolio_value)
data = self.model_dump()
data["total_portfolio_value"] = portfolio_value
data["total_profit_loss"] = pnl
write_log(self.name, "account", f"Retrieved account details")
return json.dumps(data)
def get_strategy(self) -> str:
""" Return the strategy of the account """
write_log(self.name, "account", f"Retrieved strategy")
return self.strategy
def change_strategy(self, strategy: str) -> str:
""" At your discretion, if you choose to, call this to change your investment strategy for the future """
self.strategy = strategy
self.save()
write_log(self.name, "account", f"Changed strategy")
return "Changed strategy"
# Example of usage:
if __name__ == "__main__":
account = Account("John Doe")
account.deposit(1000)
account.buy_shares("AAPL", 5)
account.sell_shares("AAPL", 2)
print(f"Current Holdings: {account.get_holdings()}")
print(f"Total Portfolio Value: {account.calculate_portfolio_value()}")
print(f"Profit/Loss: {account.get_profit_loss()}")
print(f"Transactions: {account.list_transactions()}") |