MonteWalk / tools /visualizer.py
Obotu's picture
final touches
dc45523
"""
Visualization Tools for MonteWalk
Provides flexible charting capabilities for market data, risk analysis, and backtesting.
"""
import io
import base64
import logging
from typing import List, Dict, Any, Optional, Union
import matplotlib
matplotlib.use('Agg') # Non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
logger = logging.getLogger(__name__)
# Set dark theme matching Gradio UI
plt.style.use('dark_background')
sns.set_palette(["#60a5fa", "#a855f7", "#ec4899", "#10b981", "#f59e0b"])
def _encode_figure(fig) -> str:
"""Convert matplotlib figure to base64-encoded PNG string."""
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='#0f172a')
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return f"data:image/png;base64,{img_base64}"
def plot_line(
data: Union[Dict[str, List], pd.DataFrame],
x_label: str = "X",
y_label: str = "Y",
title: str = "Line Chart",
labels: Optional[List[str]] = None
) -> str:
"""Create line chart. Example: plot_line({'x': [1, 2, 3], 'y': [10, 20, 15]}, title="Growth")"""
fig, ax = plt.subplots(figsize=(12, 6))
if isinstance(data, dict):
if 'y' in data and isinstance(data['y'][0], list):
# Multiple lines
for i, y_data in enumerate(data['y']):
label = labels[i] if labels and i < len(labels) else f"Series {i+1}"
ax.plot(data['x'], y_data, linewidth=2, label=label, marker='o', markersize=4)
ax.legend(loc='best', framealpha=0.9)
else:
# Single line
ax.plot(data['x'], data['y'], linewidth=2, color='#60a5fa', marker='o', markersize=4)
else:
# DataFrame
for col in data.columns[1:]:
ax.plot(data.iloc[:, 0], data[col], linewidth=2, label=col, marker='o', markersize=4)
if len(data.columns) > 2:
ax.legend(loc='best', framealpha=0.9)
ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0')
ax.set_ylabel(y_label, fontsize=11, color='#e2e8f0')
ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20)
ax.grid(True, alpha=0.2, linestyle='--')
ax.tick_params(colors='#94a3b8')
return _encode_figure(fig)
def plot_candlestick(
df: pd.DataFrame,
title: str = "Candlestick Chart",
volume: bool = True
) -> str:
"""Create candlestick chart. Example: plot_candlestick(df, title="AAPL Daily", volume=True)"""
try:
import mplfinance as mpf
# Prepare data
df_copy = df.copy()
if 'Date' in df_copy.columns:
df_copy.set_index('Date', inplace=True)
df_copy.index = pd.to_datetime(df_copy.index)
# Custom style matching our theme
mc = mpf.make_marketcolors(
up='#10b981', down='#ef4444',
edge='inherit',
wick='inherit',
volume='#60a5fa',
alpha=0.9
)
s = mpf.make_mpf_style(
marketcolors=mc,
gridstyle='--',
gridcolor='#334155',
facecolor='#0f172a',
figcolor='#0f172a',
edgecolor='#1e293b'
)
# Plot
fig, axes = mpf.plot(
df_copy,
type='candle',
style=s,
volume=volume,
title=title,
ylabel='Price',
ylabel_lower='Volume',
figsize=(12, 8),
returnfig=True
)
return _encode_figure(fig)
except Exception as e:
logger.error(f"Error creating candlestick chart: {e}")
# Fallback to simple line chart
return plot_line(
{'x': list(range(len(df))), 'y': df['Close'].tolist()},
x_label="Time",
y_label="Price",
title=title
)
def plot_histogram(
data: Union[List[float], np.ndarray],
bins: int = 50,
title: str = "Distribution",
x_label: str = "Value",
percentiles: Optional[List[float]] = None
) -> str:
"""Create histogram. Example: plot_histogram([1, 2, 2, 3, 3, 3, 4, 4, 5], bins=5, percentiles=[5, 50, 95])"""
fig, ax = plt.subplots(figsize=(12, 6))
# Plot histogram
n, bins_edges, patches = ax.hist(
data, bins=bins, color='#60a5fa', alpha=0.7, edgecolor='#94a3b8'
)
# Add percentile markers
if percentiles:
for p in percentiles:
val = np.percentile(data, p)
ax.axvline(val, color='#ec4899', linestyle='--', linewidth=2, alpha=0.8)
ax.text(val, ax.get_ylim()[1] * 0.9, f'P{int(p)}: {val:.2f}',
rotation=90, va='top', ha='right', color='#ec4899', fontweight='bold')
ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0')
ax.set_ylabel('Frequency', fontsize=11, color='#e2e8f0')
ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20)
ax.grid(True, alpha=0.2, axis='y', linestyle='--')
ax.tick_params(colors='#94a3b8')
return _encode_figure(fig)
def plot_scatter(
x: List[float],
y: List[float],
title: str = "Scatter Plot",
x_label: str = "X",
y_label: str = "Y",
trend_line: bool = True
) -> str:
"""Create scatter plot. Example: plot_scatter([1, 2, 3], [2, 4, 6], title="Correlation", trend_line=True)"""
fig, ax = plt.subplots(figsize=(10, 8))
# Scatter plot
ax.scatter(x, y, alpha=0.6, s=50, color='#60a5fa', edgecolors='#94a3b8', linewidth=0.5)
# Add trend line
if trend_line and len(x) > 1:
z = np.polyfit(x, y, 1)
p = np.poly1d(z)
ax.plot(x, p(x), color='#ec4899', linestyle='--', linewidth=2, alpha=0.8)
# Calculate R²
correlation = np.corrcoef(x, y)[0, 1]
ax.text(0.05, 0.95, f'Correlation: {correlation:.3f}',
transform=ax.transAxes, fontsize=10, color='#f8fafc',
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='#1e293b', alpha=0.8))
ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0')
ax.set_ylabel(y_label, fontsize=11, color='#e2e8f0')
ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20)
ax.grid(True, alpha=0.2, linestyle='--')
ax.tick_params(colors='#94a3b8')
return _encode_figure(fig)
def plot_heatmap(
matrix: Union[np.ndarray, pd.DataFrame],
labels: Optional[List[str]] = None,
title: str = "Heatmap"
) -> str:
"""Create heatmap. Example: plot_heatmap(df.corr(), title="Correlation Matrix")"""
fig, ax = plt.subplots(figsize=(10, 8))
if isinstance(matrix, pd.DataFrame):
labels = matrix.columns.tolist()
matrix = matrix.values
# Create heatmap
im = ax.imshow(matrix, cmap='RdYlGn', aspect='auto', vmin=-1, vmax=1)
# Colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.ax.tick_params(colors='#94a3b8')
# Labels
if labels:
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha='right', color='#e2e8f0')
ax.set_yticklabels(labels, color='#e2e8f0')
# Annotate cells with values
for i in range(len(matrix)):
for j in range(len(matrix[0])):
text = ax.text(j, i, f'{matrix[i, j]:.2f}',
ha="center", va="center", color='#0f172a', fontsize=9, fontweight='bold')
ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20)
return _encode_figure(fig)
def plot_bar(
categories: List[str],
values: List[float],
title: str = "Bar Chart",
x_label: str = "Category",
y_label: str = "Value",
horizontal: bool = False
) -> str:
"""Create bar chart. Example: plot_bar(["A", "B"], [10, 20], title="Comparison", horizontal=False)"""
fig, ax = plt.subplots(figsize=(12, 6))
colors = ['#60a5fa' if v >= 0 else '#ef4444' for v in values]
if horizontal:
ax.barh(categories, values, color=colors, alpha=0.8, edgecolor='#94a3b8')
ax.set_xlabel(y_label, fontsize=11, color='#e2e8f0')
ax.set_ylabel(x_label, fontsize=11, color='#e2e8f0')
else:
ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='#94a3b8')
ax.set_xlabel(x_label, fontsize=11, color='#e2e8f0')
ax.set_ylabel(y_label, fontsize=11, color='#e2e8f0')
plt.xticks(rotation=45, ha='right')
ax.set_title(title, fontsize=14, fontweight='bold', color='#f8fafc', pad=20)
ax.grid(True, alpha=0.2, axis='y' if not horizontal else 'x', linestyle='--')
ax.tick_params(colors='#94a3b8')
ax.axhline(0, color='#94a3b8', linewidth=0.8) if not horizontal else ax.axvline(0, color='#94a3b8', linewidth=0.8)
return _encode_figure(fig)
def plot_data(
data: Union[Dict, pd.DataFrame, List],
chart_type: str = "auto",
title: str = "Chart",
**kwargs
) -> str:
"""Universal plotting function. Example: plot_data([1, 2, 3, 4, 5], chart_type="auto", title="Auto Histogram")"""
try:
# Auto-detect chart type
if chart_type == "auto":
if isinstance(data, pd.DataFrame) and all(col in data.columns for col in ['Open', 'High', 'Low', 'Close']):
chart_type = "candlestick"
elif isinstance(data, (list, np.ndarray)) and isinstance(data[0], (int, float)):
chart_type = "histogram"
elif isinstance(data, dict):
if 'x' in data and 'y' in data:
chart_type = "line"
elif 'categories' in data and 'values' in data:
chart_type = "bar"
else:
chart_type = "line"
# Route to specific function
if chart_type == "line":
return plot_line(data, title=title, **kwargs)
elif chart_type == "candlestick":
return plot_candlestick(data, title=title, **kwargs)
elif chart_type == "histogram":
return plot_histogram(data, title=title, **kwargs)
elif chart_type == "scatter":
return plot_scatter(data.get('x', []), data.get('y', []), title=title, **kwargs)
elif chart_type == "heatmap":
return plot_heatmap(data, title=title, **kwargs)
elif chart_type == "bar":
return plot_bar(data.get('categories', []), data.get('values', []), title=title, **kwargs)
else:
raise ValueError(f"Unknown chart type: {chart_type}")
except Exception as e:
logger.error(f"Error creating chart: {e}", exc_info=True)
return f"Error creating chart: {str(e)}"