from datetime import datetime from enum import Enum import numpy as np import pandas as pd from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator class SmallCardNum: pass class Continuous: pass class DateTime: pass class Nominal: pass class Route(BaseModel): label: int = Field( description="Classify user queries as: 0 for Irrelevant/Vague/Incomplete, 1 for Visualizable, and 2 for SQL-only." ) class SQLQueryModel(BaseModel): sql_query: str = Field(..., description="SQL query to execute.") explanation: str = Field(..., description="Short explanation of the SQL query.") class DataPoint(BaseModel): x: int | float | str | None = None y: int | float | str | None = None bin_start: int | float | None = None bin_end: int | float | None = None frequency: int | float | None = None @field_validator("bin_start", "bin_end", "frequency", "x", "y", mode="before") @classmethod def to_native(cls, field_value): if field_value is not None and isinstance( field_value, np.float64 | np.float32 | np.int64 ): return float(field_value) if isinstance(field_value, (datetime, np.datetime64, pd.Timestamp)): # noqa: UP038 return field_value.strftime("%Y-%m-%d") return field_value @model_validator(mode="before") @classmethod def validate_keys(cls, values): x, y = values.get("x"), values.get("y") bin_start, bin_end, frequency = ( values.get("bin_start"), values.get("bin_end"), values.get("frequency"), ) xy = x is not None and y is not None bxy = bin_start is not None and bin_end is not None and frequency is not None if not (xy or bxy): raise ValueError( "Invalid input: Must provide either (x, y) OR (bin_start, bin_end, frequency), but not a mix." ) return values class Data(BaseModel): data: list[DataPoint] = Field(default_factory=list) @classmethod def validate_data(cls, data): try: return cls(data=data) except ValidationError as e: raise ValueError(f"Invalid data format: {e.errors()[0]}") # noqa: B904 class TableData(BaseModel): data: pd.DataFrame = Field(default_factory=None) class Config: arbitrary_types_allowed = True @model_validator(mode="after") def timestamp_to_str(self): # Convert all datetime columns to string format for col in self.data.select_dtypes(include=["datetime"]).columns: if col: self.data[col] = self.data[col].astype(str) return self def model_dump(self, *args, **kwargs): # noqa: ARG002 return self.data.to_dict(orient="list") class Charts(BaseModel): bar: Data | None = None line: Data | None = None pie: Data | None = None hist: Data | None = None @model_validator(mode="after") def process_charts_data(self): def stringify(data): if data and data.data: for point in data.data: if not isinstance(point.x, str): point.x = str(point.x) return data if self.bar: self.bar = stringify(self.bar) if self.pie: self.pie = stringify(self.pie) return self class PlotType(str, Enum): bar = ("bar",) line = ("line",) pie = ("pie",) hist = ("hist",) class PlotConfig(BaseModel): type: PlotType = Field( description="Type of plot, e.g., 'bar', 'line', 'pie'. Supported types depend on ShadCN implementation.", ) title: str = Field(description="Title of the plot to display above the plot.") x_axis_label: str = Field(description="Label for the X-axis of the plot.") y_axis_label: str = Field(description="Label for the Y-axis of the plot.") legend: bool = Field( default=True, description="Flag to display a legend for the plot." )