DataViz-Agent / src /utils.py
Muhammad Mustehson
Update Old Code
4a84072
import matplotlib.pyplot as plt
def plot_chart(
chart_type,
data,
title=None,
x_axis_label=None,
y_axis_label=None,
**kwargs,
):
fig, ax = plt.subplots(figsize=(8, 5))
if chart_type in {"bar", "line", "pie"}:
if data.shape[1] < 2:
raise ValueError("DataFrame must have at least two columns")
data = data.head(20)
x = data.iloc[:, 0]
y = data.iloc[:, 1]
if chart_type == "bar":
ax.bar(x, y)
ax.set_xlabel(x_axis_label or data.columns[0])
ax.set_ylabel(y_axis_label or data.columns[1])
elif chart_type == "line":
ax.plot(x, y, marker="o")
ax.set_xlabel(x_axis_label or data.columns[0])
ax.set_ylabel(y_axis_label or data.columns[1])
elif chart_type == "pie":
ax.pie(y, labels=x, autopct="%1.1f%%")
ax.axis("equal")
elif chart_type == "hist":
if data.shape[1] < 3:
raise ValueError("Histogram DataFrame must have 3 columns")
bin_start = data.iloc[:, 0]
bin_end = data.iloc[:, 1]
frequency = data.iloc[:, 2]
widths = bin_end - bin_start
ax.bar(bin_start, frequency, width=widths, align="edge")
ax.set_xlabel(x_axis_label or "Value Range")
ax.set_ylabel(y_axis_label or "Frequency")
else:
plt.close(fig)
raise ValueError(f"Unsupported chart type: {chart_type}")
if title:
ax.set_title(title)
fig.tight_layout()
return fig