Trading_Guru / models /portfolio_optimization.py
shaheerawan3's picture
Update models/portfolio_optimization.py
97dd12e verified
# portfolio_optimization.py
from typing import Dict, List
import numpy as np
import pandas as pd
from scipy.optimize import minimize
import plotly.graph_objects as go # Add this import
import streamlit as st # Add this import
class PortfolioOptimizer:
def __init__(self, returns_df: pd.DataFrame):
self.returns = returns_df
self.mean_returns = returns_df.mean()
self.cov_matrix = returns_df.cov()
def validate_data(self) -> None:
"""Validate input data and handle missing values."""
if self.returns.empty:
raise ValueError("Returns data is empty")
# Fill missing values
self.returns.fillna(method='ffill', inplace=True)
self.returns.fillna(method='bfill', inplace=True)
# Check for remaining NaN values
if self.returns.isna().any().any():
raise ValueError("Unable to fill all missing values in returns data")
# Check for zero variance
if (self.returns.std() == 0).any():
raise ValueError("One or more assets have zero variance")
# Check for insufficient data
if len(self.returns) < 30: # Minimum required data points
raise ValueError("Insufficient data points for optimization (minimum 30 required)")
def portfolio_performance(self, weights: np.array) -> tuple:
"""
Calculate portfolio performance metrics.
Args:
weights (np.array): Array of portfolio weights
Returns:
tuple: (returns, volatility, sharpe)
"""
try:
returns = np.sum(self.mean_returns * weights)
volatility = np.sqrt(np.dot(weights.T, np.dot(self.cov_matrix, weights)))
sharpe = returns / volatility if volatility > 0 else 0
return returns, volatility, sharpe
except Exception as e:
raise ValueError(f"Error calculating portfolio performance: {str(e)}")
def optimize_portfolio(self, target_return: float, risk_free_rate: float = 0.0) -> Dict:
"""Optimize portfolio weights for target return."""
try:
num_assets = len(self.returns.columns)
if num_assets < 2:
return {
'success': False,
'message': 'Need at least 2 assets for optimization'
}
# Initial guess (equal weights)
init_weights = np.array([1.0/num_assets] * num_assets)
# Constraints
bounds = tuple((0, 1) for _ in range(num_assets))
constraints = [
{'type': 'eq', 'fun': lambda x: np.sum(x) - 1} # weights sum to 1
]
# Optimize
result = minimize(
lambda w: self._portfolio_volatility(w),
init_weights,
method='SLSQP',
bounds=bounds,
constraints=constraints
)
if not result.success:
return {
'success': False,
'message': 'Optimization failed to converge'
}
optimal_weights = result.x
portfolio_return = np.sum(self.mean_returns * optimal_weights) * 252
portfolio_volatility = self._portfolio_volatility(optimal_weights)
sharpe_ratio = (portfolio_return - risk_free_rate) / portfolio_volatility
return {
'success': True,
'weights': dict(zip(self.returns.columns, optimal_weights)),
'return': portfolio_return,
'volatility': portfolio_volatility,
'sharpe_ratio': sharpe_ratio
}
except Exception as e:
return {
'success': False,
'message': str(e)
}
def _portfolio_volatility(self, weights: np.ndarray) -> float:
"""Calculate portfolio volatility."""
return np.sqrt(np.dot(weights.T, np.dot(self.cov_matrix * 252, weights)))
def get_efficient_frontier(self, num_portfolios: int = 100) -> List[Dict]:
"""Generate efficient frontier points."""
efficient_portfolios = []
min_ret = self.mean_returns.min() * 252
max_ret = self.mean_returns.max() * 252
for target in np.linspace(min_ret, max_ret, num_portfolios):
result = self.optimize_portfolio(target)
if result['success']:
efficient_portfolios.append({
'return': result['return'],
'volatility': result['volatility'],
'weights': result['weights']
})
return efficient_portfolios
def _portfolio_return(self, weights: np.array) -> float:
"""Calculate portfolio return."""
return np.sum(self.mean_returns * weights)
def _negative_sharpe(self, weights: np.array) -> float:
"""Calculate negative Sharpe ratio for minimization."""
returns, volatility, sharpe = self.portfolio_performance(weights)
return -sharpe if volatility > 0 else 0
def plot_signals(data: pd.DataFrame, signals: pd.DataFrame, asset: str) -> go.Figure:
"""Create a plotly figure with price and signals."""
# Make sure data is properly indexed
if not isinstance(data.index, pd.RangeIndex):
data = data.reset_index()
fig = go.Figure()
# Add candlestick chart
fig.add_trace(go.Candlestick(
x=data['Date'],
open=data['Open'],
high=data['High'],
low=data['Low'],
close=data['Close'],
name='Price'
))
# Add buy signals
buy_points = data[signals['Position'] == 1]
if not buy_points.empty:
fig.add_trace(go.Scatter(
x=buy_points['Date'],
y=buy_points['High'],
mode='markers',
name='Buy Signal',
marker=dict(
symbol='triangle-up',
size=15,
color='green'
)
))
# Add sell signals
sell_points = data[signals['Position'] == -1]
if not sell_points.empty:
fig.add_trace(go.Scatter(
x=sell_points['Date'],
y=sell_points['Low'],
mode='markers',
name='Sell Signal',
marker=dict(
symbol='triangle-down',
size=15,
color='red'
)
))
# Update layout
fig.update_layout(
title=f'{asset} Trading Signals',
xaxis_title='Date',
yaxis_title='Price',
template='plotly_dark' if st.session_state.theme == 'dark' else 'plotly_white',
xaxis_rangeslider_visible=False,
height=600,
hovermode='x unified'
)
return fig