Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from enum import Enum | |
| import numpy as np | |
| import pandas as pd | |
| from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator | |
| class SmallCardNum: | |
| pass | |
| class Continuous: | |
| pass | |
| class DateTime: | |
| pass | |
| class Nominal: | |
| pass | |
| class Route(BaseModel): | |
| label: int = Field( | |
| description="Classify user queries as: 0 for Irrelevant/Vague/Incomplete, 1 for Visualizable, and 2 for SQL-only." | |
| ) | |
| class SQLQueryModel(BaseModel): | |
| sql_query: str = Field(..., description="SQL query to execute.") | |
| explanation: str = Field(..., description="Short explanation of the SQL query.") | |
| class DataPoint(BaseModel): | |
| x: int | float | str | None = None | |
| y: int | float | str | None = None | |
| bin_start: int | float | None = None | |
| bin_end: int | float | None = None | |
| frequency: int | float | None = None | |
| def to_native(cls, field_value): | |
| if field_value is not None and isinstance( | |
| field_value, np.float64 | np.float32 | np.int64 | |
| ): | |
| return float(field_value) | |
| if isinstance(field_value, (datetime, np.datetime64, pd.Timestamp)): # noqa: UP038 | |
| return field_value.strftime("%Y-%m-%d") | |
| return field_value | |
| def validate_keys(cls, values): | |
| x, y = values.get("x"), values.get("y") | |
| bin_start, bin_end, frequency = ( | |
| values.get("bin_start"), | |
| values.get("bin_end"), | |
| values.get("frequency"), | |
| ) | |
| xy = x is not None and y is not None | |
| bxy = bin_start is not None and bin_end is not None and frequency is not None | |
| if not (xy or bxy): | |
| raise ValueError( | |
| "Invalid input: Must provide either (x, y) OR (bin_start, bin_end, frequency), but not a mix." | |
| ) | |
| return values | |
| class Data(BaseModel): | |
| data: list[DataPoint] = Field(default_factory=list) | |
| def validate_data(cls, data): | |
| try: | |
| return cls(data=data) | |
| except ValidationError as e: | |
| raise ValueError(f"Invalid data format: {e.errors()[0]}") # noqa: B904 | |
| class TableData(BaseModel): | |
| data: pd.DataFrame = Field(default_factory=None) | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def timestamp_to_str(self): | |
| # Convert all datetime columns to string format | |
| for col in self.data.select_dtypes(include=["datetime"]).columns: | |
| if col: | |
| self.data[col] = self.data[col].astype(str) | |
| return self | |
| def model_dump(self, *args, **kwargs): # noqa: ARG002 | |
| return self.data.to_dict(orient="list") | |
| class Charts(BaseModel): | |
| bar: Data | None = None | |
| line: Data | None = None | |
| pie: Data | None = None | |
| hist: Data | None = None | |
| def process_charts_data(self): | |
| def stringify(data): | |
| if data and data.data: | |
| for point in data.data: | |
| if not isinstance(point.x, str): | |
| point.x = str(point.x) | |
| return data | |
| if self.bar: | |
| self.bar = stringify(self.bar) | |
| if self.pie: | |
| self.pie = stringify(self.pie) | |
| return self | |
| class PlotType(str, Enum): | |
| bar = ("bar",) | |
| line = ("line",) | |
| pie = ("pie",) | |
| hist = ("hist",) | |
| class PlotConfig(BaseModel): | |
| type: PlotType = Field( | |
| description="Type of plot, e.g., 'bar', 'line', 'pie'. Supported types depend on ShadCN implementation.", | |
| ) | |
| title: str = Field(description="Title of the plot to display above the plot.") | |
| x_axis_label: str = Field(description="Label for the X-axis of the plot.") | |
| y_axis_label: str = Field(description="Label for the Y-axis of the plot.") | |
| legend: bool = Field( | |
| default=True, description="Flag to display a legend for the plot." | |
| ) | |