import pandas as pd import matplotlib.pyplot as plt def _safe_title(text): """Small helper to avoid None or empty titles.""" if text is None: return "" return str(text) def create_time_series_plot(df, date_col, value_col, agg_func="sum", freq="M", category_col=None): """ Create a time series line plot. If category_col is provided (for example "Region" or "Model"), it will draw one line per category with a legend. Parameters ---------- df : pd.DataFrame Input data. date_col : str Column name with datetime values (e.g. "Date"). value_col : str Numeric column to aggregate (e.g. "Estimated_Deliveries"). agg_func : str Aggregation method: "sum", "mean", "count", "median". freq : str Resampling frequency ("M" = month, "Q" = quarter). category_col : str or None Optional column to group by (e.g. "Region" or "Model"). Returns ------- fig : matplotlib.figure.Figure or None The figure object with the plot, or None if something fails. """ if date_col not in df.columns or value_col not in df.columns: return None data = df.copy() # Make sure date column is datetime if not pd.api.types.is_datetime64_any_dtype(data[date_col]): try: data[date_col] = pd.to_datetime(data[date_col]) except Exception: return None # If no category, just aggregate the whole series over time (old behavior) if category_col is None or category_col not in df.columns: data = data.set_index(date_col) if agg_func == "mean": grouped = data[value_col].resample(freq).mean() elif agg_func == "count": grouped = data[value_col].resample(freq).count() elif agg_func == "median": grouped = data[value_col].resample(freq).median() else: # default to sum grouped = data[value_col].resample(freq).sum() fig, ax = plt.subplots() ax.plot(grouped.index, grouped.values) ax.set_xlabel("Date") ax.set_ylabel(_safe_title(value_col)) ax.set_title(f"Time Series of {value_col} ({agg_func})") ax.grid(True) fig.tight_layout() return fig # If category_col is given (e.g. Region / Model), draw a line per category. data = data[[date_col, value_col, category_col]].dropna() if data.empty: return None # Group by category and date (using Grouper for frequency) grouped = data.groupby( [category_col, pd.Grouper(key=date_col, freq=freq)] )[value_col] if agg_func == "mean": grouped = grouped.mean() elif agg_func == "count": grouped = grouped.count() elif agg_func == "median": grouped = grouped.median() else: grouped = grouped.sum() # Now grouped is a Series with MultiIndex: (category, date) # We unstack so that the index is date and columns are categories. table = grouped.unstack(0) fig, ax = plt.subplots() # Plot one line per category for col in table.columns: # Some categories might be all NaN if data is weird; skip those series = table[col].dropna() if series.empty: continue ax.plot(series.index, series.values, label=str(col)) ax.set_xlabel("Date") ax.set_ylabel(_safe_title(value_col)) ax.set_title(f"Time Series of {value_col} by {category_col} ({agg_func})") ax.grid(True) ax.legend(title=_safe_title(category_col)) fig.tight_layout() return fig def create_distribution_plot(df, numeric_col, kind="hist", bins=30): """ Create a distribution plot for a numeric column. Parameters ---------- df : pd.DataFrame Input data. numeric_col : str Numeric column to visualize. kind : str "hist" for histogram, "box" for box plot. bins : int Number of bins for histogram. Returns ------- fig : matplotlib.figure.Figure or None The figure object with the plot, or None if something fails. """ if numeric_col not in df.columns: return None series = df[numeric_col].dropna() if series.empty: return None fig, ax = plt.subplots() if kind == "box": ax.boxplot(series.values, vert=True) ax.set_xticks([1]) ax.set_xticklabels([_safe_title(numeric_col)]) ax.set_ylabel(_safe_title(numeric_col)) ax.set_title(f"Box Plot of {numeric_col}") else: # Default to histogram ax.hist(series.values, bins=bins) ax.set_xlabel(_safe_title(numeric_col)) ax.set_ylabel("Frequency") ax.set_title(f"Histogram of {numeric_col}") fig.tight_layout() return fig def create_category_bar_plot(df, category_col, value_col=None, agg_func="count", top_n=10): """ Create a bar chart for a categorical column. - If value_col is None or agg_func == "count": Show counts of each category. - If value_col is numeric: Aggregate using sum / mean / median. This works nicely with columns like "Region" or "Model" for the Tesla dataset. Parameters ---------- df : pd.DataFrame Input data. category_col : str Name of the categorical column (e.g. "Region"). value_col : str or None Numeric column to aggregate (e.g. "Estimated_Deliveries"), or None. agg_func : str "sum", "mean", "median", or "count". top_n : int Show only the top N categories. Returns ------- fig : matplotlib.figure.Figure or None The figure object with the plot, or None if something fails. """ if category_col not in df.columns: return None data = df.copy() # Pure counts mode if value_col is None or agg_func == "count": counts = data[category_col].value_counts().head(top_n) y_values = counts.values x_labels = counts.index.astype(str).tolist() title = f"Top {top_n} {category_col} by Count" y_label = "Count" else: if value_col not in df.columns: return None data = data[[category_col, value_col]].dropna() if data.empty: return None if agg_func == "mean": grouped = data.groupby(category_col)[value_col].mean() elif agg_func == "median": grouped = data.groupby(category_col)[value_col].median() else: # default to sum grouped = data.groupby(category_col)[value_col].sum() grouped = grouped.sort_values(ascending=False).head(top_n) y_values = grouped.values x_labels = grouped.index.astype(str).tolist() title = f"Top {top_n} {category_col} by {agg_func} of {value_col}" y_label = _safe_title(value_col) fig, ax = plt.subplots() ax.bar(x_labels, y_values) ax.set_xlabel(_safe_title(category_col)) ax.set_ylabel(y_label) ax.set_title(title) ax.tick_params(axis="x", rotation=45) fig.tight_layout() return fig def create_scatter_plot(df, x_col, y_col, category_col=None): """ Create a scatter plot for two numeric columns. If category_col is given (e.g. "Region" or "Model"), points will be split by that category and a legend will be shown. Parameters ---------- df : pd.DataFrame Input data. x_col : str Column for x-axis. y_col : str Column for y-axis. category_col : str or None Optional column to group points by. Returns ------- fig : matplotlib.figure.Figure or None The figure object with the plot, or None if something fails. """ if x_col not in df.columns or y_col not in df.columns: return None data = df[[x_col, y_col] + ([category_col] if category_col and category_col in df.columns else [])].dropna() if data.empty: return None fig, ax = plt.subplots() if category_col is None or category_col not in data.columns: # Simple scatter, no categories ax.scatter(data[x_col], data[y_col]) ax.set_title(f"Scatter Plot: {x_col} vs {y_col}") else: # One scatter per category, with legend for cat_value, group_df in data.groupby(category_col): ax.scatter(group_df[x_col], group_df[y_col], label=str(cat_value)) ax.set_title(f"Scatter Plot: {x_col} vs {y_col} by {category_col}") ax.legend(title=_safe_title(category_col)) ax.set_xlabel(_safe_title(x_col)) ax.set_ylabel(_safe_title(y_col)) ax.grid(True) fig.tight_layout() return fig def create_correlation_heatmap(df, numeric_cols): """ Create a correlation heatmap for numeric columns, with the numeric values displayed on the cells. Parameters ---------- df : pd.DataFrame Input data. numeric_cols : list of str List of numeric column names to include in the correlation matrix. Returns ------- fig : matplotlib.figure.Figure or None The figure object with the heatmap, or None if something fails. """ cols = [c for c in numeric_cols if c in df.columns] if len(cols) < 2: return None corr = df[cols].corr() fig, ax = plt.subplots() cax = ax.matshow(corr.values) fig.colorbar(cax) ax.set_xticks(range(len(cols))) ax.set_yticks(range(len(cols))) ax.set_xticklabels(cols, rotation=45, ha="left") ax.set_yticklabels(cols) ax.set_title("Correlation Heatmap", pad=20) # Add the correlation values on top of each cell for i in range(len(cols)): for j in range(len(cols)): value = corr.values[i, j] ax.text(j, i, f"{value:.2f}", va="center", ha="center", fontsize=8) fig.tight_layout() return fig