""" Visualization Tools for MonteWalk Provides flexible charting capabilities for market data, risk analysis, and backtesting. """ import io import base64 import logging from typing import List, Dict, Any, Optional, Union import matplotlib matplotlib.use('Agg') # Non-interactive backend import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import numpy as np logger = logging.getLogger(__name__) # Set dark theme matching Gradio UI plt.style.use('dark_background') sns.set_palette(["#60a5fa", "#a855f7", "#ec4899", "#10b981", "#f59e0b"]) def _encode_figure(fig) -> str: """Convert matplotlib figure to base64-encoded PNG string.""" buf = io.BytesIO() fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='#0f172a') buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode('utf-8') plt.close(fig) return f"data:image/png;base64,{img_base64}" def plot_line( data: Union[Dict[str, List], pd.DataFrame], x_label: str = "X", y_label: str = "Y", title: str = "Line Chart", labels: Optional[List[str]] = None ) -> str: """Create line chart. Example: plot_line({'x': [1, 2, 3], 'y': [10, 20, 15]}, title="Growth")""" fig, ax = plt.subplots(figsize=(12, 6)) if isinstance(data, dict): if 'y' in data and isinstance(data['y'][0], list): # Multiple lines for i, y_data in enumerate(data['y']): label = labels[i] if labels and i < len(labels) else f"Series {i+1}" ax.plot(data['x'], y_data, linewidth=2, label=label, marker='o', markersize=4) ax.legend(loc='best', framealpha=0.9) else: # Single line ax.plot(data['x'], data['y'], linewidth=2, color='#60a5fa', marker='o', markersize=4) else: # DataFrame for col in data.columns[1:]: ax.plot(data.iloc[:, 0], data[col], linewidth=2, label=col, marker='o', markersize=4) if len(data.columns) > 2: ax.legend(loc='best', framealpha=0.9) ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0') ax.set_ylabel(y_label, fontsize=11, color='#e2e8f0') ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20) ax.grid(True, alpha=0.2, linestyle='--') ax.tick_params(colors='#94a3b8') return _encode_figure(fig) def plot_candlestick( df: pd.DataFrame, title: str = "Candlestick Chart", volume: bool = True ) -> str: """Create candlestick chart. Example: plot_candlestick(df, title="AAPL Daily", volume=True)""" try: import mplfinance as mpf # Prepare data df_copy = df.copy() if 'Date' in df_copy.columns: df_copy.set_index('Date', inplace=True) df_copy.index = pd.to_datetime(df_copy.index) # Custom style matching our theme mc = mpf.make_marketcolors( up='#10b981', down='#ef4444', edge='inherit', wick='inherit', volume='#60a5fa', alpha=0.9 ) s = mpf.make_mpf_style( marketcolors=mc, gridstyle='--', gridcolor='#334155', facecolor='#0f172a', figcolor='#0f172a', edgecolor='#1e293b' ) # Plot fig, axes = mpf.plot( df_copy, type='candle', style=s, volume=volume, title=title, ylabel='Price', ylabel_lower='Volume', figsize=(12, 8), returnfig=True ) return _encode_figure(fig) except Exception as e: logger.error(f"Error creating candlestick chart: {e}") # Fallback to simple line chart return plot_line( {'x': list(range(len(df))), 'y': df['Close'].tolist()}, x_label="Time", y_label="Price", title=title ) def plot_histogram( data: Union[List[float], np.ndarray], bins: int = 50, title: str = "Distribution", x_label: str = "Value", percentiles: Optional[List[float]] = None ) -> str: """Create histogram. Example: plot_histogram([1, 2, 2, 3, 3, 3, 4, 4, 5], bins=5, percentiles=[5, 50, 95])""" fig, ax = plt.subplots(figsize=(12, 6)) # Plot histogram n, bins_edges, patches = ax.hist( data, bins=bins, color='#60a5fa', alpha=0.7, edgecolor='#94a3b8' ) # Add percentile markers if percentiles: for p in percentiles: val = np.percentile(data, p) ax.axvline(val, color='#ec4899', linestyle='--', linewidth=2, alpha=0.8) ax.text(val, ax.get_ylim()[1] * 0.9, f'P{int(p)}: {val:.2f}', rotation=90, va='top', ha='right', color='#ec4899', fontweight='bold') ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0') ax.set_ylabel('Frequency', fontsize=11, color='#e2e8f0') ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20) ax.grid(True, alpha=0.2, axis='y', linestyle='--') ax.tick_params(colors='#94a3b8') return _encode_figure(fig) def plot_scatter( x: List[float], y: List[float], title: str = "Scatter Plot", x_label: str = "X", y_label: str = "Y", trend_line: bool = True ) -> str: """Create scatter plot. Example: plot_scatter([1, 2, 3], [2, 4, 6], title="Correlation", trend_line=True)""" fig, ax = plt.subplots(figsize=(10, 8)) # Scatter plot ax.scatter(x, y, alpha=0.6, s=50, color='#60a5fa', edgecolors='#94a3b8', linewidth=0.5) # Add trend line if trend_line and len(x) > 1: z = np.polyfit(x, y, 1) p = np.poly1d(z) ax.plot(x, p(x), color='#ec4899', linestyle='--', linewidth=2, alpha=0.8) # Calculate R² correlation = np.corrcoef(x, y)[0, 1] ax.text(0.05, 0.95, f'Correlation: {correlation:.3f}', transform=ax.transAxes, fontsize=10, color='#f8fafc', verticalalignment='top', bbox=dict(boxstyle='round', facecolor='#1e293b', alpha=0.8)) ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0') ax.set_ylabel(y_label, fontsize=11, color='#e2e8f0') ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20) ax.grid(True, alpha=0.2, linestyle='--') ax.tick_params(colors='#94a3b8') return _encode_figure(fig) def plot_heatmap( matrix: Union[np.ndarray, pd.DataFrame], labels: Optional[List[str]] = None, title: str = "Heatmap" ) -> str: """Create heatmap. Example: plot_heatmap(df.corr(), title="Correlation Matrix")""" fig, ax = plt.subplots(figsize=(10, 8)) if isinstance(matrix, pd.DataFrame): labels = matrix.columns.tolist() matrix = matrix.values # Create heatmap im = ax.imshow(matrix, cmap='RdYlGn', aspect='auto', vmin=-1, vmax=1) # Colorbar cbar = plt.colorbar(im, ax=ax) cbar.ax.tick_params(colors='#94a3b8') # Labels if labels: ax.set_xticks(np.arange(len(labels))) ax.set_yticks(np.arange(len(labels))) ax.set_xticklabels(labels, rotation=45, ha='right', color='#e2e8f0') ax.set_yticklabels(labels, color='#e2e8f0') # Annotate cells with values for i in range(len(matrix)): for j in range(len(matrix[0])): text = ax.text(j, i, f'{matrix[i, j]:.2f}', ha="center", va="center", color='#0f172a', fontsize=9, fontweight='bold') ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20) return _encode_figure(fig) def plot_bar( categories: List[str], values: List[float], title: str = "Bar Chart", x_label: str = "Category", y_label: str = "Value", horizontal: bool = False ) -> str: """Create bar chart. Example: plot_bar(["A", "B"], [10, 20], title="Comparison", horizontal=False)""" fig, ax = plt.subplots(figsize=(12, 6)) colors = ['#60a5fa' if v >= 0 else '#ef4444' for v in values] if horizontal: ax.barh(categories, values, color=colors, alpha=0.8, edgecolor='#94a3b8') ax.set_xlabel(y_label, fontsize=11, color='#e2e8f0') ax.set_ylabel(x_label, fontsize=11, color='#e2e8f0') else: ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='#94a3b8') ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0') ax.set_ylabel(y_label, fontsize=11, color='#e2e8f0') plt.xticks(rotation=45, ha='right') ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20) ax.grid(True, alpha=0.2, axis='y' if not horizontal else 'x', linestyle='--') ax.tick_params(colors='#94a3b8') ax.axhline(0, color='#94a3b8', linewidth=0.8) if not horizontal else ax.axvline(0, color='#94a3b8', linewidth=0.8) return _encode_figure(fig) def plot_data( data: Union[Dict, pd.DataFrame, List], chart_type: str = "auto", title: str = "Chart", **kwargs ) -> str: """Universal plotting function. Example: plot_data([1, 2, 3, 4, 5], chart_type="auto", title="Auto Histogram")""" try: # Auto-detect chart type if chart_type == "auto": if isinstance(data, pd.DataFrame) and all(col in data.columns for col in ['Open', 'High', 'Low', 'Close']): chart_type = "candlestick" elif isinstance(data, (list, np.ndarray)) and isinstance(data[0], (int, float)): chart_type = "histogram" elif isinstance(data, dict): if 'x' in data and 'y' in data: chart_type = "line" elif 'categories' in data and 'values' in data: chart_type = "bar" else: chart_type = "line" # Route to specific function if chart_type == "line": return plot_line(data, title=title, **kwargs) elif chart_type == "candlestick": return plot_candlestick(data, title=title, **kwargs) elif chart_type == "histogram": return plot_histogram(data, title=title, **kwargs) elif chart_type == "scatter": return plot_scatter(data.get('x', []), data.get('y', []), title=title, **kwargs) elif chart_type == "heatmap": return plot_heatmap(data, title=title, **kwargs) elif chart_type == "bar": return plot_bar(data.get('categories', []), data.get('values', []), title=title, **kwargs) else: raise ValueError(f"Unknown chart type: {chart_type}") except Exception as e: logger.error(f"Error creating chart: {e}", exc_info=True) return f"Error creating chart: {str(e)}"