Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| def _safe_title(text): | |
| """Small helper to avoid None or empty titles.""" | |
| if text is None: | |
| return "" | |
| return str(text) | |
| def create_time_series_plot(df, date_col, value_col, agg_func="sum", freq="M", category_col=None): | |
| """ | |
| Create a time series line plot. | |
| If category_col is provided (for example "Region" or "Model"), | |
| it will draw one line per category with a legend. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| Input data. | |
| date_col : str | |
| Column name with datetime values (e.g. "Date"). | |
| value_col : str | |
| Numeric column to aggregate (e.g. "Estimated_Deliveries"). | |
| agg_func : str | |
| Aggregation method: "sum", "mean", "count", "median". | |
| freq : str | |
| Resampling frequency ("M" = month, "Q" = quarter). | |
| category_col : str or None | |
| Optional column to group by (e.g. "Region" or "Model"). | |
| Returns | |
| ------- | |
| fig : matplotlib.figure.Figure or None | |
| The figure object with the plot, or None if something fails. | |
| """ | |
| if date_col not in df.columns or value_col not in df.columns: | |
| return None | |
| data = df.copy() | |
| # Make sure date column is datetime | |
| if not pd.api.types.is_datetime64_any_dtype(data[date_col]): | |
| try: | |
| data[date_col] = pd.to_datetime(data[date_col]) | |
| except Exception: | |
| return None | |
| # If no category, just aggregate the whole series over time (old behavior) | |
| if category_col is None or category_col not in df.columns: | |
| data = data.set_index(date_col) | |
| if agg_func == "mean": | |
| grouped = data[value_col].resample(freq).mean() | |
| elif agg_func == "count": | |
| grouped = data[value_col].resample(freq).count() | |
| elif agg_func == "median": | |
| grouped = data[value_col].resample(freq).median() | |
| else: | |
| # default to sum | |
| grouped = data[value_col].resample(freq).sum() | |
| fig, ax = plt.subplots() | |
| ax.plot(grouped.index, grouped.values) | |
| ax.set_xlabel("Date") | |
| ax.set_ylabel(_safe_title(value_col)) | |
| ax.set_title(f"Time Series of {value_col} ({agg_func})") | |
| ax.grid(True) | |
| fig.tight_layout() | |
| return fig | |
| # If category_col is given (e.g. Region / Model), draw a line per category. | |
| data = data[[date_col, value_col, category_col]].dropna() | |
| if data.empty: | |
| return None | |
| # Group by category and date (using Grouper for frequency) | |
| grouped = data.groupby( | |
| [category_col, pd.Grouper(key=date_col, freq=freq)] | |
| )[value_col] | |
| if agg_func == "mean": | |
| grouped = grouped.mean() | |
| elif agg_func == "count": | |
| grouped = grouped.count() | |
| elif agg_func == "median": | |
| grouped = grouped.median() | |
| else: | |
| grouped = grouped.sum() | |
| # Now grouped is a Series with MultiIndex: (category, date) | |
| # We unstack so that the index is date and columns are categories. | |
| table = grouped.unstack(0) | |
| fig, ax = plt.subplots() | |
| # Plot one line per category | |
| for col in table.columns: | |
| # Some categories might be all NaN if data is weird; skip those | |
| series = table[col].dropna() | |
| if series.empty: | |
| continue | |
| ax.plot(series.index, series.values, label=str(col)) | |
| ax.set_xlabel("Date") | |
| ax.set_ylabel(_safe_title(value_col)) | |
| ax.set_title(f"Time Series of {value_col} by {category_col} ({agg_func})") | |
| ax.grid(True) | |
| ax.legend(title=_safe_title(category_col)) | |
| fig.tight_layout() | |
| return fig | |
| def create_distribution_plot(df, numeric_col, kind="hist", bins=30): | |
| """ | |
| Create a distribution plot for a numeric column. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| Input data. | |
| numeric_col : str | |
| Numeric column to visualize. | |
| kind : str | |
| "hist" for histogram, "box" for box plot. | |
| bins : int | |
| Number of bins for histogram. | |
| Returns | |
| ------- | |
| fig : matplotlib.figure.Figure or None | |
| The figure object with the plot, or None if something fails. | |
| """ | |
| if numeric_col not in df.columns: | |
| return None | |
| series = df[numeric_col].dropna() | |
| if series.empty: | |
| return None | |
| fig, ax = plt.subplots() | |
| if kind == "box": | |
| ax.boxplot(series.values, vert=True) | |
| ax.set_xticks([1]) | |
| ax.set_xticklabels([_safe_title(numeric_col)]) | |
| ax.set_ylabel(_safe_title(numeric_col)) | |
| ax.set_title(f"Box Plot of {numeric_col}") | |
| else: | |
| # Default to histogram | |
| ax.hist(series.values, bins=bins) | |
| ax.set_xlabel(_safe_title(numeric_col)) | |
| ax.set_ylabel("Frequency") | |
| ax.set_title(f"Histogram of {numeric_col}") | |
| fig.tight_layout() | |
| return fig | |
| def create_category_bar_plot(df, category_col, value_col=None, agg_func="count", top_n=10): | |
| """ | |
| Create a bar chart for a categorical column. | |
| - If value_col is None or agg_func == "count": | |
| Show counts of each category. | |
| - If value_col is numeric: | |
| Aggregate using sum / mean / median. | |
| This works nicely with columns like "Region" or "Model" | |
| for the Tesla dataset. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| Input data. | |
| category_col : str | |
| Name of the categorical column (e.g. "Region"). | |
| value_col : str or None | |
| Numeric column to aggregate (e.g. "Estimated_Deliveries"), or None. | |
| agg_func : str | |
| "sum", "mean", "median", or "count". | |
| top_n : int | |
| Show only the top N categories. | |
| Returns | |
| ------- | |
| fig : matplotlib.figure.Figure or None | |
| The figure object with the plot, or None if something fails. | |
| """ | |
| if category_col not in df.columns: | |
| return None | |
| data = df.copy() | |
| # Pure counts mode | |
| if value_col is None or agg_func == "count": | |
| counts = data[category_col].value_counts().head(top_n) | |
| y_values = counts.values | |
| x_labels = counts.index.astype(str).tolist() | |
| title = f"Top {top_n} {category_col} by Count" | |
| y_label = "Count" | |
| else: | |
| if value_col not in df.columns: | |
| return None | |
| data = data[[category_col, value_col]].dropna() | |
| if data.empty: | |
| return None | |
| if agg_func == "mean": | |
| grouped = data.groupby(category_col)[value_col].mean() | |
| elif agg_func == "median": | |
| grouped = data.groupby(category_col)[value_col].median() | |
| else: | |
| # default to sum | |
| grouped = data.groupby(category_col)[value_col].sum() | |
| grouped = grouped.sort_values(ascending=False).head(top_n) | |
| y_values = grouped.values | |
| x_labels = grouped.index.astype(str).tolist() | |
| title = f"Top {top_n} {category_col} by {agg_func} of {value_col}" | |
| y_label = _safe_title(value_col) | |
| fig, ax = plt.subplots() | |
| ax.bar(x_labels, y_values) | |
| ax.set_xlabel(_safe_title(category_col)) | |
| ax.set_ylabel(y_label) | |
| ax.set_title(title) | |
| ax.tick_params(axis="x", rotation=45) | |
| fig.tight_layout() | |
| return fig | |
| def create_scatter_plot(df, x_col, y_col, category_col=None): | |
| """ | |
| Create a scatter plot for two numeric columns. | |
| If category_col is given (e.g. "Region" or "Model"), | |
| points will be split by that category and a legend will be shown. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| Input data. | |
| x_col : str | |
| Column for x-axis. | |
| y_col : str | |
| Column for y-axis. | |
| category_col : str or None | |
| Optional column to group points by. | |
| Returns | |
| ------- | |
| fig : matplotlib.figure.Figure or None | |
| The figure object with the plot, or None if something fails. | |
| """ | |
| if x_col not in df.columns or y_col not in df.columns: | |
| return None | |
| data = df[[x_col, y_col] + ([category_col] if category_col and category_col in df.columns else [])].dropna() | |
| if data.empty: | |
| return None | |
| fig, ax = plt.subplots() | |
| if category_col is None or category_col not in data.columns: | |
| # Simple scatter, no categories | |
| ax.scatter(data[x_col], data[y_col]) | |
| ax.set_title(f"Scatter Plot: {x_col} vs {y_col}") | |
| else: | |
| # One scatter per category, with legend | |
| for cat_value, group_df in data.groupby(category_col): | |
| ax.scatter(group_df[x_col], group_df[y_col], label=str(cat_value)) | |
| ax.set_title(f"Scatter Plot: {x_col} vs {y_col} by {category_col}") | |
| ax.legend(title=_safe_title(category_col)) | |
| ax.set_xlabel(_safe_title(x_col)) | |
| ax.set_ylabel(_safe_title(y_col)) | |
| ax.grid(True) | |
| fig.tight_layout() | |
| return fig | |
| def create_correlation_heatmap(df, numeric_cols): | |
| """ | |
| Create a correlation heatmap for numeric columns, | |
| with the numeric values displayed on the cells. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| Input data. | |
| numeric_cols : list of str | |
| List of numeric column names to include in the correlation matrix. | |
| Returns | |
| ------- | |
| fig : matplotlib.figure.Figure or None | |
| The figure object with the heatmap, or None if something fails. | |
| """ | |
| cols = [c for c in numeric_cols if c in df.columns] | |
| if len(cols) < 2: | |
| return None | |
| corr = df[cols].corr() | |
| fig, ax = plt.subplots() | |
| cax = ax.matshow(corr.values) | |
| fig.colorbar(cax) | |
| ax.set_xticks(range(len(cols))) | |
| ax.set_yticks(range(len(cols))) | |
| ax.set_xticklabels(cols, rotation=45, ha="left") | |
| ax.set_yticklabels(cols) | |
| ax.set_title("Correlation Heatmap", pad=20) | |
| # Add the correlation values on top of each cell | |
| for i in range(len(cols)): | |
| for j in range(len(cols)): | |
| value = corr.values[i, j] | |
| ax.text(j, i, f"{value:.2f}", va="center", ha="center", fontsize=8) | |
| fig.tight_layout() | |
| return fig | |