| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| import logging |
| import os |
| import re |
| import threading |
| import time |
| import uuid |
| from dataclasses import dataclass |
| from datetime import date, datetime, timezone |
| from decimal import Decimal |
| from statistics import mean |
| from typing import Any |
|
|
| from fastapi import Depends, FastAPI, Header, status |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel, Field |
|
|
| try: |
| import psycopg |
| except ImportError: |
| psycopg = None |
|
|
| try: |
| import sqlglot |
| from sqlglot import exp |
| except ImportError: |
| sqlglot = None |
| exp = None |
|
|
| try: |
| import numpy as np |
| except ImportError: |
| np = None |
|
|
| try: |
| import timesfm |
| except ImportError: |
| timesfm = None |
|
|
|
|
| LOGGER = logging.getLogger("pig_query_api") |
| if not LOGGER.handlers: |
| logging.basicConfig( |
| level=os.getenv("PIG_QUERY_LOG_LEVEL", "INFO").upper(), |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", |
| ) |
|
|
|
|
| SERVICE_NAME = "pig-query-api" |
| API_KEY = os.getenv("PIG_QUERY_API_KEY", "change-me") |
| |
| PG_HOST = "47.107.138.68" |
| PG_PORT = 5432 |
| PG_DATABASE = "postgres" |
| PG_USER = "postgres" |
| PG_PASSWORD = "63BvPoC5JiHu16k8Gh" |
| PG_DSN = f"postgresql://{PG_USER}:{PG_PASSWORD}@{PG_HOST}:{PG_PORT}/{PG_DATABASE}" |
| MAX_ROWS = int(os.getenv("PIG_QUERY_MAX_ROWS", "100")) |
| STATEMENT_TIMEOUT_MS = int(os.getenv("PIG_QUERY_STATEMENT_TIMEOUT_MS", "3000")) |
| DEFAULT_SELECT_LIMIT = int(os.getenv("PIG_QUERY_DEFAULT_SELECT_LIMIT", "50")) |
| TIMESFM_DEVICE = os.getenv("TIMESFM_DEVICE", "cpu").lower() |
| TIMESFM_CHECKPOINT = os.getenv("TIMESFM_CHECKPOINT", "google/timesfm-2.0-500m-pytorch") |
| TIMESFM_BACKEND = os.getenv("TIMESFM_BACKEND", "cpu").lower() |
| TIMESFM_MAX_HORIZON = int(os.getenv("TIMESFM_MAX_HORIZON", "15")) |
| TIMESFM_BATCH_SIZE = int(os.getenv("TIMESFM_BATCH_SIZE", "32")) |
| NON_DOUBAO_MODEL_ID = os.getenv("NON_DOUBAO_MODEL_ID", "NON_DOUBAO_MODEL_ID") |
| SERIES_LOOKBACK_DAYS = int(os.getenv("PREDICT_SERIES_LOOKBACK_DAYS", "90")) |
| SERIES_MIN_POINTS = int(os.getenv("PREDICT_SERIES_MIN_POINTS", "8")) |
| PREDICT_LOG_ENABLED = os.getenv("PREDICT_LOG_ENABLED", "1").lower() in {"1", "true", "yes", "on"} |
|
|
| TIMESFM_MODEL_LOCK = threading.Lock() |
| TIMESFM_MODEL: Any | None = None |
| TIMESFM_INIT_ERROR: str | None = None |
|
|
| ALLOWED_TABLES = frozenset( |
| { |
| 'public."采食量标准"', |
| 'public."体重标准"', |
| 'public."料肉比标准"', |
| 'public."猪仔信息"', |
| 'public."日常生长数据1"', |
| "public.weighing_records", |
| 'public."饲料库存记录"', |
| 'public."玉米价格"', |
| 'public."猪仔价格"', |
| "public.生猪价格", |
| 'public."猪肉批发价"', |
| } |
| ) |
|
|
| DENY_KEYWORDS = ( |
| "insert", |
| "update", |
| "delete", |
| "drop", |
| "alter", |
| "truncate", |
| "create", |
| "grant", |
| "revoke", |
| "copy", |
| "merge", |
| "vacuum", |
| "analyze", |
| "comment", |
| ) |
|
|
| COMMENT_PATTERN = re.compile(r"(--|/\*|\*/)") |
| LIMIT_PATTERN = re.compile(r"\blimit\s+(\d+)\b", re.IGNORECASE) |
| AGGREGATE_PATTERN = re.compile( |
| r"\b(count|sum|avg|min|max|string_agg|array_agg|json_agg)\s*\(", |
| re.IGNORECASE, |
| ) |
| SQL_START_PATTERN = re.compile(r"^\s*(select|with)\b", re.IGNORECASE) |
| TABLE_REF_PATTERN = re.compile(r"\b(from|join)\s+((?:public\.)?\"?[A-Za-z0-9_\u4e00-\u9fff]+\"?)", re.IGNORECASE) |
| HORIZON_PATTERN = re.compile(r"(?:未来|接下来|后续)?\s*(\d{1,2})\s*(?:天|日)") |
|
|
| MODULE_IDS = { |
| "m1_price_forecast", |
| "m2_slaughter_window", |
| "m3_feeding_plan", |
| "m4_disease_risk", |
| "m5_inventory_procurement", |
| "m6_piglet_buy", |
| } |
|
|
|
|
| def _normalize_identifier(value: str) -> str: |
| parts = [part.strip().strip('"') for part in value.split(".") if part.strip()] |
| return ".".join(part.lower() for part in parts) |
|
|
|
|
| ALLOWED_TABLE_NAMES = frozenset( |
| { |
| normalized |
| for table in ALLOWED_TABLES |
| for normalized in ( |
| _normalize_identifier(table), |
| _normalize_identifier(table).split(".", 1)[-1], |
| ) |
| } |
| ) |
|
|
|
|
| class ErrorDetail(BaseModel): |
| code: str |
| message: str |
| detail: dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
| class QuerySuccessData(BaseModel): |
| columns: list[str] |
| rows: list[list[Any]] |
| row_count: int |
| sql_executed: str |
| empty: bool |
| trace_id: str |
|
|
|
|
| class QueryResponse(BaseModel): |
| ok: bool |
| data: QuerySuccessData | None = None |
| err: ErrorDetail | None = None |
|
|
|
|
| class QueryRequest(BaseModel): |
| user_query: str = Field(..., min_length=1, max_length=2000) |
| sql: str = Field(..., min_length=1, max_length=20000) |
| trace_id: str | None = Field(default=None, max_length=128) |
|
|
|
|
| class HealthResponse(BaseModel): |
| ok: bool = True |
| service: str |
| time: str |
| sql_parser: str |
| db_driver: str |
|
|
|
|
| class PredictRunRequest(BaseModel): |
| question: str = Field(..., min_length=1, max_length=2000) |
| module_id: str | None = Field(default=None, max_length=64) |
| trace_id: str | None = Field(default=None, max_length=128) |
|
|
|
|
| class PredictRunData(BaseModel): |
| module_id: str |
| db_used: bool |
| model_trace: list[str] |
| forecast: dict[str, Any] |
| recommendation: dict[str, Any] |
| risk_flags: list[str] |
| backend_draft_summary: str |
| trace_id: str |
|
|
|
|
| class PredictRunResponse(BaseModel): |
| ok: bool |
| data: PredictRunData | None = None |
| err: ErrorDetail | None = None |
|
|
|
|
| class PredictHealthResponse(BaseModel): |
| ok: bool |
| service: str |
| time: str |
| timesfm_enabled: bool |
| timesfm_ready: bool |
| timesfm_device: str |
| timesfm_checkpoint: str |
| timesfm_init_error: str | None = None |
|
|
|
|
| class QueryAPIError(Exception): |
| def __init__(self, code: str, message: str, detail: dict[str, Any] | None = None, status_code: int = 400): |
| super().__init__(message) |
| self.code = code |
| self.message = message |
| self.detail = detail or {} |
| self.status_code = status_code |
|
|
|
|
| @dataclass(frozen=True) |
| class PreparedQuery: |
| sql: str |
| tables: set[str] |
|
|
|
|
| def normalize_sql(sql: str) -> str: |
| return sql.replace("\ufeff", "").strip() |
|
|
|
|
| def reject_empty_sql(sql: str) -> None: |
| if not sql: |
| raise QueryAPIError("E_SQL_EMPTY", "SQL 不能为空") |
|
|
|
|
| def reject_multi_statement(sql: str) -> None: |
| stripped = sql.rstrip(";").strip() |
| if ";" in stripped: |
| raise QueryAPIError("E_SQL_MULTI", "SQL 只允许单条语句") |
|
|
|
|
| def reject_comments(sql: str) -> None: |
| if COMMENT_PATTERN.search(sql): |
| raise QueryAPIError("E_SQL_COMMENT", "SQL 不允许注释") |
|
|
|
|
| def reject_dangerous_keywords(sql: str) -> None: |
| lower_sql = sql.lower() |
| for keyword in DENY_KEYWORDS: |
| if re.search(rf"\b{re.escape(keyword)}\b", lower_sql): |
| raise QueryAPIError( |
| "E_SQL_DENIED", |
| "SQL 不在允许范围内", |
| {"reason": f"contains denied keyword: {keyword}"}, |
| ) |
|
|
|
|
| def parse_sql(sql: str) -> Any | None: |
| if sqlglot is None: |
| return None |
| try: |
| return sqlglot.parse_one(sql, read="postgres") |
| except Exception as exc: |
| raise QueryAPIError("E_SQL_PARSE", "SQL 解析失败", {"reason": str(exc)}) from exc |
|
|
|
|
| def ensure_select_only(ast: Any | None, sql: str) -> None: |
| if not SQL_START_PATTERN.match(sql): |
| raise QueryAPIError("E_SQL_TYPE", "只允许 SELECT 或 WITH 查询") |
|
|
| if ast is None or exp is None: |
| return |
|
|
| denied_expression_types = ( |
| exp.Insert, |
| exp.Update, |
| exp.Delete, |
| exp.Drop, |
| exp.Alter, |
| exp.Create, |
| exp.Command, |
| exp.Copy, |
| exp.Merge, |
| ) |
| for node_type in denied_expression_types: |
| if ast.find(node_type): |
| raise QueryAPIError("E_SQL_DENIED", "SQL 不在允许范围内", {"reason": f"contains {node_type.__name__}"}) |
|
|
|
|
| def _extract_tables_with_sqlglot(ast: Any) -> set[str]: |
| tables: set[str] = set() |
| cte_names: set[str] = set() |
| if exp is None: |
| return tables |
|
|
| for cte in ast.find_all(exp.CTE): |
| alias = getattr(cte, "alias", None) |
| if alias: |
| cte_names.add(_normalize_identifier(alias)) |
|
|
| for table in ast.find_all(exp.Table): |
| db = getattr(table, "db", None) |
| name = getattr(table, "name", None) |
| if not name: |
| continue |
| normalized = _normalize_identifier(".".join(part for part in (db, name) if part)) |
| if normalized in cte_names: |
| continue |
| tables.add(normalized) |
| return tables |
|
|
|
|
| def _extract_tables_with_regex(sql: str) -> set[str]: |
| matches = TABLE_REF_PATTERN.findall(sql) |
| return {_normalize_identifier(match[1]) for match in matches} |
|
|
|
|
| def extract_tables(ast: Any | None, sql: str) -> set[str]: |
| if ast is not None and sqlglot is not None: |
| tables = _extract_tables_with_sqlglot(ast) |
| if tables: |
| return tables |
| return _extract_tables_with_regex(sql) |
|
|
|
|
| def ensure_whitelisted_tables(tables: set[str]) -> None: |
| if not tables: |
| raise QueryAPIError("E_SQL_TABLES", "未识别到可查询表") |
|
|
| invalid_tables = [] |
| for table in tables: |
| candidates = {table, table.split(".", 1)[-1]} |
| if not candidates.intersection(ALLOWED_TABLE_NAMES): |
| invalid_tables.append(table) |
|
|
| if invalid_tables: |
| raise QueryAPIError( |
| "E_SQL_DENIED", |
| "SQL 不在允许范围内", |
| {"reason": "contains non-whitelisted table", "tables": sorted(invalid_tables)}, |
| ) |
|
|
|
|
| def ensure_limit_if_needed(sql: str) -> None: |
| if LIMIT_PATTERN.search(sql): |
| return |
|
|
| lowered = sql.lower() |
| if "group by" in lowered: |
| raise QueryAPIError("E_LIMIT_REQUIRED", f"包含 GROUP BY 的查询必须显式带 LIMIT {DEFAULT_SELECT_LIMIT} 以内") |
|
|
| if AGGREGATE_PATTERN.search(sql): |
| return |
|
|
| raise QueryAPIError("E_LIMIT_REQUIRED", f"明细查询必须显式带 LIMIT,建议 LIMIT {DEFAULT_SELECT_LIMIT}") |
|
|
|
|
| def prepare_sql_for_execution(sql: str) -> PreparedQuery: |
| normalized = normalize_sql(sql) |
| reject_empty_sql(normalized) |
| reject_multi_statement(normalized) |
| reject_comments(normalized) |
| reject_dangerous_keywords(normalized) |
| ast = parse_sql(normalized) |
| ensure_select_only(ast, normalized) |
| tables = extract_tables(ast, normalized) |
| ensure_whitelisted_tables(tables) |
| ensure_limit_if_needed(normalized) |
| return PreparedQuery(sql=normalized, tables=tables) |
|
|
|
|
| def _now_utc_iso() -> str: |
| return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") |
|
|
|
|
| def _clamp_horizon(value: int | None) -> int: |
| if value is None: |
| return min(15, TIMESFM_MAX_HORIZON) |
| return max(7, min(TIMESFM_MAX_HORIZON, int(value))) |
|
|
|
|
| def infer_horizon_days(question: str) -> int: |
| match = HORIZON_PATTERN.search(question) |
| if not match: |
| return _clamp_horizon(15) |
| return _clamp_horizon(int(match.group(1))) |
|
|
|
|
| def ensure_non_doubao_model() -> None: |
| normalized = NON_DOUBAO_MODEL_ID.strip().lower() |
| if "doubao" in normalized: |
| raise QueryAPIError( |
| "E_MODEL_POLICY", |
| "专业预测总结模型禁止配置为豆包", |
| {"NON_DOUBAO_MODEL_ID": NON_DOUBAO_MODEL_ID}, |
| status_code=500, |
| ) |
|
|
|
|
| def infer_module_id_from_question(question: str) -> str: |
| normalized = question.lower() |
| if any(token in normalized for token in ("库存", "采购", "补货", "饲料还够")): |
| return "m5_inventory_procurement" |
| if any(token in normalized for token in ("猪仔", "补栏", "买入", "购入")): |
| return "m6_piglet_buy" |
| if any(token in normalized for token in ("疫病", "异常", "预警", "风险", "病")): |
| return "m4_disease_risk" |
| if any(token in normalized for token in ("饲喂", "喂料", "料肉比", "采食")): |
| return "m3_feeding_plan" |
| if any(token in normalized for token in ("出栏", "卖猪", "窗口")): |
| return "m2_slaughter_window" |
| return "m1_price_forecast" |
|
|
|
|
| def normalize_module_id(module_id: str | None, question: str) -> str: |
| if module_id and module_id in MODULE_IDS: |
| return module_id |
| return infer_module_id_from_question(question) |
|
|
|
|
| def _query_rows(sql: str, params: tuple[Any, ...] = ()) -> list[tuple[Any, ...]]: |
| _require_psycopg() |
| if not PG_DSN: |
| raise QueryAPIError("E_DSN_MISSING", "未配置 PIG_QUERY_PG_DSN", status_code=500) |
| try: |
| with psycopg.connect(PG_DSN, autocommit=False) as conn: |
| with conn.cursor() as cur: |
| cur.execute("SET TRANSACTION READ ONLY") |
| cur.execute(f"SET LOCAL statement_timeout = '{STATEMENT_TIMEOUT_MS}ms'") |
| cur.execute(sql, params) |
| return cur.fetchall() |
| except QueryAPIError: |
| raise |
| except Exception as exc: |
| raise QueryAPIError("E_PREDICT_DB", "预测数据读取失败", {"reason": str(exc)}, status_code=500) from exc |
|
|
|
|
| def _rows_to_series(name: str, rows: list[tuple[Any, Any]]) -> dict[str, Any]: |
| timestamps: list[str] = [] |
| values: list[float] = [] |
| for ts, value in rows: |
| if value is None: |
| continue |
| try: |
| numeric_value = float(value) |
| except Exception: |
| continue |
| timestamps.append(ts.isoformat() if hasattr(ts, "isoformat") else str(ts)) |
| values.append(numeric_value) |
| return {"name": name, "timestamps": timestamps, "values": values} |
|
|
|
|
| def _query_date_price_series(table_name: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]: |
| rows = _query_rows( |
| f""" |
| SELECT d, v |
| FROM ( |
| SELECT "日期"::date AS d, "价格"::float8 AS v |
| FROM public."{table_name}" |
| WHERE "日期" IS NOT NULL |
| ORDER BY "日期" DESC |
| LIMIT %s |
| ) t |
| ORDER BY d ASC |
| """, |
| (limit_rows,), |
| ) |
| return _rows_to_series(series_name, rows) |
|
|
|
|
| def _query_text_date_price_series(table_name: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]: |
| rows = _query_rows( |
| f""" |
| SELECT d, v |
| FROM ( |
| SELECT |
| CASE |
| WHEN "日期" ~ '^[0-9]{{4}}-[0-9]{{2}}-[0-9]{{2}}$' THEN "日期"::date |
| ELSE NULL |
| END AS d, |
| "价格"::float8 AS v |
| FROM public."{table_name}" |
| ORDER BY d DESC NULLS LAST |
| LIMIT %s |
| ) t |
| WHERE d IS NOT NULL |
| ORDER BY d ASC |
| """, |
| (limit_rows,), |
| ) |
| return _rows_to_series(series_name, rows) |
|
|
|
|
| def _query_growth_avg_series(value_field: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]: |
| rows = _query_rows( |
| f""" |
| SELECT d, avg_v |
| FROM ( |
| SELECT |
| CASE |
| WHEN "采集日期" ~ '^[0-9]{{4}}-[0-9]{{2}}-[0-9]{{2}}$' THEN "采集日期"::date |
| ELSE NULL |
| END AS d, |
| AVG("{value_field}"::float8) AS avg_v |
| FROM public."日常生长数据1" |
| GROUP BY d |
| ORDER BY d DESC NULLS LAST |
| LIMIT %s |
| ) t |
| WHERE d IS NOT NULL |
| ORDER BY d ASC |
| """, |
| (limit_rows,), |
| ) |
| return _rows_to_series(series_name, rows) |
|
|
|
|
| def _query_inventory_series(limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]: |
| rows = _query_rows( |
| """ |
| SELECT idx, v |
| FROM ( |
| SELECT |
| "编号"::int AS idx, |
| "原有库存_kg"::float8 AS v |
| FROM public."饲料库存记录" |
| WHERE "原有库存_kg" IS NOT NULL |
| ORDER BY "编号" DESC |
| LIMIT %s |
| ) t |
| ORDER BY idx ASC |
| """, |
| (limit_rows,), |
| ) |
| return _rows_to_series("feed_inventory_kg", rows) |
|
|
|
|
| def build_series_from_db(module_id: str, question: str, horizon_days: int | None = None) -> dict[str, Any]: |
| horizon = _clamp_horizon(horizon_days) |
| query_plan: list[tuple[str, Any]] = [] |
| if module_id == "m1_price_forecast": |
| query_plan = [ |
| ("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")), |
| ("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")), |
| ("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")), |
| ] |
| elif module_id == "m2_slaughter_window": |
| query_plan = [ |
| ("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")), |
| ("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")), |
| ] |
| elif module_id == "m3_feeding_plan": |
| query_plan = [ |
| ("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")), |
| ("avg_feed_kg", lambda: _query_growth_avg_series("采食量kg", "avg_feed_kg")), |
| ("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")), |
| ("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")), |
| ] |
| elif module_id == "m4_disease_risk": |
| query_plan = [ |
| ("avg_feed_kg", lambda: _query_growth_avg_series("采食量kg", "avg_feed_kg")), |
| ("avg_water_l", lambda: _query_growth_avg_series("饮水量l", "avg_water_l")), |
| ("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")), |
| ] |
| elif module_id == "m5_inventory_procurement": |
| query_plan = [ |
| ("feed_inventory_kg", _query_inventory_series), |
| ("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")), |
| ("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")), |
| ] |
| elif module_id == "m6_piglet_buy": |
| query_plan = [ |
| ("piglet_price", lambda: _query_date_price_series("猪仔价格", "piglet_price")), |
| ("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")), |
| ] |
| else: |
| raise QueryAPIError("E_PREDICT_INPUT", "未知的 module_id", {"module_id": module_id}) |
|
|
| series: list[dict[str, Any]] = [] |
| source_errors: list[dict[str, Any]] = [] |
| for series_name, fetcher in query_plan: |
| try: |
| item = fetcher() |
| except QueryAPIError as exc: |
| source_errors.append( |
| { |
| "series": series_name, |
| "code": exc.code, |
| "message": exc.message, |
| "detail": exc.detail, |
| } |
| ) |
| LOGGER.warning( |
| "predict_series_fetch_failed module_id=%s series=%s code=%s detail=%s", |
| module_id, |
| series_name, |
| exc.code, |
| json.dumps(exc.detail, ensure_ascii=False), |
| ) |
| continue |
| if len(item.get("values", [])) > 0: |
| series.append(item) |
|
|
| non_empty_series = [item for item in series if len(item.get("values", [])) > 0] |
| if not non_empty_series: |
| raise QueryAPIError( |
| "E_PREDICT_DATA_EMPTY", |
| "预测所需数据库数据不足", |
| {"module_id": module_id, "question": question, "source_errors": source_errors}, |
| status_code=400, |
| ) |
|
|
| return { |
| "module_id": module_id, |
| "question": question, |
| "horizon_days": horizon, |
| "series": non_empty_series, |
| "source_errors": source_errors, |
| } |
|
|
|
|
| def _accepts_kwarg(callable_obj: Any, arg_name: str) -> bool: |
| try: |
| import inspect |
|
|
| signature = inspect.signature(callable_obj) |
| return arg_name in signature.parameters |
| except Exception: |
| return False |
|
|
|
|
| def _create_timesfm_hparams() -> Any: |
| hparams_cls = getattr(timesfm, "TimesFmHparams", None) |
| if hparams_cls is None: |
| raise QueryAPIError("E_TIMESFM_INIT", "timesfm 缺少 TimesFmHparams", status_code=500) |
|
|
| kwargs: dict[str, Any] = {} |
| if _accepts_kwarg(hparams_cls, "backend"): |
| kwargs["backend"] = TIMESFM_BACKEND |
| if _accepts_kwarg(hparams_cls, "per_core_batch_size"): |
| kwargs["per_core_batch_size"] = TIMESFM_BATCH_SIZE |
| if _accepts_kwarg(hparams_cls, "horizon_len"): |
| kwargs["horizon_len"] = TIMESFM_MAX_HORIZON |
| if _accepts_kwarg(hparams_cls, "device"): |
| kwargs["device"] = TIMESFM_DEVICE |
|
|
| |
| if TIMESFM_CHECKPOINT == "google/timesfm-2.0-500m-pytorch": |
| if _accepts_kwarg(hparams_cls, "input_patch_len"): |
| kwargs["input_patch_len"] = 32 |
| if _accepts_kwarg(hparams_cls, "output_patch_len"): |
| kwargs["output_patch_len"] = 128 |
| if _accepts_kwarg(hparams_cls, "num_layers"): |
| kwargs["num_layers"] = 50 |
| if _accepts_kwarg(hparams_cls, "model_dims"): |
| kwargs["model_dims"] = 1280 |
| if _accepts_kwarg(hparams_cls, "use_positional_embedding"): |
| kwargs["use_positional_embedding"] = False |
|
|
| return hparams_cls(**kwargs) |
|
|
|
|
| def _create_timesfm_checkpoint() -> Any: |
| checkpoint_cls = getattr(timesfm, "TimesFmCheckpoint", None) |
| if checkpoint_cls is None: |
| return TIMESFM_CHECKPOINT |
|
|
| if _accepts_kwarg(checkpoint_cls, "huggingface_repo_id"): |
| return checkpoint_cls(huggingface_repo_id=TIMESFM_CHECKPOINT) |
| return checkpoint_cls(TIMESFM_CHECKPOINT) |
|
|
|
|
| def get_timesfm_model() -> Any: |
| global TIMESFM_MODEL |
| global TIMESFM_INIT_ERROR |
| if TIMESFM_MODEL is not None: |
| return TIMESFM_MODEL |
| with TIMESFM_MODEL_LOCK: |
| if TIMESFM_MODEL is not None: |
| return TIMESFM_MODEL |
| if timesfm is None: |
| TIMESFM_INIT_ERROR = "timesfm package not installed" |
| raise QueryAPIError("E_TIMESFM_INIT", "TimesFM 未安装,请在 python311 环境安装 timesfm[torch]", status_code=500) |
|
|
| model_cls = getattr(timesfm, "TimesFm", None) |
| if model_cls is None: |
| TIMESFM_INIT_ERROR = "timesfm package missing TimesFm class" |
| raise QueryAPIError("E_TIMESFM_INIT", "TimesFM 初始化失败:缺少 TimesFm 类", status_code=500) |
|
|
| try: |
| hparams = _create_timesfm_hparams() |
| checkpoint = _create_timesfm_checkpoint() |
| init_attempts = [ |
| lambda: model_cls(hparams=hparams, checkpoint=checkpoint), |
| lambda: model_cls(hparams, checkpoint), |
| lambda: model_cls(checkpoint=checkpoint), |
| ] |
| last_error = None |
| for attempt in init_attempts: |
| try: |
| TIMESFM_MODEL = attempt() |
| TIMESFM_INIT_ERROR = None |
| return TIMESFM_MODEL |
| except Exception as exc: |
| last_error = str(exc) |
| TIMESFM_INIT_ERROR = last_error or "unknown init error" |
| except QueryAPIError: |
| raise |
| except Exception as exc: |
| TIMESFM_INIT_ERROR = str(exc) |
|
|
| raise QueryAPIError( |
| "E_TIMESFM_INIT", |
| "TimesFM 初始化失败", |
| {"reason": TIMESFM_INIT_ERROR, "checkpoint": TIMESFM_CHECKPOINT}, |
| status_code=500, |
| ) |
|
|
|
|
| def ensure_timesfm_ready() -> None: |
| get_timesfm_model() |
|
|
|
|
| def _normalize_timesfm_points(raw_output: Any, horizon_days: int) -> list[float]: |
| payload = raw_output |
| if isinstance(payload, tuple) and payload: |
| payload = payload[0] |
|
|
| if np is not None and isinstance(payload, np.ndarray): |
| values = payload |
| elif isinstance(payload, list): |
| values = payload |
| elif isinstance(payload, dict): |
| candidates = payload.get("point_forecast") or payload.get("forecast") or payload.get("mean") |
| if candidates is None: |
| raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 返回结构无法解析", {"raw_type": str(type(payload))}, status_code=500) |
| values = candidates |
| else: |
| raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 返回结构无法解析", {"raw_type": str(type(payload))}, status_code=500) |
|
|
| if np is not None and isinstance(values, np.ndarray): |
| if values.ndim >= 2: |
| points = values[0].tolist() |
| else: |
| points = values.tolist() |
| else: |
| if values and isinstance(values[0], list): |
| points = values[0] |
| else: |
| points = values |
|
|
| numeric_points: list[float] = [] |
| for item in points: |
| try: |
| numeric_points.append(float(item)) |
| except Exception: |
| continue |
|
|
| if not numeric_points: |
| raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 未返回有效预测点", status_code=500) |
|
|
| if len(numeric_points) < horizon_days: |
| numeric_points.extend([numeric_points[-1]] * (horizon_days - len(numeric_points))) |
| return numeric_points[:horizon_days] |
|
|
|
|
| def run_timesfm_forecast(series: list[float], horizon_days: int) -> dict[str, Any]: |
| if not series: |
| raise QueryAPIError("E_TIMESFM_INFER", "输入时序为空", status_code=400) |
|
|
| model = get_timesfm_model() |
| padded_series = [float(v) for v in series] |
| if len(padded_series) < SERIES_MIN_POINTS: |
| padded_series.extend([padded_series[-1]] * (SERIES_MIN_POINTS - len(padded_series))) |
|
|
| forecast_attempts = [ |
| lambda: model.forecast([padded_series], freq=[0]), |
| lambda: model.forecast([padded_series], horizon_len=horizon_days, freq=[0]), |
| lambda: model.forecast([padded_series], horizon_len=horizon_days), |
| lambda: model.forecast([padded_series], horizon_days), |
| ] |
| if np is not None: |
| np_input = np.asarray([padded_series], dtype=float) |
| forecast_attempts.extend( |
| [ |
| lambda: model.forecast(np_input, freq=[0]), |
| lambda: model.forecast(np_input, horizon_len=horizon_days, freq=[0]), |
| ] |
| ) |
|
|
| last_error = None |
| for attempt in forecast_attempts: |
| try: |
| raw = attempt() |
| points = _normalize_timesfm_points(raw, horizon_days) |
| return { |
| "horizon_days": horizon_days, |
| "point_forecast": points, |
| "input_last_value": padded_series[-1], |
| } |
| except QueryAPIError: |
| raise |
| except Exception as exc: |
| last_error = str(exc) |
| continue |
|
|
| raise QueryAPIError( |
| "E_TIMESFM_INFER", |
| "TimesFM 推理失败", |
| {"reason": last_error or "unknown forecast error"}, |
| status_code=500, |
| ) |
|
|
|
|
| def run_module_logic(module_id: str, series_pack: dict[str, Any], forecast_map: dict[str, dict[str, Any]]) -> dict[str, Any]: |
| horizon_days = int(series_pack.get("horizon_days", 15)) |
| recommendation: dict[str, Any] = {} |
| risk_flags: list[str] = [] |
|
|
| def _series_direction(name: str) -> str: |
| forecast = forecast_map.get(name, {}).get("point_forecast", []) |
| if len(forecast) < 2: |
| return "flat" |
| if forecast[-1] > forecast[0]: |
| return "up" |
| if forecast[-1] < forecast[0]: |
| return "down" |
| return "flat" |
|
|
| if module_id == "m1_price_forecast": |
| hog_dir = _series_direction("hog_price") |
| recommendation["action"] = "择机延后出栏" if hog_dir == "up" else "关注近期出栏窗口" |
| recommendation["reason"] = f"生猪价格未来{horizon_days}天趋势:{hog_dir}" |
| elif module_id == "m2_slaughter_window": |
| hog_dir = _series_direction("hog_price") |
| weight_dir = _series_direction("avg_weight_kg") |
| recommendation["optimal_window"] = "未来3-7天" if hog_dir == "up" and weight_dir == "up" else "未来1-3天滚动评估" |
| recommendation["reason"] = f"价格趋势={hog_dir}, 体重趋势={weight_dir}" |
| elif module_id == "m3_feeding_plan": |
| feed_dir = _series_direction("avg_feed_kg") |
| recommendation["feeding_adjustment"] = "维持配方并小幅增量" if feed_dir == "up" else "控制喂量并优化配方" |
| recommendation["reason"] = f"采食趋势={feed_dir}" |
| elif module_id == "m4_disease_risk": |
| water = forecast_map.get("avg_water_l", {}).get("point_forecast", []) |
| feed = forecast_map.get("avg_feed_kg", {}).get("point_forecast", []) |
| if water and feed and water[-1] < water[0] and feed[-1] < feed[0]: |
| risk_flags.append("采食与饮水同步下滑,建议排查健康风险") |
| recommendation["risk_level"] = "high" if risk_flags else "medium" |
| recommendation["action"] = "加强巡检与体温采样" |
| elif module_id == "m5_inventory_procurement": |
| inventory = forecast_map.get("feed_inventory_kg", {}).get("point_forecast", []) |
| if inventory and min(inventory) <= 0: |
| risk_flags.append("预测库存触底,需立即补货") |
| recommendation["procurement_window"] = "未来1周内分批采购" |
| recommendation["action"] = "结合价格低位滚动补货" |
| elif module_id == "m6_piglet_buy": |
| piglet_dir = _series_direction("piglet_price") |
| recommendation["buy_window"] = "未来3-5天" if piglet_dir == "down" else "等待价格回调" |
| recommendation["reason"] = f"猪仔价格趋势={piglet_dir}" |
| else: |
| raise QueryAPIError("E_PREDICT_INPUT", "未知 module_id", {"module_id": module_id}, status_code=400) |
|
|
| all_points = [point for item in forecast_map.values() for point in item.get("point_forecast", [])] |
| avg_point = round(mean(all_points), 4) if all_points else 0.0 |
| recommendation.setdefault("confidence", 0.65 if all_points else 0.0) |
| recommendation.setdefault("summary_metric", avg_point) |
|
|
| summary = f"{module_id} 已完成未来{horizon_days}天预测,建议:{json.dumps(recommendation, ensure_ascii=False)}" |
| return { |
| "forecast": forecast_map, |
| "recommendation": recommendation, |
| "risk_flags": risk_flags, |
| "backend_draft_summary": summary, |
| } |
|
|
|
|
| def _insert_prediction_audit( |
| trace_id: str, |
| module_id: str, |
| question: str, |
| horizon_days: int, |
| payload: PredictRunData, |
| ) -> None: |
| if not PREDICT_LOG_ENABLED: |
| return |
| if psycopg is None or not PG_DSN: |
| LOGGER.warning("predict_audit_insert_skipped trace_id=%s reason=psycopg_or_dsn_missing", trace_id) |
| return |
| try: |
| with psycopg.connect(PG_DSN, autocommit=True) as conn: |
| with conn.cursor() as cur: |
| cur.execute( |
| """ |
| INSERT INTO public.ai_prediction_run ( |
| trace_id, module_id, question, horizon_days, db_used, model_trace, status, created_at |
| ) |
| VALUES (%s, %s, %s, %s, %s, %s::jsonb, %s, now()) |
| ON CONFLICT (trace_id) DO UPDATE |
| SET module_id = EXCLUDED.module_id, |
| question = EXCLUDED.question, |
| horizon_days = EXCLUDED.horizon_days, |
| db_used = EXCLUDED.db_used, |
| model_trace = EXCLUDED.model_trace, |
| status = EXCLUDED.status |
| RETURNING id |
| """, |
| ( |
| trace_id, |
| module_id, |
| question, |
| horizon_days, |
| payload.db_used, |
| json.dumps(payload.model_trace, ensure_ascii=False), |
| "ok", |
| ), |
| ) |
| row = cur.fetchone() |
| if not row: |
| return |
| run_id = int(row[0]) |
| cur.execute( |
| """ |
| INSERT INTO public.ai_prediction_output ( |
| run_id, forecast, recommendation, risk_flags, backend_draft_summary, created_at |
| ) |
| VALUES (%s, %s::jsonb, %s::jsonb, %s::jsonb, %s, now()) |
| """, |
| ( |
| run_id, |
| json.dumps(payload.forecast, ensure_ascii=False), |
| json.dumps(payload.recommendation, ensure_ascii=False), |
| json.dumps(payload.risk_flags, ensure_ascii=False), |
| payload.backend_draft_summary, |
| ), |
| ) |
| except Exception as exc: |
| LOGGER.warning("predict_audit_insert_skipped trace_id=%s reason=%s", trace_id, str(exc)) |
|
|
|
|
| def _require_psycopg() -> None: |
| if psycopg is None: |
| raise QueryAPIError("E_DRIVER_MISSING", "缺少 psycopg 依赖,请先安装后再启动服务", status_code=500) |
|
|
|
|
| def _to_jsonable(value: Any) -> Any: |
| if value is None: |
| return None |
| if isinstance(value, (str, int, float, bool)): |
| return value |
| if isinstance(value, Decimal): |
| return float(value) |
| if isinstance(value, (datetime, date)): |
| return value.isoformat() |
| if isinstance(value, bytes): |
| return value.decode("utf-8", errors="replace") |
| return str(value) |
|
|
|
|
| def run_query(sql: str, trace_id: str) -> QuerySuccessData: |
| _require_psycopg() |
| if not PG_DSN: |
| raise QueryAPIError("E_DSN_MISSING", "未配置 PIG_QUERY_PG_DSN", status_code=500) |
|
|
| started = time.perf_counter() |
| try: |
| with psycopg.connect(PG_DSN, autocommit=False) as conn: |
| with conn.cursor() as cur: |
| cur.execute("SET TRANSACTION READ ONLY") |
| cur.execute(f"SET LOCAL statement_timeout = '{STATEMENT_TIMEOUT_MS}ms'") |
| cur.execute(sql) |
| rows = cur.fetchmany(MAX_ROWS) |
| columns = [desc.name for desc in cur.description] if cur.description else [] |
| except QueryAPIError: |
| raise |
| except Exception as exc: |
| raise QueryAPIError("E_DB_QUERY", "数据库查询失败", {"reason": str(exc)}, status_code=500) from exc |
|
|
| elapsed_ms = round((time.perf_counter() - started) * 1000, 2) |
| normalized_rows = [[_to_jsonable(value) for value in row] for row in rows] |
| data = QuerySuccessData( |
| columns=columns, |
| rows=normalized_rows, |
| row_count=len(normalized_rows), |
| sql_executed=sql, |
| empty=len(normalized_rows) == 0, |
| trace_id=trace_id, |
| ) |
|
|
| LOGGER.info( |
| "query_ok trace_id=%s sql_hash=%s row_count=%s duration_ms=%s", |
| trace_id, |
| hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16], |
| data.row_count, |
| elapsed_ms, |
| ) |
| return data |
|
|
|
|
| async def verify_api_key(x_api_key: str = Header(default="", alias="X-API-Key")) -> None: |
| if not x_api_key or x_api_key != API_KEY: |
| raise QueryAPIError("E_UNAUTHORIZED", "无效的 API Key", status_code=status.HTTP_401_UNAUTHORIZED) |
|
|
|
|
| def build_error_response(exc: QueryAPIError, trace_id: str | None = None) -> JSONResponse: |
| detail = dict(exc.detail) |
| if trace_id and "trace_id" not in detail: |
| detail["trace_id"] = trace_id |
| body = QueryResponse( |
| ok=False, |
| data=None, |
| err=ErrorDetail(code=exc.code, message=exc.message, detail=detail), |
| ) |
| return JSONResponse(status_code=exc.status_code, content=body.model_dump()) |
|
|
|
|
| app = FastAPI(title=SERVICE_NAME, version="1.0.0") |
|
|
|
|
| @app.exception_handler(QueryAPIError) |
| async def query_api_error_handler(_: Any, exc: QueryAPIError) -> JSONResponse: |
| return build_error_response(exc) |
|
|
|
|
| @app.exception_handler(Exception) |
| async def unhandled_exception_handler(_: Any, exc: Exception) -> JSONResponse: |
| LOGGER.exception("unhandled error: %s", exc) |
| return build_error_response( |
| QueryAPIError("E_UNKNOWN", "系统错误", {"reason": str(exc)}, status_code=500) |
| ) |
|
|
|
|
| @app.get("/api/v1/health", response_model=HealthResponse) |
| async def health() -> HealthResponse: |
| return HealthResponse( |
| service=SERVICE_NAME, |
| time=datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z"), |
| sql_parser="sqlglot" if sqlglot else "regex-fallback", |
| db_driver="psycopg" if psycopg else "missing", |
| ) |
|
|
|
|
| @app.get("/api/v1/predict/health", response_model=PredictHealthResponse) |
| async def predict_health() -> PredictHealthResponse: |
| ready = False |
| error_text = TIMESFM_INIT_ERROR |
| try: |
| ensure_non_doubao_model() |
| except QueryAPIError as exc: |
| error_text = exc.message |
| if timesfm is not None: |
| try: |
| ensure_timesfm_ready() |
| ready = True |
| except QueryAPIError as exc: |
| error_text = exc.message if error_text is None else f"{error_text}; {exc.message}" |
| else: |
| pkg_error = "timesfm package not installed" |
| error_text = pkg_error if error_text is None else f"{error_text}; {pkg_error}" |
| return PredictHealthResponse( |
| ok=True, |
| service=SERVICE_NAME, |
| time=_now_utc_iso(), |
| timesfm_enabled=timesfm is not None, |
| timesfm_ready=ready, |
| timesfm_device=TIMESFM_DEVICE, |
| timesfm_checkpoint=TIMESFM_CHECKPOINT, |
| timesfm_init_error=error_text, |
| ) |
|
|
|
|
| @app.post("/api/v1/predict/run", response_model=PredictRunResponse) |
| async def predict_run(request: PredictRunRequest, _: None = Depends(verify_api_key)) -> PredictRunResponse: |
| trace_id = request.trace_id or str(uuid.uuid4()) |
| module_id = normalize_module_id(request.module_id, request.question) |
| ensure_non_doubao_model() |
| ensure_timesfm_ready() |
| horizon_days = infer_horizon_days(request.question) |
|
|
| series_pack = build_series_from_db(module_id=module_id, question=request.question, horizon_days=horizon_days) |
| horizon_days = int(series_pack.get("horizon_days", horizon_days)) |
|
|
| forecast_map: dict[str, dict[str, Any]] = {} |
| for series_item in series_pack.get("series", []): |
| series_name = str(series_item.get("name") or "series") |
| values = series_item.get("values") or [] |
| if not isinstance(values, list) or not values: |
| continue |
| forecast_map[series_name] = run_timesfm_forecast(values, horizon_days=horizon_days) |
|
|
| if not forecast_map: |
| raise QueryAPIError("E_PREDICT_DATA_EMPTY", "缺少可预测序列", {"module_id": module_id}, status_code=400) |
|
|
| module_output = run_module_logic(module_id, series_pack, forecast_map) |
| data = PredictRunData( |
| module_id=module_id, |
| db_used=True, |
| model_trace=[f"timesfm:{TIMESFM_CHECKPOINT}", f"summary:{NON_DOUBAO_MODEL_ID}"], |
| forecast=module_output["forecast"], |
| recommendation=module_output["recommendation"], |
| risk_flags=module_output["risk_flags"], |
| backend_draft_summary=module_output["backend_draft_summary"], |
| trace_id=trace_id, |
| ) |
| _insert_prediction_audit( |
| trace_id=trace_id, |
| module_id=module_id, |
| question=request.question, |
| horizon_days=horizon_days, |
| payload=data, |
| ) |
| LOGGER.info( |
| "predict_run_ok trace_id=%s module_id=%s series_count=%s", |
| trace_id, |
| module_id, |
| len(forecast_map), |
| ) |
| return PredictRunResponse(ok=True, data=data, err=None) |
|
|
|
|
| @app.post("/api/v1/query", response_model=QueryResponse) |
| async def query(request: QueryRequest, _: None = Depends(verify_api_key)) -> QueryResponse: |
| trace_id = request.trace_id or str(uuid.uuid4()) |
| prepared = prepare_sql_for_execution(request.sql) |
| LOGGER.info( |
| "query_request trace_id=%s user_query=%s tables=%s", |
| trace_id, |
| request.user_query, |
| json.dumps(sorted(prepared.tables), ensure_ascii=False), |
| ) |
| data = run_query(prepared.sql, trace_id) |
| return QueryResponse(ok=True, data=data, err=None) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False) |