CS5130_finalProject / visualizations.py
Khang Nguyen
inital commit
aa893a9
import pandas as pd
import matplotlib.pyplot as plt
def _safe_title(text):
"""Small helper to avoid None or empty titles."""
if text is None:
return ""
return str(text)
def create_time_series_plot(df, date_col, value_col, agg_func="sum", freq="M", category_col=None):
"""
Create a time series line plot.
If category_col is provided (for example "Region" or "Model"),
it will draw one line per category with a legend.
Parameters
----------
df : pd.DataFrame
Input data.
date_col : str
Column name with datetime values (e.g. "Date").
value_col : str
Numeric column to aggregate (e.g. "Estimated_Deliveries").
agg_func : str
Aggregation method: "sum", "mean", "count", "median".
freq : str
Resampling frequency ("M" = month, "Q" = quarter).
category_col : str or None
Optional column to group by (e.g. "Region" or "Model").
Returns
-------
fig : matplotlib.figure.Figure or None
The figure object with the plot, or None if something fails.
"""
if date_col not in df.columns or value_col not in df.columns:
return None
data = df.copy()
# Make sure date column is datetime
if not pd.api.types.is_datetime64_any_dtype(data[date_col]):
try:
data[date_col] = pd.to_datetime(data[date_col])
except Exception:
return None
# If no category, just aggregate the whole series over time (old behavior)
if category_col is None or category_col not in df.columns:
data = data.set_index(date_col)
if agg_func == "mean":
grouped = data[value_col].resample(freq).mean()
elif agg_func == "count":
grouped = data[value_col].resample(freq).count()
elif agg_func == "median":
grouped = data[value_col].resample(freq).median()
else:
# default to sum
grouped = data[value_col].resample(freq).sum()
fig, ax = plt.subplots()
ax.plot(grouped.index, grouped.values)
ax.set_xlabel("Date")
ax.set_ylabel(_safe_title(value_col))
ax.set_title(f"Time Series of {value_col} ({agg_func})")
ax.grid(True)
fig.tight_layout()
return fig
# If category_col is given (e.g. Region / Model), draw a line per category.
data = data[[date_col, value_col, category_col]].dropna()
if data.empty:
return None
# Group by category and date (using Grouper for frequency)
grouped = data.groupby(
[category_col, pd.Grouper(key=date_col, freq=freq)]
)[value_col]
if agg_func == "mean":
grouped = grouped.mean()
elif agg_func == "count":
grouped = grouped.count()
elif agg_func == "median":
grouped = grouped.median()
else:
grouped = grouped.sum()
# Now grouped is a Series with MultiIndex: (category, date)
# We unstack so that the index is date and columns are categories.
table = grouped.unstack(0)
fig, ax = plt.subplots()
# Plot one line per category
for col in table.columns:
# Some categories might be all NaN if data is weird; skip those
series = table[col].dropna()
if series.empty:
continue
ax.plot(series.index, series.values, label=str(col))
ax.set_xlabel("Date")
ax.set_ylabel(_safe_title(value_col))
ax.set_title(f"Time Series of {value_col} by {category_col} ({agg_func})")
ax.grid(True)
ax.legend(title=_safe_title(category_col))
fig.tight_layout()
return fig
def create_distribution_plot(df, numeric_col, kind="hist", bins=30):
"""
Create a distribution plot for a numeric column.
Parameters
----------
df : pd.DataFrame
Input data.
numeric_col : str
Numeric column to visualize.
kind : str
"hist" for histogram, "box" for box plot.
bins : int
Number of bins for histogram.
Returns
-------
fig : matplotlib.figure.Figure or None
The figure object with the plot, or None if something fails.
"""
if numeric_col not in df.columns:
return None
series = df[numeric_col].dropna()
if series.empty:
return None
fig, ax = plt.subplots()
if kind == "box":
ax.boxplot(series.values, vert=True)
ax.set_xticks([1])
ax.set_xticklabels([_safe_title(numeric_col)])
ax.set_ylabel(_safe_title(numeric_col))
ax.set_title(f"Box Plot of {numeric_col}")
else:
# Default to histogram
ax.hist(series.values, bins=bins)
ax.set_xlabel(_safe_title(numeric_col))
ax.set_ylabel("Frequency")
ax.set_title(f"Histogram of {numeric_col}")
fig.tight_layout()
return fig
def create_category_bar_plot(df, category_col, value_col=None, agg_func="count", top_n=10):
"""
Create a bar chart for a categorical column.
- If value_col is None or agg_func == "count":
Show counts of each category.
- If value_col is numeric:
Aggregate using sum / mean / median.
This works nicely with columns like "Region" or "Model"
for the Tesla dataset.
Parameters
----------
df : pd.DataFrame
Input data.
category_col : str
Name of the categorical column (e.g. "Region").
value_col : str or None
Numeric column to aggregate (e.g. "Estimated_Deliveries"), or None.
agg_func : str
"sum", "mean", "median", or "count".
top_n : int
Show only the top N categories.
Returns
-------
fig : matplotlib.figure.Figure or None
The figure object with the plot, or None if something fails.
"""
if category_col not in df.columns:
return None
data = df.copy()
# Pure counts mode
if value_col is None or agg_func == "count":
counts = data[category_col].value_counts().head(top_n)
y_values = counts.values
x_labels = counts.index.astype(str).tolist()
title = f"Top {top_n} {category_col} by Count"
y_label = "Count"
else:
if value_col not in df.columns:
return None
data = data[[category_col, value_col]].dropna()
if data.empty:
return None
if agg_func == "mean":
grouped = data.groupby(category_col)[value_col].mean()
elif agg_func == "median":
grouped = data.groupby(category_col)[value_col].median()
else:
# default to sum
grouped = data.groupby(category_col)[value_col].sum()
grouped = grouped.sort_values(ascending=False).head(top_n)
y_values = grouped.values
x_labels = grouped.index.astype(str).tolist()
title = f"Top {top_n} {category_col} by {agg_func} of {value_col}"
y_label = _safe_title(value_col)
fig, ax = plt.subplots()
ax.bar(x_labels, y_values)
ax.set_xlabel(_safe_title(category_col))
ax.set_ylabel(y_label)
ax.set_title(title)
ax.tick_params(axis="x", rotation=45)
fig.tight_layout()
return fig
def create_scatter_plot(df, x_col, y_col, category_col=None):
"""
Create a scatter plot for two numeric columns.
If category_col is given (e.g. "Region" or "Model"),
points will be split by that category and a legend will be shown.
Parameters
----------
df : pd.DataFrame
Input data.
x_col : str
Column for x-axis.
y_col : str
Column for y-axis.
category_col : str or None
Optional column to group points by.
Returns
-------
fig : matplotlib.figure.Figure or None
The figure object with the plot, or None if something fails.
"""
if x_col not in df.columns or y_col not in df.columns:
return None
data = df[[x_col, y_col] + ([category_col] if category_col and category_col in df.columns else [])].dropna()
if data.empty:
return None
fig, ax = plt.subplots()
if category_col is None or category_col not in data.columns:
# Simple scatter, no categories
ax.scatter(data[x_col], data[y_col])
ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
else:
# One scatter per category, with legend
for cat_value, group_df in data.groupby(category_col):
ax.scatter(group_df[x_col], group_df[y_col], label=str(cat_value))
ax.set_title(f"Scatter Plot: {x_col} vs {y_col} by {category_col}")
ax.legend(title=_safe_title(category_col))
ax.set_xlabel(_safe_title(x_col))
ax.set_ylabel(_safe_title(y_col))
ax.grid(True)
fig.tight_layout()
return fig
def create_correlation_heatmap(df, numeric_cols):
"""
Create a correlation heatmap for numeric columns,
with the numeric values displayed on the cells.
Parameters
----------
df : pd.DataFrame
Input data.
numeric_cols : list of str
List of numeric column names to include in the correlation matrix.
Returns
-------
fig : matplotlib.figure.Figure or None
The figure object with the heatmap, or None if something fails.
"""
cols = [c for c in numeric_cols if c in df.columns]
if len(cols) < 2:
return None
corr = df[cols].corr()
fig, ax = plt.subplots()
cax = ax.matshow(corr.values)
fig.colorbar(cax)
ax.set_xticks(range(len(cols)))
ax.set_yticks(range(len(cols)))
ax.set_xticklabels(cols, rotation=45, ha="left")
ax.set_yticklabels(cols)
ax.set_title("Correlation Heatmap", pad=20)
# Add the correlation values on top of each cell
for i in range(len(cols)):
for j in range(len(cols)):
value = corr.values[i, j]
ax.text(j, i, f"{value:.2f}", va="center", ha="center", fontsize=8)
fig.tight_layout()
return fig