File size: 4,300 Bytes
9eecab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from utils.logger import logger
import matplotlib.pyplot as plt
import plotext as plt_terminal
import os


class VisualizationAgent:

    def __init__(self, registry):
        self.registry = registry
        os.makedirs("output", exist_ok=True)


    def _detect_dataset(self, query, datasets):

        q = query.lower()

        for d in datasets:
            if d.lower() in q:
                return d

        logger.info("Dataset not specified, using default dataset.")
        return datasets[0]


    def _detect_column(self, query, columns):

        q = query.lower()

        for col in columns:
            if col.lower() in q:
                return col

        return None


    def handle(self, query):

        q = query.lower()

        try:

            datasets = self.registry.list_datasets()

            if not datasets:
                logger.warning("VisualizationAgent called with no datasets loaded.")
                return "No datasets available."

            dataset = self._detect_dataset(q, datasets)

            df = self.registry.load_dataframe(dataset)

            columns = df.columns.tolist()

        except Exception as e:

            logger.error(f"Failed loading dataset in VisualizationAgent | {e}")
            return "Failed to load dataset."

        try:

            column = self._detect_column(q, columns)

            if column is None:
                logger.warning("Column not detected for visualization.")
                return "Column not found in dataset."

            # ---------- HISTOGRAM ----------
            if "hist" in q or "histogram" in q:

                logger.info(f"Generating histogram for {column} in {dataset}")

                values = df[column].dropna().values

                # Terminal plot
                plt_terminal.clear_figure()
                plt_terminal.hist(values, bins=20)
                plt_terminal.title(f"Histogram of {column}")
                plt_terminal.xlabel(column)
                plt_terminal.ylabel("Frequency")
                plt_terminal.show()

                # Save PNG
                filepath = f"output/{dataset}_{column}_hist.png"

                plt.figure()
                df[column].dropna().hist()
                plt.title(f"Histogram of {column}")
                plt.xlabel(column)
                plt.ylabel("Frequency")
                plt.savefig(filepath)
                plt.close()

                logger.info(f"Histogram saved → {filepath}")

                return f"Histogram generated in terminal. PNG saved to {filepath}"


            # ---------- BAR CHART ----------
            if "bar" in q or "bar chart" in q:

                unique_values = df[column].nunique()

                if unique_values > 50:
                    logger.warning(
                        f"Column '{column}' has {unique_values} unique values. Skipping bar chart."
                    )
                    return f"Column '{column}' has {unique_values} unique values. Too many to visualize meaningfully."

                logger.info(f"Generating bar chart for {column} in {dataset}")

                counts = df[column].value_counts()

                # Terminal plot
                plt_terminal.clear_figure()
                plt_terminal.bar(
                    counts.index.astype(str).tolist(),
                    counts.values.tolist()
                )
                plt_terminal.title(f"Bar Chart of {column}")
                plt_terminal.xlabel(column)
                plt_terminal.ylabel("Count")
                plt_terminal.show()

                # Save PNG
                filepath = f"output/{dataset}_{column}_bar.png"

                plt.figure()
                counts.plot(kind="bar")
                plt.title(f"Bar Chart of {column}")
                plt.xlabel(column)
                plt.ylabel("Count")
                plt.savefig(filepath)
                plt.close()

                logger.info(f"Bar chart saved → {filepath}")

                return f"Bar chart generated in terminal. PNG saved to {filepath}"

            return "Visualization query not understood."

        except Exception as e:

            logger.error(f"Visualization failed | Query: {query} | Error: {e}")
            return "Visualization agent error."