Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import io | |
| import re | |
| import uuid | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from statsmodels.tsa.stattools import adfuller | |
| from models import OutlierInfo, UploadWarnings | |
| # βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MAX_SERIES = 20 | |
| MIN_ROWS = 15 | |
| MAX_WINDOW = 512 # Chronos-Bolt context limit | |
| DATE_FORMATS = [ | |
| "%d/%m/%Y", # 15/04/2024 β most common in India | |
| "%d-%m-%Y", # 15-04-2024 | |
| "%d/%m/%y", # 15/04/24 | |
| "%d-%m-%y", # 15-04-24 | |
| "%d-%b-%Y", # 15-Apr-2024 | |
| "%d-%b-%y", # 15-Apr-24 | |
| "%d %B %Y", # 15 April 2024 | |
| "%B %d, %Y", # April 15, 2024 | |
| "%Y-%m-%d", # 2024-04-15 β ISO | |
| "%m/%d/%Y", # 04/15/2024 β US format (tried last) | |
| "%Y%m%d", # 20240415 | |
| ] | |
| # Friendly error codes β matched in frontend ERROR_MAP | |
| class IngestionError(Exception): | |
| def __init__(self, code: str, message: str): | |
| self.code = code | |
| self.message = message | |
| super().__init__(message) | |
| # βββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ingest(file_bytes: bytes, filename: str) -> dict: | |
| """ | |
| Full ingestion pipeline. Takes raw file bytes, returns a clean | |
| session dict ready for forecasting. | |
| Raises IngestionError with a user-friendly message if the file | |
| can't be used. | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| raw_df = _load_file(file_bytes, filename) | |
| raw_df = _strip_empty(raw_df) | |
| date_col, value_cols = _detect_columns(raw_df) | |
| is_multi = len(value_cols) > 1 | |
| # For multi-series we return the list and let the user pick. | |
| # Actual forecasting uses one series at a time. | |
| series_list = value_cols if is_multi else [] | |
| # Work with the first value column for the upload preview. | |
| # User can change this in ColumnPicker before forecasting. | |
| value_col = value_cols[0] | |
| df = _parse_dates(raw_df, date_col) | |
| df = _clean_values(df, value_col) | |
| df = _sort_and_dedup(df, date_col) | |
| df, gap_fraction = _handle_gaps(df, date_col, value_col) | |
| _validate_length(df) | |
| frequency = _detect_frequency(df, date_col) | |
| outliers = _find_outliers(df, value_col, date_col) | |
| warnings = _build_warnings(df, value_col, frequency, gap_fraction) | |
| preview = _make_preview(df, date_col, value_col) | |
| # Store the cleaned dataframe in a simple in-memory session store. | |
| _SESSION_STORE[session_id] = { | |
| "df": df, | |
| "date_col": date_col, | |
| "value_col": value_col, | |
| "frequency": frequency, | |
| "warnings": warnings, | |
| "is_multi": is_multi, | |
| "value_cols": value_cols, | |
| } | |
| return { | |
| "session_id": session_id, | |
| "detected_date_col": date_col, | |
| "detected_value_col": value_col, | |
| "columns": list(raw_df.columns), | |
| "series_list": series_list, | |
| "preview": preview, | |
| "frequency": frequency, | |
| "n_rows": len(df), | |
| "outliers": outliers, | |
| "warnings": warnings.__dict__, | |
| } | |
| def get_session(session_id: str) -> dict: | |
| if session_id not in _SESSION_STORE: | |
| raise IngestionError("SESSION_NOT_FOUND", "Session expired. Please upload your file again.") | |
| return _SESSION_STORE[session_id] | |
| def prepare_series( | |
| session_id: str, | |
| date_col: str, | |
| value_col: str, | |
| outlier_action: str = "include", | |
| series_name: str | None = None, | |
| ) -> dict: | |
| """ | |
| Called just before forecasting. Validates the user's column selection, | |
| applies outlier action, and returns the windowed numpy array. | |
| """ | |
| session = get_session(session_id) | |
| df = session["df"].copy() | |
| # Re-clean with the user's confirmed column choice (may differ from auto-detect) | |
| if value_col not in df.columns: | |
| raise IngestionError( | |
| "NON_NUMERIC", | |
| f"Column '{value_col}' not found. Please pick a valid column." | |
| ) | |
| df = _clean_values(df, value_col) | |
| if outlier_action == "cap": | |
| df = _cap_outliers(df, value_col) | |
| series = df[value_col].dropna().values.astype(np.float64) | |
| if len(series) < MIN_ROWS: | |
| raise IngestionError( | |
| "TOO_FEW_ROWS", | |
| f"Need at least {MIN_ROWS} data points. You have {len(series)}." | |
| ) | |
| # Slide a window if the series is longer than Chronos can handle | |
| if len(series) > MAX_WINDOW: | |
| series = series[-MAX_WINDOW:] | |
| # Build matching date index for the windowed series | |
| dates = df[date_col].iloc[-len(series):].dt.strftime("%Y-%m-%d").tolist() | |
| warnings = session["warnings"] | |
| return { | |
| "series": series, | |
| "dates": dates, | |
| "frequency": session["frequency"], | |
| "n_rows": len(series), | |
| "is_financial": warnings.non_stationary, | |
| "is_intermittent": warnings.intermittent, | |
| "warnings": warnings, | |
| } | |
| # βββ File loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_file(file_bytes: bytes, filename: str) -> pd.DataFrame: | |
| ext = Path(filename).suffix.lower() | |
| if ext in (".xlsx", ".xls"): | |
| try: | |
| return pd.read_excel(io.BytesIO(file_bytes), header=0) | |
| except Exception: | |
| raise IngestionError( | |
| "UNSUPPORTED_FORMAT", | |
| "Could not read the Excel file. Try saving it as CSV and uploading again." | |
| ) | |
| if ext == ".csv": | |
| delimiter = _detect_delimiter(file_bytes) | |
| try: | |
| return pd.read_csv(io.BytesIO(file_bytes), sep=delimiter, header=0) | |
| except Exception: | |
| raise IngestionError( | |
| "UNSUPPORTED_FORMAT", | |
| "Could not read the CSV file. Make sure it has headers in the first row." | |
| ) | |
| raise IngestionError( | |
| "UNSUPPORTED_FORMAT", | |
| "Please upload a CSV or Excel (.xlsx) file." | |
| ) | |
| def _detect_delimiter(file_bytes: bytes) -> str: | |
| # Sample the first 2KB to avoid reading large files just for detection | |
| sample = file_bytes[:2048].decode("utf-8", errors="ignore") | |
| counts = {d: sample.count(d) for d in (",", ";", "\t", "|")} | |
| return max(counts, key=counts.get) | |
| def _strip_empty(df: pd.DataFrame) -> pd.DataFrame: | |
| # Drop rows and columns that are entirely empty (common in Excel exports) | |
| df = df.dropna(how="all") | |
| df = df.loc[:, df.notna().any()] | |
| df.columns = [str(c).strip() for c in df.columns] | |
| return df.reset_index(drop=True) | |
| # βββ Column detection βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _detect_columns(df: pd.DataFrame) -> tuple[str, list[str]]: | |
| """ | |
| Returns (date_col, [value_col, ...]). | |
| Date column = first column where >80% of values parse as a date. | |
| Value columns = all numeric columns that aren't the date column. | |
| """ | |
| date_col = None | |
| for col in df.columns: | |
| if _col_is_date(df[col]): | |
| date_col = col | |
| break | |
| if date_col is None: | |
| raise IngestionError( | |
| "NO_DATE_COL", | |
| "We couldn't find a date column. Make sure one column has dates like 15/04/2024." | |
| ) | |
| value_cols = [] | |
| for col in df.columns: | |
| if col == date_col: | |
| continue | |
| cleaned = df[col].apply(_parse_indian_number) | |
| numeric_frac = cleaned.notna().mean() | |
| if numeric_frac > 0.7: | |
| value_cols.append(col) | |
| if not value_cols: | |
| raise IngestionError( | |
| "NON_NUMERIC", | |
| "We couldn't find a numeric column to forecast. " | |
| "Make sure one column has your sales or price numbers." | |
| ) | |
| if len(value_cols) > MAX_SERIES: | |
| value_cols = value_cols[:MAX_SERIES] | |
| return date_col, value_cols | |
| def _col_is_date(series: pd.Series) -> bool: | |
| sample = series.dropna().astype(str).head(20) | |
| if len(sample) == 0: | |
| return False | |
| successes = sum(1 for v in sample if _try_parse_date(v) is not None) | |
| return successes / len(sample) > 0.8 | |
| # βββ Date parsing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_dates(df: pd.DataFrame, date_col: str) -> pd.DataFrame: | |
| df = df.copy() | |
| df[date_col] = df[date_col].astype(str).apply(_try_parse_date) | |
| unparseable = df[date_col].isna().sum() | |
| if unparseable / len(df) > 0.3: | |
| raise IngestionError( | |
| "BAD_DATES", | |
| "More than 30% of dates couldn't be read. " | |
| "Please use a format like 15/04/2024 or 2024-04-15." | |
| ) | |
| return df | |
| def _try_parse_date(value: str): | |
| value = str(value).strip() | |
| for fmt in DATE_FORMATS: | |
| try: | |
| return pd.to_datetime(value, format=fmt) | |
| except (ValueError, TypeError): | |
| continue | |
| # Last resort: let pandas guess | |
| try: | |
| return pd.to_datetime(value, dayfirst=True) | |
| except Exception: | |
| return None | |
| # βββ Value cleaning βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _clean_values(df: pd.DataFrame, value_col: str) -> pd.DataFrame: | |
| df = df.copy() | |
| if df[value_col].dtype == object: | |
| df[value_col] = df[value_col].apply(_parse_indian_number) | |
| non_numeric = df[value_col].isna().mean() | |
| if non_numeric > 0.8: | |
| raise IngestionError( | |
| "NON_NUMERIC", | |
| f"Column '{value_col}' has too many non-numeric values. " | |
| "Please pick the column with your sales or price numbers." | |
| ) | |
| if df[value_col].notna().all() and df[value_col].nunique() == 1: | |
| raise IngestionError( | |
| "ALL_IDENTICAL", | |
| "All values in this column are identical. Forecasting won't be useful here." | |
| ) | |
| return df | |
| def _parse_indian_number(value) -> float | None: | |
| """ | |
| Handles βΉ2,300 / 23.5 lakh / 2 crore / 1,23,456 and plain floats. | |
| Returns None if the value genuinely can't be parsed as a number. | |
| """ | |
| if pd.isna(value): | |
| return None | |
| text = str(value).strip().lower() | |
| text = re.sub(r"[βΉ$Β£\s]", "", text) | |
| text = re.sub(r"^rs\.?\s*", "", text) # strip "Rs" / "Rs." prefix | |
| # Crore / lakh shorthand | |
| crore_match = re.search(r"([\d.]+)\s*crore", text) | |
| lakh_match = re.search(r"([\d.]+)\s*lakh", text) | |
| if crore_match: | |
| return float(crore_match.group(1)) * 1e7 | |
| if lakh_match: | |
| return float(lakh_match.group(1)) * 1e5 | |
| # Strip Indian-style commas (1,23,456 β 123456 and 1,234 β 1234) | |
| text = re.sub(r",", "", text) | |
| # Remove any trailing unit words | |
| text = re.sub(r"[a-z]+$", "", text).strip() | |
| try: | |
| return float(text) | |
| except ValueError: | |
| return None | |
| # βββ Sorting and deduplication ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sort_and_dedup(df: pd.DataFrame, date_col: str) -> pd.DataFrame: | |
| df = df.copy() | |
| df = df.dropna(subset=[date_col]) | |
| df = df.sort_values(date_col) | |
| df = df.drop_duplicates(subset=[date_col], keep="first") | |
| return df.reset_index(drop=True) | |
| # βββ Gap handling βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _handle_gaps( | |
| df: pd.DataFrame, date_col: str, value_col: str | |
| ) -> tuple[pd.DataFrame, float]: | |
| """ | |
| Fills small gaps via interpolation. Returns the filled dataframe | |
| and the original gap fraction so warnings can be set. | |
| """ | |
| df = df.copy() | |
| missing = df[value_col].isna() | |
| gap_fraction = missing.mean() | |
| if gap_fraction > 0.3: | |
| raise IngestionError( | |
| "TOO_MANY_GAPS", | |
| f"About {gap_fraction:.0%} of your values are missing. " | |
| "Please fill in the gaps and try again." | |
| ) | |
| if missing.any(): | |
| df[value_col] = df[value_col].interpolate(method="linear", limit_direction="both") | |
| return df, gap_fraction | |
| # βββ Validation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _validate_length(df: pd.DataFrame) -> None: | |
| if len(df) < MIN_ROWS: | |
| raise IngestionError( | |
| "TOO_FEW_ROWS", | |
| f"We need at least {MIN_ROWS} data points to make a forecast. " | |
| f"Your file has {len(df)} rows." | |
| ) | |
| # βββ Frequency detection ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _detect_frequency(df: pd.DataFrame, date_col: str) -> str: | |
| if len(df) < 3: | |
| return "unknown" | |
| deltas = df[date_col].diff().dropna().dt.days | |
| median_gap = deltas.median() | |
| if median_gap < 0.1: | |
| return "hourly" | |
| if median_gap <= 1.5: | |
| return "daily" | |
| if median_gap <= 8: | |
| return "weekly" | |
| if median_gap <= 35: | |
| return "monthly" | |
| if median_gap <= 100: | |
| return "quarterly" | |
| if median_gap <= 400: | |
| return "annually" | |
| return "unknown" | |
| # βββ Outlier detection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _find_outliers(df: pd.DataFrame, value_col: str, date_col: str) -> list[OutlierInfo]: | |
| """ | |
| IQR method (Tukey 1977). Flags values more than 3ΓIQR beyond Q1/Q3. | |
| These are shown to the user for confirmation β not auto-removed. | |
| """ | |
| values = df[value_col].dropna() | |
| q1, q3 = values.quantile(0.25), values.quantile(0.75) | |
| iqr = q3 - q1 | |
| lower, upper = q1 - 3 * iqr, q3 + 3 * iqr | |
| outliers = [] | |
| for idx, row in df.iterrows(): | |
| v = row[value_col] | |
| if pd.notna(v) and (v < lower or v > upper): | |
| outliers.append(OutlierInfo( | |
| row_index=int(idx), | |
| date=str(row[date_col].date()), | |
| value=float(v), | |
| )) | |
| return outliers | |
| def _cap_outliers(df: pd.DataFrame, value_col: str) -> pd.DataFrame: | |
| df = df.copy() | |
| cap = df[value_col].quantile(0.99) | |
| floor = df[value_col].quantile(0.01) | |
| df[value_col] = df[value_col].clip(lower=floor, upper=cap) | |
| return df | |
| # βββ Warnings βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_warnings( | |
| df: pd.DataFrame, | |
| value_col: str, | |
| frequency: str, | |
| gap_fraction: float, | |
| ) -> UploadWarnings: | |
| series = df[value_col].dropna().values | |
| # Intermittent demand: more than 30% of values are zero | |
| zero_frac = (series == 0).mean() | |
| intermittent = bool(zero_frac > 0.3) | |
| # Financial/non-stationary series: ADF test p-value > 0.05 | |
| # Only meaningful on longer series | |
| non_stationary = False | |
| if len(series) >= 20: | |
| try: | |
| p_value = adfuller(series, autolag="AIC")[1] | |
| non_stationary = bool(p_value > 0.05) | |
| except Exception: | |
| pass | |
| short_series = len(series) < 52 and frequency == "weekly" | |
| large_gaps = bool(0.1 < gap_fraction <= 0.3) | |
| return UploadWarnings( | |
| intermittent=intermittent, | |
| non_stationary=non_stationary, | |
| short_series=short_series, | |
| large_gaps=large_gaps, | |
| ) | |
| # βββ Preview ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_preview(df: pd.DataFrame, date_col: str, value_col: str) -> list[dict]: | |
| rows = df[[date_col, value_col]].head(5) | |
| return [ | |
| {"date": str(row[date_col].date()), "value": row[value_col]} | |
| for _, row in rows.iterrows() | |
| ] | |
| # βββ In-memory session store ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _SESSION_STORE: dict[str, dict] = {} |