nova / app.py
katenovaa's picture
Update app.py
f9d1272 verified
from __future__ import annotations
import hashlib
import json
import logging
import os
import re
import threading
import time
import uuid
from dataclasses import dataclass
from datetime import date, datetime, timezone
from decimal import Decimal
from statistics import mean
from typing import Any
from fastapi import Depends, FastAPI, Header, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
try:
import psycopg
except ImportError: # pragma: no cover - exercised in runtime environments without psycopg
psycopg = None
try:
import sqlglot
from sqlglot import exp
except ImportError: # pragma: no cover - exercised in runtime environments without sqlglot
sqlglot = None
exp = None
try:
import numpy as np
except ImportError: # pragma: no cover - optional dependency for TimesFM runtime
np = None
try:
import timesfm
except ImportError: # pragma: no cover - exercised in runtime environments without timesfm
timesfm = None
LOGGER = logging.getLogger("pig_query_api")
if not LOGGER.handlers:
logging.basicConfig(
level=os.getenv("PIG_QUERY_LOG_LEVEL", "INFO").upper(),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
SERVICE_NAME = "pig-query-api"
API_KEY = os.getenv("PIG_QUERY_API_KEY", "change-me")
# Hardcoded PostgreSQL config requested by project owner.
PG_HOST = "47.107.138.68"
PG_PORT = 5432
PG_DATABASE = "postgres"
PG_USER = "postgres"
PG_PASSWORD = "63BvPoC5JiHu16k8Gh"
PG_DSN = f"postgresql://{PG_USER}:{PG_PASSWORD}@{PG_HOST}:{PG_PORT}/{PG_DATABASE}"
MAX_ROWS = int(os.getenv("PIG_QUERY_MAX_ROWS", "100"))
STATEMENT_TIMEOUT_MS = int(os.getenv("PIG_QUERY_STATEMENT_TIMEOUT_MS", "3000"))
DEFAULT_SELECT_LIMIT = int(os.getenv("PIG_QUERY_DEFAULT_SELECT_LIMIT", "50"))
TIMESFM_DEVICE = os.getenv("TIMESFM_DEVICE", "cpu").lower()
TIMESFM_CHECKPOINT = os.getenv("TIMESFM_CHECKPOINT", "google/timesfm-2.0-500m-pytorch")
TIMESFM_BACKEND = os.getenv("TIMESFM_BACKEND", "cpu").lower()
TIMESFM_MAX_HORIZON = int(os.getenv("TIMESFM_MAX_HORIZON", "15"))
TIMESFM_BATCH_SIZE = int(os.getenv("TIMESFM_BATCH_SIZE", "32"))
NON_DOUBAO_MODEL_ID = os.getenv("NON_DOUBAO_MODEL_ID", "NON_DOUBAO_MODEL_ID")
SERIES_LOOKBACK_DAYS = int(os.getenv("PREDICT_SERIES_LOOKBACK_DAYS", "90"))
SERIES_MIN_POINTS = int(os.getenv("PREDICT_SERIES_MIN_POINTS", "8"))
PREDICT_LOG_ENABLED = os.getenv("PREDICT_LOG_ENABLED", "1").lower() in {"1", "true", "yes", "on"}
TIMESFM_MODEL_LOCK = threading.Lock()
TIMESFM_MODEL: Any | None = None
TIMESFM_INIT_ERROR: str | None = None
ALLOWED_TABLES = frozenset(
{
'public."采食量标准"',
'public."体重标准"',
'public."料肉比标准"',
'public."猪仔信息"',
'public."日常生长数据1"',
"public.weighing_records",
'public."饲料库存记录"',
'public."玉米价格"',
'public."猪仔价格"',
"public.生猪价格",
'public."猪肉批发价"',
}
)
DENY_KEYWORDS = (
"insert",
"update",
"delete",
"drop",
"alter",
"truncate",
"create",
"grant",
"revoke",
"copy",
"merge",
"vacuum",
"analyze",
"comment",
)
COMMENT_PATTERN = re.compile(r"(--|/\*|\*/)")
LIMIT_PATTERN = re.compile(r"\blimit\s+(\d+)\b", re.IGNORECASE)
AGGREGATE_PATTERN = re.compile(
r"\b(count|sum|avg|min|max|string_agg|array_agg|json_agg)\s*\(",
re.IGNORECASE,
)
SQL_START_PATTERN = re.compile(r"^\s*(select|with)\b", re.IGNORECASE)
TABLE_REF_PATTERN = re.compile(r"\b(from|join)\s+((?:public\.)?\"?[A-Za-z0-9_\u4e00-\u9fff]+\"?)", re.IGNORECASE)
HORIZON_PATTERN = re.compile(r"(?:未来|接下来|后续)?\s*(\d{1,2})\s*(?:天|日)")
MODULE_IDS = {
"m1_price_forecast",
"m2_slaughter_window",
"m3_feeding_plan",
"m4_disease_risk",
"m5_inventory_procurement",
"m6_piglet_buy",
}
def _normalize_identifier(value: str) -> str:
parts = [part.strip().strip('"') for part in value.split(".") if part.strip()]
return ".".join(part.lower() for part in parts)
ALLOWED_TABLE_NAMES = frozenset(
{
normalized
for table in ALLOWED_TABLES
for normalized in (
_normalize_identifier(table),
_normalize_identifier(table).split(".", 1)[-1],
)
}
)
class ErrorDetail(BaseModel):
code: str
message: str
detail: dict[str, Any] = Field(default_factory=dict)
class QuerySuccessData(BaseModel):
columns: list[str]
rows: list[list[Any]]
row_count: int
sql_executed: str
empty: bool
trace_id: str
class QueryResponse(BaseModel):
ok: bool
data: QuerySuccessData | None = None
err: ErrorDetail | None = None
class QueryRequest(BaseModel):
user_query: str = Field(..., min_length=1, max_length=2000)
sql: str = Field(..., min_length=1, max_length=20000)
trace_id: str | None = Field(default=None, max_length=128)
class HealthResponse(BaseModel):
ok: bool = True
service: str
time: str
sql_parser: str
db_driver: str
class PredictRunRequest(BaseModel):
question: str = Field(..., min_length=1, max_length=2000)
module_id: str | None = Field(default=None, max_length=64)
trace_id: str | None = Field(default=None, max_length=128)
class PredictRunData(BaseModel):
module_id: str
db_used: bool
model_trace: list[str]
forecast: dict[str, Any]
recommendation: dict[str, Any]
risk_flags: list[str]
backend_draft_summary: str
trace_id: str
class PredictRunResponse(BaseModel):
ok: bool
data: PredictRunData | None = None
err: ErrorDetail | None = None
class PredictHealthResponse(BaseModel):
ok: bool
service: str
time: str
timesfm_enabled: bool
timesfm_ready: bool
timesfm_device: str
timesfm_checkpoint: str
timesfm_init_error: str | None = None
class QueryAPIError(Exception):
def __init__(self, code: str, message: str, detail: dict[str, Any] | None = None, status_code: int = 400):
super().__init__(message)
self.code = code
self.message = message
self.detail = detail or {}
self.status_code = status_code
@dataclass(frozen=True)
class PreparedQuery:
sql: str
tables: set[str]
def normalize_sql(sql: str) -> str:
return sql.replace("\ufeff", "").strip()
def reject_empty_sql(sql: str) -> None:
if not sql:
raise QueryAPIError("E_SQL_EMPTY", "SQL 不能为空")
def reject_multi_statement(sql: str) -> None:
stripped = sql.rstrip(";").strip()
if ";" in stripped:
raise QueryAPIError("E_SQL_MULTI", "SQL 只允许单条语句")
def reject_comments(sql: str) -> None:
if COMMENT_PATTERN.search(sql):
raise QueryAPIError("E_SQL_COMMENT", "SQL 不允许注释")
def reject_dangerous_keywords(sql: str) -> None:
lower_sql = sql.lower()
for keyword in DENY_KEYWORDS:
if re.search(rf"\b{re.escape(keyword)}\b", lower_sql):
raise QueryAPIError(
"E_SQL_DENIED",
"SQL 不在允许范围内",
{"reason": f"contains denied keyword: {keyword}"},
)
def parse_sql(sql: str) -> Any | None:
if sqlglot is None:
return None
try:
return sqlglot.parse_one(sql, read="postgres")
except Exception as exc: # pragma: no cover - depends on sqlglot internals
raise QueryAPIError("E_SQL_PARSE", "SQL 解析失败", {"reason": str(exc)}) from exc
def ensure_select_only(ast: Any | None, sql: str) -> None:
if not SQL_START_PATTERN.match(sql):
raise QueryAPIError("E_SQL_TYPE", "只允许 SELECT 或 WITH 查询")
if ast is None or exp is None:
return
denied_expression_types = (
exp.Insert,
exp.Update,
exp.Delete,
exp.Drop,
exp.Alter,
exp.Create,
exp.Command,
exp.Copy,
exp.Merge,
)
for node_type in denied_expression_types:
if ast.find(node_type):
raise QueryAPIError("E_SQL_DENIED", "SQL 不在允许范围内", {"reason": f"contains {node_type.__name__}"})
def _extract_tables_with_sqlglot(ast: Any) -> set[str]:
tables: set[str] = set()
cte_names: set[str] = set()
if exp is None:
return tables
for cte in ast.find_all(exp.CTE):
alias = getattr(cte, "alias", None)
if alias:
cte_names.add(_normalize_identifier(alias))
for table in ast.find_all(exp.Table):
db = getattr(table, "db", None)
name = getattr(table, "name", None)
if not name:
continue
normalized = _normalize_identifier(".".join(part for part in (db, name) if part))
if normalized in cte_names:
continue
tables.add(normalized)
return tables
def _extract_tables_with_regex(sql: str) -> set[str]:
matches = TABLE_REF_PATTERN.findall(sql)
return {_normalize_identifier(match[1]) for match in matches}
def extract_tables(ast: Any | None, sql: str) -> set[str]:
if ast is not None and sqlglot is not None:
tables = _extract_tables_with_sqlglot(ast)
if tables:
return tables
return _extract_tables_with_regex(sql)
def ensure_whitelisted_tables(tables: set[str]) -> None:
if not tables:
raise QueryAPIError("E_SQL_TABLES", "未识别到可查询表")
invalid_tables = []
for table in tables:
candidates = {table, table.split(".", 1)[-1]}
if not candidates.intersection(ALLOWED_TABLE_NAMES):
invalid_tables.append(table)
if invalid_tables:
raise QueryAPIError(
"E_SQL_DENIED",
"SQL 不在允许范围内",
{"reason": "contains non-whitelisted table", "tables": sorted(invalid_tables)},
)
def ensure_limit_if_needed(sql: str) -> None:
if LIMIT_PATTERN.search(sql):
return
lowered = sql.lower()
if "group by" in lowered:
raise QueryAPIError("E_LIMIT_REQUIRED", f"包含 GROUP BY 的查询必须显式带 LIMIT {DEFAULT_SELECT_LIMIT} 以内")
if AGGREGATE_PATTERN.search(sql):
return
raise QueryAPIError("E_LIMIT_REQUIRED", f"明细查询必须显式带 LIMIT,建议 LIMIT {DEFAULT_SELECT_LIMIT}")
def prepare_sql_for_execution(sql: str) -> PreparedQuery:
normalized = normalize_sql(sql)
reject_empty_sql(normalized)
reject_multi_statement(normalized)
reject_comments(normalized)
reject_dangerous_keywords(normalized)
ast = parse_sql(normalized)
ensure_select_only(ast, normalized)
tables = extract_tables(ast, normalized)
ensure_whitelisted_tables(tables)
ensure_limit_if_needed(normalized)
return PreparedQuery(sql=normalized, tables=tables)
def _now_utc_iso() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def _clamp_horizon(value: int | None) -> int:
if value is None:
return min(15, TIMESFM_MAX_HORIZON)
return max(7, min(TIMESFM_MAX_HORIZON, int(value)))
def infer_horizon_days(question: str) -> int:
match = HORIZON_PATTERN.search(question)
if not match:
return _clamp_horizon(15)
return _clamp_horizon(int(match.group(1)))
def ensure_non_doubao_model() -> None:
normalized = NON_DOUBAO_MODEL_ID.strip().lower()
if "doubao" in normalized:
raise QueryAPIError(
"E_MODEL_POLICY",
"专业预测总结模型禁止配置为豆包",
{"NON_DOUBAO_MODEL_ID": NON_DOUBAO_MODEL_ID},
status_code=500,
)
def infer_module_id_from_question(question: str) -> str:
normalized = question.lower()
if any(token in normalized for token in ("库存", "采购", "补货", "饲料还够")):
return "m5_inventory_procurement"
if any(token in normalized for token in ("猪仔", "补栏", "买入", "购入")):
return "m6_piglet_buy"
if any(token in normalized for token in ("疫病", "异常", "预警", "风险", "病")):
return "m4_disease_risk"
if any(token in normalized for token in ("饲喂", "喂料", "料肉比", "采食")):
return "m3_feeding_plan"
if any(token in normalized for token in ("出栏", "卖猪", "窗口")):
return "m2_slaughter_window"
return "m1_price_forecast"
def normalize_module_id(module_id: str | None, question: str) -> str:
if module_id and module_id in MODULE_IDS:
return module_id
return infer_module_id_from_question(question)
def _query_rows(sql: str, params: tuple[Any, ...] = ()) -> list[tuple[Any, ...]]:
_require_psycopg()
if not PG_DSN:
raise QueryAPIError("E_DSN_MISSING", "未配置 PIG_QUERY_PG_DSN", status_code=500)
try:
with psycopg.connect(PG_DSN, autocommit=False) as conn:
with conn.cursor() as cur:
cur.execute("SET TRANSACTION READ ONLY")
cur.execute(f"SET LOCAL statement_timeout = '{STATEMENT_TIMEOUT_MS}ms'")
cur.execute(sql, params)
return cur.fetchall()
except QueryAPIError:
raise
except Exception as exc:
raise QueryAPIError("E_PREDICT_DB", "预测数据读取失败", {"reason": str(exc)}, status_code=500) from exc
def _rows_to_series(name: str, rows: list[tuple[Any, Any]]) -> dict[str, Any]:
timestamps: list[str] = []
values: list[float] = []
for ts, value in rows:
if value is None:
continue
try:
numeric_value = float(value)
except Exception:
continue
timestamps.append(ts.isoformat() if hasattr(ts, "isoformat") else str(ts))
values.append(numeric_value)
return {"name": name, "timestamps": timestamps, "values": values}
def _query_date_price_series(table_name: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
f"""
SELECT d, v
FROM (
SELECT "日期"::date AS d, "价格"::float8 AS v
FROM public."{table_name}"
WHERE "日期" IS NOT NULL
ORDER BY "日期" DESC
LIMIT %s
) t
ORDER BY d ASC
""",
(limit_rows,),
)
return _rows_to_series(series_name, rows)
def _query_text_date_price_series(table_name: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
f"""
SELECT d, v
FROM (
SELECT
CASE
WHEN "日期" ~ '^[0-9]{{4}}-[0-9]{{2}}-[0-9]{{2}}$' THEN "日期"::date
ELSE NULL
END AS d,
"价格"::float8 AS v
FROM public."{table_name}"
ORDER BY d DESC NULLS LAST
LIMIT %s
) t
WHERE d IS NOT NULL
ORDER BY d ASC
""",
(limit_rows,),
)
return _rows_to_series(series_name, rows)
def _query_growth_avg_series(value_field: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
f"""
SELECT d, avg_v
FROM (
SELECT
CASE
WHEN "采集日期" ~ '^[0-9]{{4}}-[0-9]{{2}}-[0-9]{{2}}$' THEN "采集日期"::date
ELSE NULL
END AS d,
AVG("{value_field}"::float8) AS avg_v
FROM public."日常生长数据1"
GROUP BY d
ORDER BY d DESC NULLS LAST
LIMIT %s
) t
WHERE d IS NOT NULL
ORDER BY d ASC
""",
(limit_rows,),
)
return _rows_to_series(series_name, rows)
def _query_inventory_series(limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
"""
SELECT idx, v
FROM (
SELECT
"编号"::int AS idx,
"原有库存_kg"::float8 AS v
FROM public."饲料库存记录"
WHERE "原有库存_kg" IS NOT NULL
ORDER BY "编号" DESC
LIMIT %s
) t
ORDER BY idx ASC
""",
(limit_rows,),
)
return _rows_to_series("feed_inventory_kg", rows)
def build_series_from_db(module_id: str, question: str, horizon_days: int | None = None) -> dict[str, Any]:
horizon = _clamp_horizon(horizon_days)
query_plan: list[tuple[str, Any]] = []
if module_id == "m1_price_forecast":
query_plan = [
("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")),
("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")),
("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")),
]
elif module_id == "m2_slaughter_window":
query_plan = [
("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")),
("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")),
]
elif module_id == "m3_feeding_plan":
query_plan = [
("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")),
("avg_feed_kg", lambda: _query_growth_avg_series("采食量kg", "avg_feed_kg")),
("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")),
("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")),
]
elif module_id == "m4_disease_risk":
query_plan = [
("avg_feed_kg", lambda: _query_growth_avg_series("采食量kg", "avg_feed_kg")),
("avg_water_l", lambda: _query_growth_avg_series("饮水量l", "avg_water_l")),
("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")),
]
elif module_id == "m5_inventory_procurement":
query_plan = [
("feed_inventory_kg", _query_inventory_series),
("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")),
("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")),
]
elif module_id == "m6_piglet_buy":
query_plan = [
("piglet_price", lambda: _query_date_price_series("猪仔价格", "piglet_price")),
("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")),
]
else:
raise QueryAPIError("E_PREDICT_INPUT", "未知的 module_id", {"module_id": module_id})
series: list[dict[str, Any]] = []
source_errors: list[dict[str, Any]] = []
for series_name, fetcher in query_plan:
try:
item = fetcher()
except QueryAPIError as exc:
source_errors.append(
{
"series": series_name,
"code": exc.code,
"message": exc.message,
"detail": exc.detail,
}
)
LOGGER.warning(
"predict_series_fetch_failed module_id=%s series=%s code=%s detail=%s",
module_id,
series_name,
exc.code,
json.dumps(exc.detail, ensure_ascii=False),
)
continue
if len(item.get("values", [])) > 0:
series.append(item)
non_empty_series = [item for item in series if len(item.get("values", [])) > 0]
if not non_empty_series:
raise QueryAPIError(
"E_PREDICT_DATA_EMPTY",
"预测所需数据库数据不足",
{"module_id": module_id, "question": question, "source_errors": source_errors},
status_code=400,
)
return {
"module_id": module_id,
"question": question,
"horizon_days": horizon,
"series": non_empty_series,
"source_errors": source_errors,
}
def _accepts_kwarg(callable_obj: Any, arg_name: str) -> bool:
try:
import inspect
signature = inspect.signature(callable_obj)
return arg_name in signature.parameters
except Exception:
return False
def _create_timesfm_hparams() -> Any:
hparams_cls = getattr(timesfm, "TimesFmHparams", None)
if hparams_cls is None:
raise QueryAPIError("E_TIMESFM_INIT", "timesfm 缺少 TimesFmHparams", status_code=500)
kwargs: dict[str, Any] = {}
if _accepts_kwarg(hparams_cls, "backend"):
kwargs["backend"] = TIMESFM_BACKEND
if _accepts_kwarg(hparams_cls, "per_core_batch_size"):
kwargs["per_core_batch_size"] = TIMESFM_BATCH_SIZE
if _accepts_kwarg(hparams_cls, "horizon_len"):
kwargs["horizon_len"] = TIMESFM_MAX_HORIZON
if _accepts_kwarg(hparams_cls, "device"):
kwargs["device"] = TIMESFM_DEVICE
# TimesFM 2.0 500m checkpoint 需要固定架构参数,避免 state_dict 不匹配。
if TIMESFM_CHECKPOINT == "google/timesfm-2.0-500m-pytorch":
if _accepts_kwarg(hparams_cls, "input_patch_len"):
kwargs["input_patch_len"] = 32
if _accepts_kwarg(hparams_cls, "output_patch_len"):
kwargs["output_patch_len"] = 128
if _accepts_kwarg(hparams_cls, "num_layers"):
kwargs["num_layers"] = 50
if _accepts_kwarg(hparams_cls, "model_dims"):
kwargs["model_dims"] = 1280
if _accepts_kwarg(hparams_cls, "use_positional_embedding"):
kwargs["use_positional_embedding"] = False
return hparams_cls(**kwargs)
def _create_timesfm_checkpoint() -> Any:
checkpoint_cls = getattr(timesfm, "TimesFmCheckpoint", None)
if checkpoint_cls is None:
return TIMESFM_CHECKPOINT
if _accepts_kwarg(checkpoint_cls, "huggingface_repo_id"):
return checkpoint_cls(huggingface_repo_id=TIMESFM_CHECKPOINT)
return checkpoint_cls(TIMESFM_CHECKPOINT)
def get_timesfm_model() -> Any:
global TIMESFM_MODEL
global TIMESFM_INIT_ERROR
if TIMESFM_MODEL is not None:
return TIMESFM_MODEL
with TIMESFM_MODEL_LOCK:
if TIMESFM_MODEL is not None:
return TIMESFM_MODEL
if timesfm is None:
TIMESFM_INIT_ERROR = "timesfm package not installed"
raise QueryAPIError("E_TIMESFM_INIT", "TimesFM 未安装,请在 python311 环境安装 timesfm[torch]", status_code=500)
model_cls = getattr(timesfm, "TimesFm", None)
if model_cls is None:
TIMESFM_INIT_ERROR = "timesfm package missing TimesFm class"
raise QueryAPIError("E_TIMESFM_INIT", "TimesFM 初始化失败:缺少 TimesFm 类", status_code=500)
try:
hparams = _create_timesfm_hparams()
checkpoint = _create_timesfm_checkpoint()
init_attempts = [
lambda: model_cls(hparams=hparams, checkpoint=checkpoint),
lambda: model_cls(hparams, checkpoint),
lambda: model_cls(checkpoint=checkpoint),
]
last_error = None
for attempt in init_attempts:
try:
TIMESFM_MODEL = attempt()
TIMESFM_INIT_ERROR = None
return TIMESFM_MODEL
except Exception as exc: # pragma: no cover - runtime compatibility fallback
last_error = str(exc)
TIMESFM_INIT_ERROR = last_error or "unknown init error"
except QueryAPIError:
raise
except Exception as exc: # pragma: no cover - runtime compatibility fallback
TIMESFM_INIT_ERROR = str(exc)
raise QueryAPIError(
"E_TIMESFM_INIT",
"TimesFM 初始化失败",
{"reason": TIMESFM_INIT_ERROR, "checkpoint": TIMESFM_CHECKPOINT},
status_code=500,
)
def ensure_timesfm_ready() -> None:
get_timesfm_model()
def _normalize_timesfm_points(raw_output: Any, horizon_days: int) -> list[float]:
payload = raw_output
if isinstance(payload, tuple) and payload:
payload = payload[0]
if np is not None and isinstance(payload, np.ndarray):
values = payload
elif isinstance(payload, list):
values = payload
elif isinstance(payload, dict):
candidates = payload.get("point_forecast") or payload.get("forecast") or payload.get("mean")
if candidates is None:
raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 返回结构无法解析", {"raw_type": str(type(payload))}, status_code=500)
values = candidates
else:
raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 返回结构无法解析", {"raw_type": str(type(payload))}, status_code=500)
if np is not None and isinstance(values, np.ndarray):
if values.ndim >= 2:
points = values[0].tolist()
else:
points = values.tolist()
else:
if values and isinstance(values[0], list):
points = values[0]
else:
points = values
numeric_points: list[float] = []
for item in points:
try:
numeric_points.append(float(item))
except Exception:
continue
if not numeric_points:
raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 未返回有效预测点", status_code=500)
if len(numeric_points) < horizon_days:
numeric_points.extend([numeric_points[-1]] * (horizon_days - len(numeric_points)))
return numeric_points[:horizon_days]
def run_timesfm_forecast(series: list[float], horizon_days: int) -> dict[str, Any]:
if not series:
raise QueryAPIError("E_TIMESFM_INFER", "输入时序为空", status_code=400)
model = get_timesfm_model()
padded_series = [float(v) for v in series]
if len(padded_series) < SERIES_MIN_POINTS:
padded_series.extend([padded_series[-1]] * (SERIES_MIN_POINTS - len(padded_series)))
forecast_attempts = [
lambda: model.forecast([padded_series], freq=[0]),
lambda: model.forecast([padded_series], horizon_len=horizon_days, freq=[0]),
lambda: model.forecast([padded_series], horizon_len=horizon_days),
lambda: model.forecast([padded_series], horizon_days),
]
if np is not None:
np_input = np.asarray([padded_series], dtype=float)
forecast_attempts.extend(
[
lambda: model.forecast(np_input, freq=[0]),
lambda: model.forecast(np_input, horizon_len=horizon_days, freq=[0]),
]
)
last_error = None
for attempt in forecast_attempts:
try:
raw = attempt()
points = _normalize_timesfm_points(raw, horizon_days)
return {
"horizon_days": horizon_days,
"point_forecast": points,
"input_last_value": padded_series[-1],
}
except QueryAPIError:
raise
except Exception as exc: # pragma: no cover - runtime compatibility fallback
last_error = str(exc)
continue
raise QueryAPIError(
"E_TIMESFM_INFER",
"TimesFM 推理失败",
{"reason": last_error or "unknown forecast error"},
status_code=500,
)
def run_module_logic(module_id: str, series_pack: dict[str, Any], forecast_map: dict[str, dict[str, Any]]) -> dict[str, Any]:
horizon_days = int(series_pack.get("horizon_days", 15))
recommendation: dict[str, Any] = {}
risk_flags: list[str] = []
def _series_direction(name: str) -> str:
forecast = forecast_map.get(name, {}).get("point_forecast", [])
if len(forecast) < 2:
return "flat"
if forecast[-1] > forecast[0]:
return "up"
if forecast[-1] < forecast[0]:
return "down"
return "flat"
if module_id == "m1_price_forecast":
hog_dir = _series_direction("hog_price")
recommendation["action"] = "择机延后出栏" if hog_dir == "up" else "关注近期出栏窗口"
recommendation["reason"] = f"生猪价格未来{horizon_days}天趋势:{hog_dir}"
elif module_id == "m2_slaughter_window":
hog_dir = _series_direction("hog_price")
weight_dir = _series_direction("avg_weight_kg")
recommendation["optimal_window"] = "未来3-7天" if hog_dir == "up" and weight_dir == "up" else "未来1-3天滚动评估"
recommendation["reason"] = f"价格趋势={hog_dir}, 体重趋势={weight_dir}"
elif module_id == "m3_feeding_plan":
feed_dir = _series_direction("avg_feed_kg")
recommendation["feeding_adjustment"] = "维持配方并小幅增量" if feed_dir == "up" else "控制喂量并优化配方"
recommendation["reason"] = f"采食趋势={feed_dir}"
elif module_id == "m4_disease_risk":
water = forecast_map.get("avg_water_l", {}).get("point_forecast", [])
feed = forecast_map.get("avg_feed_kg", {}).get("point_forecast", [])
if water and feed and water[-1] < water[0] and feed[-1] < feed[0]:
risk_flags.append("采食与饮水同步下滑,建议排查健康风险")
recommendation["risk_level"] = "high" if risk_flags else "medium"
recommendation["action"] = "加强巡检与体温采样"
elif module_id == "m5_inventory_procurement":
inventory = forecast_map.get("feed_inventory_kg", {}).get("point_forecast", [])
if inventory and min(inventory) <= 0:
risk_flags.append("预测库存触底,需立即补货")
recommendation["procurement_window"] = "未来1周内分批采购"
recommendation["action"] = "结合价格低位滚动补货"
elif module_id == "m6_piglet_buy":
piglet_dir = _series_direction("piglet_price")
recommendation["buy_window"] = "未来3-5天" if piglet_dir == "down" else "等待价格回调"
recommendation["reason"] = f"猪仔价格趋势={piglet_dir}"
else:
raise QueryAPIError("E_PREDICT_INPUT", "未知 module_id", {"module_id": module_id}, status_code=400)
all_points = [point for item in forecast_map.values() for point in item.get("point_forecast", [])]
avg_point = round(mean(all_points), 4) if all_points else 0.0
recommendation.setdefault("confidence", 0.65 if all_points else 0.0)
recommendation.setdefault("summary_metric", avg_point)
summary = f"{module_id} 已完成未来{horizon_days}天预测,建议:{json.dumps(recommendation, ensure_ascii=False)}"
return {
"forecast": forecast_map,
"recommendation": recommendation,
"risk_flags": risk_flags,
"backend_draft_summary": summary,
}
def _insert_prediction_audit(
trace_id: str,
module_id: str,
question: str,
horizon_days: int,
payload: PredictRunData,
) -> None:
if not PREDICT_LOG_ENABLED:
return
if psycopg is None or not PG_DSN:
LOGGER.warning("predict_audit_insert_skipped trace_id=%s reason=psycopg_or_dsn_missing", trace_id)
return
try:
with psycopg.connect(PG_DSN, autocommit=True) as conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO public.ai_prediction_run (
trace_id, module_id, question, horizon_days, db_used, model_trace, status, created_at
)
VALUES (%s, %s, %s, %s, %s, %s::jsonb, %s, now())
ON CONFLICT (trace_id) DO UPDATE
SET module_id = EXCLUDED.module_id,
question = EXCLUDED.question,
horizon_days = EXCLUDED.horizon_days,
db_used = EXCLUDED.db_used,
model_trace = EXCLUDED.model_trace,
status = EXCLUDED.status
RETURNING id
""",
(
trace_id,
module_id,
question,
horizon_days,
payload.db_used,
json.dumps(payload.model_trace, ensure_ascii=False),
"ok",
),
)
row = cur.fetchone()
if not row:
return
run_id = int(row[0])
cur.execute(
"""
INSERT INTO public.ai_prediction_output (
run_id, forecast, recommendation, risk_flags, backend_draft_summary, created_at
)
VALUES (%s, %s::jsonb, %s::jsonb, %s::jsonb, %s, now())
""",
(
run_id,
json.dumps(payload.forecast, ensure_ascii=False),
json.dumps(payload.recommendation, ensure_ascii=False),
json.dumps(payload.risk_flags, ensure_ascii=False),
payload.backend_draft_summary,
),
)
except Exception as exc:
LOGGER.warning("predict_audit_insert_skipped trace_id=%s reason=%s", trace_id, str(exc))
def _require_psycopg() -> None:
if psycopg is None:
raise QueryAPIError("E_DRIVER_MISSING", "缺少 psycopg 依赖,请先安装后再启动服务", status_code=500)
def _to_jsonable(value: Any) -> Any:
if value is None:
return None
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Decimal):
return float(value)
if isinstance(value, (datetime, date)):
return value.isoformat()
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def run_query(sql: str, trace_id: str) -> QuerySuccessData:
_require_psycopg()
if not PG_DSN:
raise QueryAPIError("E_DSN_MISSING", "未配置 PIG_QUERY_PG_DSN", status_code=500)
started = time.perf_counter()
try:
with psycopg.connect(PG_DSN, autocommit=False) as conn:
with conn.cursor() as cur:
cur.execute("SET TRANSACTION READ ONLY")
cur.execute(f"SET LOCAL statement_timeout = '{STATEMENT_TIMEOUT_MS}ms'")
cur.execute(sql)
rows = cur.fetchmany(MAX_ROWS)
columns = [desc.name for desc in cur.description] if cur.description else []
except QueryAPIError:
raise
except Exception as exc:
raise QueryAPIError("E_DB_QUERY", "数据库查询失败", {"reason": str(exc)}, status_code=500) from exc
elapsed_ms = round((time.perf_counter() - started) * 1000, 2)
normalized_rows = [[_to_jsonable(value) for value in row] for row in rows]
data = QuerySuccessData(
columns=columns,
rows=normalized_rows,
row_count=len(normalized_rows),
sql_executed=sql,
empty=len(normalized_rows) == 0,
trace_id=trace_id,
)
LOGGER.info(
"query_ok trace_id=%s sql_hash=%s row_count=%s duration_ms=%s",
trace_id,
hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16],
data.row_count,
elapsed_ms,
)
return data
async def verify_api_key(x_api_key: str = Header(default="", alias="X-API-Key")) -> None:
if not x_api_key or x_api_key != API_KEY:
raise QueryAPIError("E_UNAUTHORIZED", "无效的 API Key", status_code=status.HTTP_401_UNAUTHORIZED)
def build_error_response(exc: QueryAPIError, trace_id: str | None = None) -> JSONResponse:
detail = dict(exc.detail)
if trace_id and "trace_id" not in detail:
detail["trace_id"] = trace_id
body = QueryResponse(
ok=False,
data=None,
err=ErrorDetail(code=exc.code, message=exc.message, detail=detail),
)
return JSONResponse(status_code=exc.status_code, content=body.model_dump())
app = FastAPI(title=SERVICE_NAME, version="1.0.0")
@app.exception_handler(QueryAPIError)
async def query_api_error_handler(_: Any, exc: QueryAPIError) -> JSONResponse:
return build_error_response(exc)
@app.exception_handler(Exception)
async def unhandled_exception_handler(_: Any, exc: Exception) -> JSONResponse: # pragma: no cover - defensive fallback
LOGGER.exception("unhandled error: %s", exc)
return build_error_response(
QueryAPIError("E_UNKNOWN", "系统错误", {"reason": str(exc)}, status_code=500)
)
@app.get("/api/v1/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(
service=SERVICE_NAME,
time=datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z"),
sql_parser="sqlglot" if sqlglot else "regex-fallback",
db_driver="psycopg" if psycopg else "missing",
)
@app.get("/api/v1/predict/health", response_model=PredictHealthResponse)
async def predict_health() -> PredictHealthResponse:
ready = False
error_text = TIMESFM_INIT_ERROR
try:
ensure_non_doubao_model()
except QueryAPIError as exc:
error_text = exc.message
if timesfm is not None:
try:
ensure_timesfm_ready()
ready = True
except QueryAPIError as exc:
error_text = exc.message if error_text is None else f"{error_text}; {exc.message}"
else:
pkg_error = "timesfm package not installed"
error_text = pkg_error if error_text is None else f"{error_text}; {pkg_error}"
return PredictHealthResponse(
ok=True,
service=SERVICE_NAME,
time=_now_utc_iso(),
timesfm_enabled=timesfm is not None,
timesfm_ready=ready,
timesfm_device=TIMESFM_DEVICE,
timesfm_checkpoint=TIMESFM_CHECKPOINT,
timesfm_init_error=error_text,
)
@app.post("/api/v1/predict/run", response_model=PredictRunResponse)
async def predict_run(request: PredictRunRequest, _: None = Depends(verify_api_key)) -> PredictRunResponse:
trace_id = request.trace_id or str(uuid.uuid4())
module_id = normalize_module_id(request.module_id, request.question)
ensure_non_doubao_model()
ensure_timesfm_ready()
horizon_days = infer_horizon_days(request.question)
series_pack = build_series_from_db(module_id=module_id, question=request.question, horizon_days=horizon_days)
horizon_days = int(series_pack.get("horizon_days", horizon_days))
forecast_map: dict[str, dict[str, Any]] = {}
for series_item in series_pack.get("series", []):
series_name = str(series_item.get("name") or "series")
values = series_item.get("values") or []
if not isinstance(values, list) or not values:
continue
forecast_map[series_name] = run_timesfm_forecast(values, horizon_days=horizon_days)
if not forecast_map:
raise QueryAPIError("E_PREDICT_DATA_EMPTY", "缺少可预测序列", {"module_id": module_id}, status_code=400)
module_output = run_module_logic(module_id, series_pack, forecast_map)
data = PredictRunData(
module_id=module_id,
db_used=True,
model_trace=[f"timesfm:{TIMESFM_CHECKPOINT}", f"summary:{NON_DOUBAO_MODEL_ID}"],
forecast=module_output["forecast"],
recommendation=module_output["recommendation"],
risk_flags=module_output["risk_flags"],
backend_draft_summary=module_output["backend_draft_summary"],
trace_id=trace_id,
)
_insert_prediction_audit(
trace_id=trace_id,
module_id=module_id,
question=request.question,
horizon_days=horizon_days,
payload=data,
)
LOGGER.info(
"predict_run_ok trace_id=%s module_id=%s series_count=%s",
trace_id,
module_id,
len(forecast_map),
)
return PredictRunResponse(ok=True, data=data, err=None)
@app.post("/api/v1/query", response_model=QueryResponse)
async def query(request: QueryRequest, _: None = Depends(verify_api_key)) -> QueryResponse:
trace_id = request.trace_id or str(uuid.uuid4())
prepared = prepare_sql_for_execution(request.sql)
LOGGER.info(
"query_request trace_id=%s user_query=%s tables=%s",
trace_id,
request.user_query,
json.dumps(sorted(prepared.tables), ensure_ascii=False),
)
data = run_query(prepared.sql, trace_id)
return QueryResponse(ok=True, data=data, err=None)
if __name__ == "__main__": # pragma: no cover - manual run path
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False)