Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import os | |
| import logging | |
| import time | |
| class ChartGenerator: | |
| def __init__(self, data=None): | |
| logging.info("Initializing ChartGenerator") | |
| if data is not None: | |
| self.data = data | |
| else: | |
| self.data = pd.read_excel(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'sample_data.xlsx')) | |
| def generate_chart(self, plot_args): | |
| start_time = time.time() | |
| logging.info(f"Generating chart with arguments: {plot_args}") | |
| # Validate columns before plotting | |
| x_col = plot_args['x'] | |
| y_cols = plot_args['y'] | |
| missing_cols = [] | |
| if x_col not in self.data.columns: | |
| missing_cols.append(x_col) | |
| for y in y_cols: | |
| if y not in self.data.columns: | |
| missing_cols.append(y) | |
| if missing_cols: | |
| logging.error(f"Missing columns in data: {missing_cols}") | |
| logging.info(f"Available columns: {list(self.data.columns)}") | |
| raise ValueError(f"Missing columns in data: {missing_cols}") | |
| # Clear any existing plots | |
| plt.clf() | |
| plt.close('all') | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| for y in y_cols: | |
| color = plot_args.get('color', None) | |
| if plot_args.get('chart_type', 'line') == 'bar': | |
| ax.bar(self.data[x_col], self.data[y], label=y, color=color) | |
| else: | |
| ax.plot(self.data[x_col], self.data[y], label=y, color=color, marker='o') | |
| ax.set_xlabel(x_col) | |
| ax.set_ylabel('Value') | |
| ax.set_title(f'{plot_args.get("chart_type", "line").title()} Chart') | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| # Rotate x-axis labels if needed | |
| if len(self.data[x_col]) > 5: | |
| plt.xticks(rotation=45) | |
| chart_filename = 'chart.png' | |
| output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images') | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| logging.info(f"Created output directory: {output_dir}") | |
| full_path = os.path.join(output_dir, chart_filename) | |
| if os.path.exists(full_path): | |
| os.remove(full_path) | |
| logging.info(f"Removed existing chart file: {full_path}") | |
| # Save with high DPI for better quality | |
| plt.savefig(full_path, dpi=300, bbox_inches='tight', facecolor='white') | |
| plt.close(fig) | |
| # Verify file was created | |
| if os.path.exists(full_path): | |
| file_size = os.path.getsize(full_path) | |
| logging.info(f"Chart generated and saved to {full_path} (size: {file_size} bytes)") | |
| else: | |
| logging.error(f"Failed to create chart file at {full_path}") | |
| raise FileNotFoundError(f"Chart file was not created at {full_path}") | |
| return os.path.join('static', 'images', chart_filename) |