Spaces:
Sleeping
Sleeping
| import tempfile | |
| from typing import Dict, Any | |
| import plotly.graph_objects as go | |
| from llama_index.core.tools import FunctionTool | |
| def make_figure( | |
| title: str, | |
| content: str, | |
| chart_type: str, | |
| data: Dict[str, Any] | |
| ) -> str: | |
| """Create a Plotly figure based on chart_type and data, save it as a PNG, | |
| and return the filepath to the generated image. | |
| Args: | |
| title (str): The main title of the learning unit. | |
| content (str): The raw content of the learning unit. | |
| chart_type (str): The type of chart to generate (e.g., "bar_chart", | |
| "line_graph", "pie_chart", "scatter_plot", "histogram"). | |
| data (Dict[str, Any]): A dictionary containing the data for the chart. | |
| Expected keys depend on chart_type: | |
| - "bar_chart": {"labels": List[str], "values": List[float], | |
| "x_label": str, "y_label": str} | |
| - "line_graph": {"x": List[float], "y": List[float], | |
| "x_label": str, "y_label": str} | |
| - "pie_chart": {"sizes": List[float], "labels": List[str]} | |
| - "scatter_plot": {"x": List[float], "y": List[float], | |
| "x_label": str, "y_label": str} | |
| - "histogram": {"values": List[float], "bins": int, | |
| "x_label": str, "y_label": str} | |
| Returns: | |
| str: The filepath to the generated image file. | |
| """ | |
| fig = go.Figure() | |
| try: | |
| if chart_type == "bar_chart": | |
| labels = data.get("labels", []) | |
| values = data.get("values", []) | |
| fig.add_trace(go.Bar(x=labels, y=values, marker_color='skyblue')) | |
| fig.update_layout(title_text=f"Bar Chart for {title}", | |
| xaxis_title=data.get("x_label", "Category"), | |
| yaxis_title=data.get("y_label", "Value")) | |
| elif chart_type == "line_graph": | |
| x = data.get("x", []) | |
| y = data.get("y", []) | |
| fig.add_trace(go.Scatter(x=x, y=y, mode='lines+markers', | |
| marker_color='purple')) | |
| fig.update_layout(title_text=f"Line Graph for {title}", | |
| xaxis_title=data.get("x_label", "X-axis"), | |
| yaxis_title=data.get("y_label", "Y-axis")) | |
| elif chart_type == "pie_chart": | |
| sizes = data.get("sizes", []) | |
| labels = data.get("labels", []) | |
| fig.add_trace(go.Pie(labels=labels, values=sizes, hole=.3)) | |
| fig.update_layout(title_text=f"Pie Chart for {title}") | |
| elif chart_type == "scatter_plot": | |
| x = data.get("x", []) | |
| y = data.get("y", []) | |
| fig.add_trace(go.Scatter(x=x, y=y, mode='markers', marker_color='red')) | |
| fig.update_layout(title_text=f"Scatter Plot for {title}", | |
| xaxis_title=data.get("x_label", "X-axis"), | |
| yaxis_title=data.get("y_label", "Y-axis")) | |
| elif chart_type == "histogram": | |
| values = data.get("values", []) | |
| bins = data.get("bins", 10) | |
| fig.add_trace(go.Histogram(x=values, nbinsx=bins, | |
| marker_color='green')) | |
| fig.update_layout(title_text=f"Histogram for {title}", | |
| xaxis_title=data.get("x_label", "Value"), | |
| yaxis_title=data.get("y_label", "Frequency")) | |
| else: | |
| # Handle unsupported chart types | |
| fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='text', | |
| text=[f"Figure for {title}", | |
| f"(Unsupported Chart Type: {chart_type})"], | |
| textfont_size=12)) | |
| fig.update_layout(xaxis_visible=False, yaxis_visible=False, | |
| title_text=f"Figure for {title}") | |
| except Exception as e: | |
| fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='text', | |
| text=[f"Figure for {title}", | |
| f"(Error generating figure: {e})"], | |
| textfont_size=12)) | |
| fig.update_layout(xaxis_visible=False, yaxis_visible=False, | |
| title_text=f"Figure for {title}") | |
| # Save the figure to a temporary file and return its path | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', prefix='plotly_figure_') | |
| fig.write_image(temp_file.name, format='png', width=800, height=500, scale=2) | |
| temp_file.close() | |
| return temp_file.name | |
| make_figure_tool = FunctionTool.from_defaults(fn=make_figure, name="make_figure") | |