DataViz-Agent / src /models.py
Muhammad Mustehson
Update Old Code
4a84072
from datetime import datetime
from enum import Enum
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
class SmallCardNum:
pass
class Continuous:
pass
class DateTime:
pass
class Nominal:
pass
class Route(BaseModel):
label: int = Field(
description="Classify user queries as: 0 for Irrelevant/Vague/Incomplete, 1 for Visualizable, and 2 for SQL-only."
)
class SQLQueryModel(BaseModel):
sql_query: str = Field(..., description="SQL query to execute.")
explanation: str = Field(..., description="Short explanation of the SQL query.")
class DataPoint(BaseModel):
x: int | float | str | None = None
y: int | float | str | None = None
bin_start: int | float | None = None
bin_end: int | float | None = None
frequency: int | float | None = None
@field_validator("bin_start", "bin_end", "frequency", "x", "y", mode="before")
@classmethod
def to_native(cls, field_value):
if field_value is not None and isinstance(
field_value, np.float64 | np.float32 | np.int64
):
return float(field_value)
if isinstance(field_value, (datetime, np.datetime64, pd.Timestamp)): # noqa: UP038
return field_value.strftime("%Y-%m-%d")
return field_value
@model_validator(mode="before")
@classmethod
def validate_keys(cls, values):
x, y = values.get("x"), values.get("y")
bin_start, bin_end, frequency = (
values.get("bin_start"),
values.get("bin_end"),
values.get("frequency"),
)
xy = x is not None and y is not None
bxy = bin_start is not None and bin_end is not None and frequency is not None
if not (xy or bxy):
raise ValueError(
"Invalid input: Must provide either (x, y) OR (bin_start, bin_end, frequency), but not a mix."
)
return values
class Data(BaseModel):
data: list[DataPoint] = Field(default_factory=list)
@classmethod
def validate_data(cls, data):
try:
return cls(data=data)
except ValidationError as e:
raise ValueError(f"Invalid data format: {e.errors()[0]}") # noqa: B904
class TableData(BaseModel):
data: pd.DataFrame = Field(default_factory=None)
class Config:
arbitrary_types_allowed = True
@model_validator(mode="after")
def timestamp_to_str(self):
# Convert all datetime columns to string format
for col in self.data.select_dtypes(include=["datetime"]).columns:
if col:
self.data[col] = self.data[col].astype(str)
return self
def model_dump(self, *args, **kwargs): # noqa: ARG002
return self.data.to_dict(orient="list")
class Charts(BaseModel):
bar: Data | None = None
line: Data | None = None
pie: Data | None = None
hist: Data | None = None
@model_validator(mode="after")
def process_charts_data(self):
def stringify(data):
if data and data.data:
for point in data.data:
if not isinstance(point.x, str):
point.x = str(point.x)
return data
if self.bar:
self.bar = stringify(self.bar)
if self.pie:
self.pie = stringify(self.pie)
return self
class PlotType(str, Enum):
bar = ("bar",)
line = ("line",)
pie = ("pie",)
hist = ("hist",)
class PlotConfig(BaseModel):
type: PlotType = Field(
description="Type of plot, e.g., 'bar', 'line', 'pie'. Supported types depend on ShadCN implementation.",
)
title: str = Field(description="Title of the plot to display above the plot.")
x_axis_label: str = Field(description="Label for the X-axis of the plot.")
y_axis_label: str = Field(description="Label for the Y-axis of the plot.")
legend: bool = Field(
default=True, description="Flag to display a legend for the plot."
)