ko / app.py
katenovaa's picture
Update app.py
a0eb6b2 verified
from __future__ import annotations
import hashlib
import json
import logging
import os
import re
import threading
import time
import uuid
from dataclasses import dataclass
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from statistics import mean
from typing import Any
from fastapi import Depends, FastAPI, Header, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
try:
import psycopg
except ImportError: # pragma: no cover - exercised in runtime environments without psycopg
psycopg = None
try:
import sqlglot
from sqlglot import exp
except ImportError: # pragma: no cover - exercised in runtime environments without sqlglot
sqlglot = None
exp = None
try:
import numpy as np
except ImportError: # pragma: no cover - optional dependency for TimesFM runtime
np = None
try:
import timesfm
except ImportError: # pragma: no cover - exercised in runtime environments without timesfm
timesfm = None
LOGGER = logging.getLogger("pig_query_api")
if not LOGGER.handlers:
logging.basicConfig(
level=os.getenv("PIG_QUERY_LOG_LEVEL", "INFO").upper(),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
SERVICE_NAME = "pig-query-api"
API_KEY = os.getenv("PIG_QUERY_API_KEY", "change-me")
# Hardcoded PostgreSQL config requested by project owner.
PG_HOST = "47.107.138.68"
PG_PORT = 5432
PG_DATABASE = "pig"
PG_USER = "postgres"
PG_PASSWORD = "63BvPoC5JiHu16k8Gh"
PG_DSN = f"postgresql://{PG_USER}:{PG_PASSWORD}@{PG_HOST}:{PG_PORT}/{PG_DATABASE}"
MAX_ROWS = int(os.getenv("PIG_QUERY_MAX_ROWS", "100"))
STATEMENT_TIMEOUT_MS = int(os.getenv("PIG_QUERY_STATEMENT_TIMEOUT_MS", "3000"))
DEFAULT_SELECT_LIMIT = int(os.getenv("PIG_QUERY_DEFAULT_SELECT_LIMIT", "50"))
TIMESFM_DEVICE = os.getenv("TIMESFM_DEVICE", "cpu").lower()
TIMESFM_CHECKPOINT = os.getenv("TIMESFM_CHECKPOINT", "google/timesfm-2.0-500m-pytorch")
TIMESFM_BACKEND = os.getenv("TIMESFM_BACKEND", "cpu").lower()
TIMESFM_MAX_HORIZON = int(os.getenv("TIMESFM_MAX_HORIZON", "15"))
TIMESFM_BATCH_SIZE = int(os.getenv("TIMESFM_BATCH_SIZE", "32"))
SERIES_LOOKBACK_DAYS = int(os.getenv("PREDICT_SERIES_LOOKBACK_DAYS", "90"))
SERIES_MIN_POINTS = int(os.getenv("PREDICT_SERIES_MIN_POINTS", "8"))
PREDICT_LOG_ENABLED = os.getenv("PREDICT_LOG_ENABLED", "1").lower() in {"1", "true", "yes", "on"}
PREDICT_BACKEND_SUMMARY_ENABLED = os.getenv("PIG_PREDICT_BACKEND_SUMMARY", "on").lower() in {
"1",
"true",
"yes",
"on",
}
RISK_SCORE_LOW = float(os.getenv("PIG_PREDICT_RISK_SCORE_LOW", "6"))
RISK_SCORE_HIGH = float(os.getenv("PIG_PREDICT_RISK_SCORE_HIGH", "12"))
TIMESFM_MODEL_LOCK = threading.Lock()
TIMESFM_MODEL: Any | None = None
TIMESFM_INIT_ERROR: str | None = None
ALLOWED_TABLES = frozenset(
{
'public."采食量标准"',
'public."称重记录"',
'public."豆粕价"',
'public."料肉比标准"',
'public."日常生长数据"',
'public."生猪产能"',
"public.生猪价格",
'public."饲料库存记录"',
'public."体重标准"',
'public."玉米价格"',
'public."仔猪价格"',
'public."猪肉批发价"',
'public."猪仔价格"',
'public."猪仔信息"',
}
)
DENY_KEYWORDS = (
"insert",
"update",
"delete",
"drop",
"alter",
"truncate",
"create",
"grant",
"revoke",
"copy",
"merge",
"vacuum",
"analyze",
"comment",
)
COMMENT_PATTERN = re.compile(r"(--|/\*|\*/)")
LIMIT_PATTERN = re.compile(r"\blimit\s+(\d+)\b", re.IGNORECASE)
AGGREGATE_PATTERN = re.compile(
r"\b(count|sum|avg|min|max|string_agg|array_agg|json_agg)\s*\(",
re.IGNORECASE,
)
SQL_START_PATTERN = re.compile(r"^\s*(select|with)\b", re.IGNORECASE)
TABLE_REF_PATTERN = re.compile(r"\b(from|join)\s+((?:public\.)?\"?[A-Za-z0-9_\u4e00-\u9fff]+\"?)", re.IGNORECASE)
HORIZON_PATTERN = re.compile(r"(?:未来|接下来|后续)?\s*(\d{1,2})\s*(?:天|日)")
MODULE_IDS = {
"m1_price_forecast",
"m2_slaughter_window",
"m3_feeding_plan",
"m4_disease_risk",
"m5_inventory_procurement",
"m6_piglet_buy",
"m7_statistics",
"m8_anomaly_detection",
}
REALTIME_MODULE_IDS = frozenset({"m7_statistics", "m8_anomaly_detection"})
def _normalize_identifier(value: str) -> str:
parts = [part.strip().strip('"') for part in value.split(".") if part.strip()]
return ".".join(part.lower() for part in parts)
ALLOWED_TABLE_NAMES = frozenset(
{
normalized
for table in ALLOWED_TABLES
for normalized in (
_normalize_identifier(table),
_normalize_identifier(table).split(".", 1)[-1],
)
}
)
class ErrorDetail(BaseModel):
code: str
message: str
detail: dict[str, Any] = Field(default_factory=dict)
class QuerySuccessData(BaseModel):
columns: list[str]
rows: list[list[Any]]
row_count: int
sql_executed: str
empty: bool
trace_id: str
class QueryResponse(BaseModel):
ok: bool
data: QuerySuccessData | None = None
err: ErrorDetail | None = None
class QueryRequest(BaseModel):
user_query: str = Field(..., min_length=1, max_length=2000)
sql: str = Field(..., min_length=1, max_length=20000)
trace_id: str | None = Field(default=None, max_length=128)
class HealthResponse(BaseModel):
ok: bool = True
service: str
time: str
sql_parser: str
db_driver: str
class PredictRunRequest(BaseModel):
question: str = Field(..., min_length=1, max_length=2000)
module_id: str | None = Field(default=None, max_length=64)
trace_id: str | None = Field(default=None, max_length=128)
class PredictRunData(BaseModel):
module_id: str
db_used: bool
model_trace: list[str]
forecast: dict[str, Any]
recommendation: dict[str, Any]
risk_flags: list[str]
decision_metrics: dict[str, Any] | None = None
backend_draft_summary: str
executive_summary: str | None = None
trace_id: str
class PredictRunResponse(BaseModel):
ok: bool
data: PredictRunData | None = None
err: ErrorDetail | None = None
class PredictHealthResponse(BaseModel):
ok: bool
service: str
time: str
timesfm_enabled: bool
timesfm_ready: bool
timesfm_device: str
timesfm_checkpoint: str
timesfm_init_error: str | None = None
class QueryAPIError(Exception):
def __init__(self, code: str, message: str, detail: dict[str, Any] | None = None, status_code: int = 400):
super().__init__(message)
self.code = code
self.message = message
self.detail = detail or {}
self.status_code = status_code
@dataclass(frozen=True)
class PreparedQuery:
sql: str
tables: set[str]
def normalize_sql(sql: str) -> str:
return sql.replace("\ufeff", "").strip()
def reject_empty_sql(sql: str) -> None:
if not sql:
raise QueryAPIError("E_SQL_EMPTY", "SQL 不能为空")
def reject_multi_statement(sql: str) -> None:
stripped = sql.rstrip(";").strip()
if ";" in stripped:
raise QueryAPIError("E_SQL_MULTI", "SQL 只允许单条语句")
def reject_comments(sql: str) -> None:
if COMMENT_PATTERN.search(sql):
raise QueryAPIError("E_SQL_COMMENT", "SQL 不允许注释")
def reject_dangerous_keywords(sql: str) -> None:
lower_sql = sql.lower()
for keyword in DENY_KEYWORDS:
if re.search(rf"\b{re.escape(keyword)}\b", lower_sql):
raise QueryAPIError(
"E_SQL_DENIED",
"SQL 不在允许范围内",
{"reason": f"contains denied keyword: {keyword}"},
)
def parse_sql(sql: str) -> Any | None:
if sqlglot is None:
return None
try:
return sqlglot.parse_one(sql, read="postgres")
except Exception as exc: # pragma: no cover - depends on sqlglot internals
raise QueryAPIError("E_SQL_PARSE", "SQL 解析失败", {"reason": str(exc)}) from exc
def ensure_select_only(ast: Any | None, sql: str) -> None:
if not SQL_START_PATTERN.match(sql):
raise QueryAPIError("E_SQL_TYPE", "只允许 SELECT 或 WITH 查询")
if ast is None or exp is None:
return
denied_expression_types = (
exp.Insert,
exp.Update,
exp.Delete,
exp.Drop,
exp.Alter,
exp.Create,
exp.Command,
exp.Copy,
exp.Merge,
)
for node_type in denied_expression_types:
if ast.find(node_type):
raise QueryAPIError("E_SQL_DENIED", "SQL 不在允许范围内", {"reason": f"contains {node_type.__name__}"})
def _extract_tables_with_sqlglot(ast: Any) -> set[str]:
tables: set[str] = set()
cte_names: set[str] = set()
if exp is None:
return tables
for cte in ast.find_all(exp.CTE):
alias = getattr(cte, "alias", None)
if alias:
cte_names.add(_normalize_identifier(alias))
for table in ast.find_all(exp.Table):
db = getattr(table, "db", None)
name = getattr(table, "name", None)
if not name:
continue
normalized = _normalize_identifier(".".join(part for part in (db, name) if part))
if normalized in cte_names:
continue
tables.add(normalized)
return tables
def _extract_tables_with_regex(sql: str) -> set[str]:
matches = TABLE_REF_PATTERN.findall(sql)
return {_normalize_identifier(match[1]) for match in matches}
def extract_tables(ast: Any | None, sql: str) -> set[str]:
if ast is not None and sqlglot is not None:
tables = _extract_tables_with_sqlglot(ast)
if tables:
return tables
return _extract_tables_with_regex(sql)
def ensure_whitelisted_tables(tables: set[str]) -> None:
if not tables:
raise QueryAPIError("E_SQL_TABLES", "未识别到可查询表")
invalid_tables = []
for table in tables:
candidates = {table, table.split(".", 1)[-1]}
if not candidates.intersection(ALLOWED_TABLE_NAMES):
invalid_tables.append(table)
if invalid_tables:
raise QueryAPIError(
"E_SQL_DENIED",
"SQL 不在允许范围内",
{"reason": "contains non-whitelisted table", "tables": sorted(invalid_tables)},
)
def ensure_limit_if_needed(sql: str) -> None:
if LIMIT_PATTERN.search(sql):
return
lowered = sql.lower()
if "group by" in lowered:
raise QueryAPIError("E_LIMIT_REQUIRED", f"包含 GROUP BY 的查询必须显式带 LIMIT {DEFAULT_SELECT_LIMIT} 以内")
if AGGREGATE_PATTERN.search(sql):
return
raise QueryAPIError("E_LIMIT_REQUIRED", f"明细查询必须显式带 LIMIT,建议 LIMIT {DEFAULT_SELECT_LIMIT}")
def prepare_sql_for_execution(sql: str) -> PreparedQuery:
normalized = normalize_sql(sql)
reject_empty_sql(normalized)
reject_multi_statement(normalized)
reject_comments(normalized)
reject_dangerous_keywords(normalized)
ast = parse_sql(normalized)
ensure_select_only(ast, normalized)
tables = extract_tables(ast, normalized)
ensure_whitelisted_tables(tables)
ensure_limit_if_needed(normalized)
return PreparedQuery(sql=normalized, tables=tables)
def _now_utc_iso() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def _clamp_horizon(value: int | None) -> int:
if value is None:
return min(15, TIMESFM_MAX_HORIZON)
return max(7, min(TIMESFM_MAX_HORIZON, int(value)))
def infer_horizon_days(question: str) -> int:
match = HORIZON_PATTERN.search(question)
if not match:
return _clamp_horizon(15)
return _clamp_horizon(int(match.group(1)))
def infer_module_id_from_question(question: str) -> str:
normalized = question.lower()
# m8: 异常检测
if any(token in normalized for token in ("异常", "预警", "风险", "有没有")):
return "m8_anomaly_detection"
# m7: 统计分析
if any(token in normalized for token in ("有多少", "多少头", "现在", "当前", "统计", "总共")):
return "m7_statistics"
# m5: 库存采购
if any(token in normalized for token in ("库存", "采购", "补货", "饲料还够")):
return "m5_inventory_procurement"
# m6: 仔猪购买
if any(token in normalized for token in ("猪仔", "补栏", "买入", "购入")):
return "m6_piglet_buy"
# m4: 疾病风险
if any(token in normalized for token in ("疫病", "病", "采食", "饮水")):
return "m4_disease_risk"
# m3: 饲料计划
if any(token in normalized for token in ("饲喂", "喂料", "料肉比")):
return "m3_feeding_plan"
# m2: 出栏窗口
if any(token in normalized for token in ("出栏", "卖猪", "窗口")):
return "m2_slaughter_window"
# m1: 价格预测(默认)
return "m1_price_forecast"
def normalize_module_id(module_id: str | None, question: str) -> str:
if module_id and module_id in MODULE_IDS:
return module_id
return infer_module_id_from_question(question)
def _query_rows(sql: str, params: tuple[Any, ...] = ()) -> list[tuple[Any, ...]]:
_require_psycopg()
if not PG_DSN:
raise QueryAPIError("E_DSN_MISSING", "未配置 PIG_QUERY_PG_DSN", status_code=500)
try:
with psycopg.connect(PG_DSN, autocommit=False) as conn:
with conn.cursor() as cur:
cur.execute("SET TRANSACTION READ ONLY")
cur.execute(f"SET LOCAL statement_timeout = '{STATEMENT_TIMEOUT_MS}ms'")
cur.execute(sql, params)
return cur.fetchall()
except QueryAPIError:
raise
except Exception as exc:
raise QueryAPIError("E_PREDICT_DB", "预测数据读取失败", {"reason": str(exc)}, status_code=500) from exc
def _rows_to_series(name: str, rows: list[tuple[Any, Any]]) -> dict[str, Any]:
timestamps: list[str] = []
values: list[float] = []
for ts, value in rows:
if value is None:
continue
try:
numeric_value = float(value)
except Exception:
continue
timestamps.append(ts.isoformat() if hasattr(ts, "isoformat") else str(ts))
values.append(numeric_value)
return {"name": name, "timestamps": timestamps, "values": values}
def _query_date_price_series(table_name: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
f"""
SELECT d, v
FROM (
SELECT "日期"::date AS d, "价格"::float8 AS v
FROM public."{table_name}"
WHERE "日期" IS NOT NULL
ORDER BY "日期" DESC
LIMIT %s
) t
ORDER BY d ASC
""",
(limit_rows,),
)
return _rows_to_series(series_name, rows)
def _query_text_date_price_series(table_name: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
f"""
SELECT d, v
FROM (
SELECT
CASE
WHEN "日期" ~ '^[0-9]{{4}}-[0-9]{{2}}-[0-9]{{2}}$' THEN "日期"::date
ELSE NULL
END AS d,
"价格"::float8 AS v
FROM public."{table_name}"
ORDER BY d DESC NULLS LAST
LIMIT %s
) t
WHERE d IS NOT NULL
ORDER BY d ASC
""",
(limit_rows,),
)
return _rows_to_series(series_name, rows)
def _query_growth_avg_series(value_field: str, series_name: str, limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
f"""
SELECT d, avg_v
FROM (
SELECT
"采集时间"::date AS d,
AVG("{value_field}"::float8) AS avg_v
FROM public."日常生长数据"
GROUP BY "采集时间"::date
ORDER BY d DESC
LIMIT %s
) t
ORDER BY d ASC
""",
(limit_rows,),
)
return _rows_to_series(series_name, rows)
def _query_inventory_series(limit_rows: int = SERIES_LOOKBACK_DAYS) -> dict[str, Any]:
rows = _query_rows(
"""
SELECT idx, v
FROM (
SELECT
"编号"::int AS idx,
"原有库存_kg"::float8 AS v
FROM public."饲料库存记录"
WHERE "原有库存_kg" IS NOT NULL
ORDER BY "编号" DESC
LIMIT %s
) t
ORDER BY idx ASC
""",
(limit_rows,),
)
return _rows_to_series("feed_inventory_kg", rows)
# 硬编码日期:将2025-11-05作为"今日"(最新数据日期)
DEMO_REFERENCE_DATE = datetime(2025, 11, 5, 0, 0, 0)
DEMO_LOOKBACK_DAYS = 90 # 预测模块(m1-m6)的标准回溯期
DEMO_LOOKBACK_DAYS_REALTIME = 7 # m7、m8 实时分析的回溯期(减少数据量:500猪×7天≈3500条 vs 30天≈15000条)
M8_FEED_THRESHOLD = 0.8
M8_WATER_THRESHOLD = 0.8
def _is_realtime_question(question: str) -> bool:
"""检测问题是否为实时查询(询问当前/今日情况)"""
realtime_keywords = ["今日", "当前", "现在", "今天", "最近", "实时", "最新", "此刻", "眼下"]
return any(keyword in question for keyword in realtime_keywords)
def _get_realtime_window(reference_date: datetime) -> tuple[date, date]:
window_end = reference_date.date()
window_start = (reference_date - timedelta(days=DEMO_LOOKBACK_DAYS_REALTIME - 1)).date()
return window_start, window_end
def _build_weight_comment(weight_avg_kg: float, pig_count: int) -> str:
if pig_count <= 0:
return "样本不足,继续观察"
if weight_avg_kg < 25:
return "体重偏轻,建议关注营养与健康状态"
if weight_avg_kg <= 90:
return "体重正常,建议保持当前管理节奏"
return "体重偏高,建议评估出栏与饲喂策略"
def _query_growth_avg_series_with_daterange(
value_field: str,
series_name: str,
start_date: date | None = None,
end_date: date | None = None,
limit_rows: int | None = None,
) -> dict[str, Any]:
"""查询指定日期范围内的生长数据平均值
Args:
value_field: 字段名 (e.g., "体重kg", "采食量kg", "饮水量l")
series_name: 序列名称
start_date: 开始日期(不含此日期)
end_date: 结束日期(包含此日期)
limit_rows: 最大行数限制
"""
# 如果未指定日期范围,使用原始行数限制
if start_date is None or end_date is None:
return _query_growth_avg_series(value_field, series_name, limit_rows or SERIES_LOOKBACK_DAYS)
rows = _query_rows(
f"""
SELECT d, avg_v
FROM (
SELECT
"采集时间"::date AS d,
AVG("{value_field}"::float8) AS avg_v
FROM public."日常生长数据"
WHERE "采集时间"::date >= %s AND "采集时间"::date <= %s
GROUP BY "采集时间"::date
ORDER BY d DESC
LIMIT %s
) t
ORDER BY d ASC
""",
(start_date, end_date, limit_rows or DEMO_LOOKBACK_DAYS),
)
return _rows_to_series(series_name, rows)
def _query_m7_statistics(start_date: date, end_date: date) -> dict[str, Any]:
rows = _query_rows(
"""
SELECT
COALESCE(COUNT(DISTINCT "id"), 0)::int AS pig_count,
COALESCE(SUM("饮水量l"::float8), 0)::float8 AS water_total_l,
COALESCE(AVG("饮水量l"::float8), 0)::float8 AS water_avg_l,
COALESCE(AVG("体重kg"::float8), 0)::float8 AS weight_avg_kg,
COALESCE(AVG("采食量kg"::float8), 0)::float8 AS feed_avg_kg
FROM public."日常生长数据"
WHERE "采集时间"::date >= %s AND "采集时间"::date <= %s
""",
(start_date, end_date),
)
if not rows:
pig_count = 0
water_total_l = 0.0
water_avg_l = 0.0
weight_avg_kg = 0.0
feed_avg_kg = 0.0
else:
pig_count, water_total_l, water_avg_l, weight_avg_kg, feed_avg_kg = rows[0]
pig_count = int(pig_count or 0)
water_total_l = float(water_total_l or 0.0)
water_avg_l = float(water_avg_l or 0.0)
weight_avg_kg = float(weight_avg_kg or 0.0)
feed_avg_kg = float(feed_avg_kg or 0.0)
return {
"reference_date": DEMO_REFERENCE_DATE.date().isoformat(),
"window_start": start_date.isoformat(),
"window_end": end_date.isoformat(),
"pig_count": pig_count,
"water_total_l": round(water_total_l, 4),
"water_avg_l": round(water_avg_l, 4),
"weight_avg_kg": round(weight_avg_kg, 4),
"weight_comment": _build_weight_comment(weight_avg_kg=weight_avg_kg, pig_count=pig_count),
"feed_avg_kg": round(feed_avg_kg, 4),
}
def _query_pig_count(start_date: date | None = None, end_date: date | None = None) -> int:
"""查询指定日期范围内的猪群总数
Args:
start_date: 开始日期(不含此日期)
end_date: 结束日期(包含此日期)
Returns:
猪群总数
"""
if start_date is None or end_date is None:
sql = """
SELECT COUNT(DISTINCT "id")
FROM public."日常生长数据"
WHERE "id" IS NOT NULL
"""
params: tuple[Any, ...] = ()
else:
sql = """
SELECT COUNT(DISTINCT "id")
FROM public."日常生长数据"
WHERE "id" IS NOT NULL
AND "采集时间"::date >= %s AND "采集时间"::date <= %s
"""
params = (start_date, end_date)
try:
rows = _query_rows(sql, params)
return int(rows[0][0]) if rows else 0
except QueryAPIError:
raise
except Exception as exc:
LOGGER.exception("query_pig_count_failed start_date=%s end_date=%s", start_date, end_date)
raise QueryAPIError("E_PREDICT_DB", "统计数据读取失败", {"reason": str(exc)}, status_code=500) from exc
def _query_m8_pig_window_avg(start_date: date, end_date: date) -> list[dict[str, Any]]:
rows = _query_rows(
"""
SELECT
"id",
AVG("采食量kg"::float8) AS recent_feed,
AVG("饮水量l"::float8) AS recent_water
FROM public."日常生长数据"
WHERE "采集时间"::date >= %s AND "采集时间"::date <= %s
AND "id" IS NOT NULL
GROUP BY "id"
""",
(start_date, end_date),
)
return [
{
"pig_id": str(pig_id),
"recent_feed": float(recent_feed) if recent_feed is not None else None,
"recent_water": float(recent_water) if recent_water is not None else None,
}
for pig_id, recent_feed, recent_water in rows
]
def _query_anomaly_pigs(
start_date: date | None = None,
end_date: date | None = None,
feed_threshold: float = 0.8,
water_threshold: float = 0.8,
) -> list[dict[str, Any]]:
"""查询异常猪列表(采食量或饮水量异常)
Args:
start_date: 开始日期
end_date: 结束日期
feed_threshold: 采食量异常阈值(低于正常值的百分比)
water_threshold: 饮水量异常阈值(低于正常值的百分比)
Returns:
异常猪的列表,包含pig_id和异常原因
"""
if start_date is None or end_date is None:
return []
baseline = _query_m7_statistics(start_date, end_date)
avg_feed = float(baseline["feed_avg_kg"])
avg_water = float(baseline["water_avg_l"])
if avg_feed <= 0 and avg_water <= 0:
return []
pig_rows = _query_m8_pig_window_avg(start_date, end_date)
results = []
for pig_item in pig_rows:
pig_id = pig_item["pig_id"]
recent_feed = pig_item["recent_feed"]
recent_water = pig_item["recent_water"]
anomalies = []
if recent_feed is not None and avg_feed > 0 and recent_feed < avg_feed * feed_threshold:
anomalies.append(f"采食量异常({recent_feed:.2f}kg,低于平均{avg_feed:.2f}kg)")
if recent_water is not None and avg_water > 0 and recent_water < avg_water * water_threshold:
anomalies.append(f"饮水量异常({recent_water:.2f}升,低于平均{avg_water:.2f}升)")
if anomalies:
results.append({
"pig_id": str(pig_id),
"recent_feed": round(float(recent_feed), 2) if recent_feed is not None else 0,
"recent_water": round(float(recent_water), 2) if recent_water is not None else 0,
"anomalies": anomalies,
})
return results
def build_series_from_db(module_id: str, question: str, horizon_days: int | None = None) -> dict[str, Any]:
horizon = _clamp_horizon(horizon_days)
if module_id in REALTIME_MODULE_IDS:
demo_start_date, demo_end_date = _get_realtime_window(DEMO_REFERENCE_DATE)
LOGGER.info(
"build_series_from_db using demo date range %s to %s for module %s",
demo_start_date,
demo_end_date,
module_id,
)
return {
"module_id": module_id,
"question": question,
"horizon_days": DEMO_LOOKBACK_DAYS_REALTIME,
"series": [],
"source_errors": [],
"reference_date": DEMO_REFERENCE_DATE.date().isoformat(),
"window_start": demo_start_date.isoformat(),
"window_end": demo_end_date.isoformat(),
}
query_plan: list[tuple[str, Any]] = []
if module_id == "m1_price_forecast":
query_plan = [
("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")),
("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")),
("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")),
]
elif module_id == "m2_slaughter_window":
query_plan = [
("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")),
("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")),
]
elif module_id == "m3_feeding_plan":
query_plan = [
("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")),
("avg_feed_kg", lambda: _query_growth_avg_series("采食量kg", "avg_feed_kg")),
("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")),
("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")),
]
elif module_id == "m4_disease_risk":
query_plan = [
("avg_feed_kg", lambda: _query_growth_avg_series("采食量kg", "avg_feed_kg")),
("avg_water_l", lambda: _query_growth_avg_series("饮水量l", "avg_water_l")),
("avg_weight_kg", lambda: _query_growth_avg_series("体重kg", "avg_weight_kg")),
]
elif module_id == "m5_inventory_procurement":
query_plan = [
("feed_inventory_kg", _query_inventory_series),
("corn_price", lambda: _query_date_price_series("玉米价格", "corn_price")),
("soymeal_price", lambda: _query_text_date_price_series("豆粕价", "soymeal_price")),
]
elif module_id == "m6_piglet_buy":
query_plan = [
("piglet_price", lambda: _query_date_price_series("猪仔价格", "piglet_price")),
("hog_price", lambda: _query_date_price_series("生猪价格", "hog_price")),
]
else:
raise QueryAPIError("E_PREDICT_INPUT", "未知的 module_id", {"module_id": module_id})
series: list[dict[str, Any]] = []
source_errors: list[dict[str, Any]] = []
for series_name, fetcher in query_plan:
try:
item = fetcher()
except QueryAPIError as exc:
source_errors.append(
{
"series": series_name,
"code": exc.code,
"message": exc.message,
"detail": exc.detail,
}
)
LOGGER.warning(
"predict_series_fetch_failed module_id=%s series=%s code=%s detail=%s",
module_id,
series_name,
exc.code,
json.dumps(exc.detail, ensure_ascii=False),
)
continue
if len(item.get("values", [])) > 0:
series.append(item)
non_empty_series = [item for item in series if len(item.get("values", [])) > 0]
if not non_empty_series:
raise QueryAPIError(
"E_PREDICT_DATA_EMPTY",
"预测所需数据库数据不足",
{"module_id": module_id, "question": question, "source_errors": source_errors},
status_code=400,
)
return {
"module_id": module_id,
"question": question,
"horizon_days": horizon,
"series": non_empty_series,
"source_errors": source_errors,
}
def _accepts_kwarg(callable_obj: Any, arg_name: str) -> bool:
try:
import inspect
signature = inspect.signature(callable_obj)
return arg_name in signature.parameters
except Exception:
return False
def _create_timesfm_hparams() -> Any:
hparams_cls = getattr(timesfm, "TimesFmHparams", None)
if hparams_cls is None:
raise QueryAPIError("E_TIMESFM_INIT", "timesfm 缺少 TimesFmHparams", status_code=500)
kwargs: dict[str, Any] = {}
if _accepts_kwarg(hparams_cls, "backend"):
kwargs["backend"] = TIMESFM_BACKEND
if _accepts_kwarg(hparams_cls, "per_core_batch_size"):
kwargs["per_core_batch_size"] = TIMESFM_BATCH_SIZE
if _accepts_kwarg(hparams_cls, "horizon_len"):
kwargs["horizon_len"] = TIMESFM_MAX_HORIZON
if _accepts_kwarg(hparams_cls, "device"):
kwargs["device"] = TIMESFM_DEVICE
# TimesFM 2.0 500m checkpoint 需要固定架构参数,避免 state_dict 不匹配。
if TIMESFM_CHECKPOINT == "google/timesfm-2.0-500m-pytorch":
if _accepts_kwarg(hparams_cls, "input_patch_len"):
kwargs["input_patch_len"] = 32
if _accepts_kwarg(hparams_cls, "output_patch_len"):
kwargs["output_patch_len"] = 128
if _accepts_kwarg(hparams_cls, "num_layers"):
kwargs["num_layers"] = 50
if _accepts_kwarg(hparams_cls, "model_dims"):
kwargs["model_dims"] = 1280
if _accepts_kwarg(hparams_cls, "use_positional_embedding"):
kwargs["use_positional_embedding"] = False
return hparams_cls(**kwargs)
def _create_timesfm_checkpoint() -> Any:
checkpoint_cls = getattr(timesfm, "TimesFmCheckpoint", None)
if checkpoint_cls is None:
return TIMESFM_CHECKPOINT
if _accepts_kwarg(checkpoint_cls, "huggingface_repo_id"):
return checkpoint_cls(huggingface_repo_id=TIMESFM_CHECKPOINT)
return checkpoint_cls(TIMESFM_CHECKPOINT)
def get_timesfm_model() -> Any:
global TIMESFM_MODEL
global TIMESFM_INIT_ERROR
if TIMESFM_MODEL is not None:
return TIMESFM_MODEL
with TIMESFM_MODEL_LOCK:
if TIMESFM_MODEL is not None:
return TIMESFM_MODEL
if timesfm is None:
TIMESFM_INIT_ERROR = "timesfm package not installed"
raise QueryAPIError("E_TIMESFM_INIT", "TimesFM 未安装,请在 python311 环境安装 timesfm[torch]", status_code=500)
model_cls = getattr(timesfm, "TimesFm", None)
if model_cls is None:
TIMESFM_INIT_ERROR = "timesfm package missing TimesFm class"
raise QueryAPIError("E_TIMESFM_INIT", "TimesFM 初始化失败:缺少 TimesFm 类", status_code=500)
try:
hparams = _create_timesfm_hparams()
checkpoint = _create_timesfm_checkpoint()
init_attempts = [
lambda: model_cls(hparams=hparams, checkpoint=checkpoint),
lambda: model_cls(hparams, checkpoint),
lambda: model_cls(checkpoint=checkpoint),
]
last_error = None
for attempt in init_attempts:
try:
TIMESFM_MODEL = attempt()
TIMESFM_INIT_ERROR = None
return TIMESFM_MODEL
except Exception as exc: # pragma: no cover - runtime compatibility fallback
last_error = str(exc)
TIMESFM_INIT_ERROR = last_error or "unknown init error"
except QueryAPIError:
raise
except Exception as exc: # pragma: no cover - runtime compatibility fallback
TIMESFM_INIT_ERROR = str(exc)
raise QueryAPIError(
"E_TIMESFM_INIT",
"TimesFM 初始化失败",
{"reason": TIMESFM_INIT_ERROR, "checkpoint": TIMESFM_CHECKPOINT},
status_code=500,
)
def ensure_timesfm_ready() -> None:
get_timesfm_model()
def _normalize_timesfm_points(raw_output: Any, horizon_days: int) -> list[float]:
payload = raw_output
if isinstance(payload, tuple) and payload:
payload = payload[0]
if np is not None and isinstance(payload, np.ndarray):
values = payload
elif isinstance(payload, list):
values = payload
elif isinstance(payload, dict):
candidates = payload.get("point_forecast") or payload.get("forecast") or payload.get("mean")
if candidates is None:
raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 返回结构无法解析", {"raw_type": str(type(payload))}, status_code=500)
values = candidates
else:
raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 返回结构无法解析", {"raw_type": str(type(payload))}, status_code=500)
if np is not None and isinstance(values, np.ndarray):
if values.ndim >= 2:
points = values[0].tolist()
else:
points = values.tolist()
else:
if values and isinstance(values[0], list):
points = values[0]
else:
points = values
numeric_points: list[float] = []
for item in points:
try:
numeric_points.append(float(item))
except Exception:
continue
if not numeric_points:
raise QueryAPIError("E_TIMESFM_INFER", "TimesFM 未返回有效预测点", status_code=500)
if len(numeric_points) < horizon_days:
numeric_points.extend([numeric_points[-1]] * (horizon_days - len(numeric_points)))
return numeric_points[:horizon_days]
def run_timesfm_forecast(series: list[float], horizon_days: int) -> dict[str, Any]:
if not series:
raise QueryAPIError("E_TIMESFM_INFER", "输入时序为空", status_code=400)
model = get_timesfm_model()
padded_series = [float(v) for v in series]
if len(padded_series) < SERIES_MIN_POINTS:
padded_series.extend([padded_series[-1]] * (SERIES_MIN_POINTS - len(padded_series)))
forecast_attempts = [
lambda: model.forecast([padded_series], freq=[0]),
lambda: model.forecast([padded_series], horizon_len=horizon_days, freq=[0]),
lambda: model.forecast([padded_series], horizon_len=horizon_days),
lambda: model.forecast([padded_series], horizon_days),
]
if np is not None:
np_input = np.asarray([padded_series], dtype=float)
forecast_attempts.extend(
[
lambda: model.forecast(np_input, freq=[0]),
lambda: model.forecast(np_input, horizon_len=horizon_days, freq=[0]),
]
)
last_error = None
for attempt in forecast_attempts:
try:
raw = attempt()
points = _normalize_timesfm_points(raw, horizon_days)
return {
"horizon_days": horizon_days,
"point_forecast": points,
"input_last_value": padded_series[-1],
}
except QueryAPIError:
raise
except Exception as exc: # pragma: no cover - runtime compatibility fallback
last_error = str(exc)
continue
raise QueryAPIError(
"E_TIMESFM_INFER",
"TimesFM 推理失败",
{"reason": last_error or "unknown forecast error"},
status_code=500,
)
def run_module_logic(module_id: str, series_pack: dict[str, Any], forecast_map: dict[str, dict[str, Any]]) -> dict[str, Any]:
horizon_days = int(series_pack.get("horizon_days", 15))
if module_id in REALTIME_MODULE_IDS:
window_start = date.fromisoformat(series_pack.get("window_start", DEMO_REFERENCE_DATE.date().isoformat()))
window_end = date.fromisoformat(series_pack.get("window_end", DEMO_REFERENCE_DATE.date().isoformat()))
else:
window_start = None
window_end = None
recommendation: dict[str, Any] = {}
risk_flags: list[str] = []
def _series_direction(name: str) -> str:
forecast = forecast_map.get(name, {}).get("point_forecast", [])
if len(forecast) < 2:
return "flat"
if forecast[-1] > forecast[0]:
return "up"
if forecast[-1] < forecast[0]:
return "down"
return "flat"
if module_id == "m1_price_forecast":
hog_dir = _series_direction("hog_price")
recommendation["action"] = "择机延后出栏" if hog_dir == "up" else "关注近期出栏窗口"
recommendation["reason"] = f"生猪价格未来{horizon_days}天趋势:{hog_dir}"
elif module_id == "m2_slaughter_window":
hog_dir = _series_direction("hog_price")
weight_dir = _series_direction("avg_weight_kg")
recommendation["optimal_window"] = "未来3-7天" if hog_dir == "up" and weight_dir == "up" else "未来1-3天滚动评估"
recommendation["reason"] = f"价格趋势={hog_dir}, 体重趋势={weight_dir}"
elif module_id == "m3_feeding_plan":
feed_dir = _series_direction("avg_feed_kg")
recommendation["feeding_adjustment"] = "维持配方并小幅增量" if feed_dir == "up" else "控制喂量并优化配方"
recommendation["reason"] = f"采食趋势={feed_dir}"
elif module_id == "m4_disease_risk":
water = forecast_map.get("avg_water_l", {}).get("point_forecast", [])
feed = forecast_map.get("avg_feed_kg", {}).get("point_forecast", [])
if water and feed and water[-1] < water[0] and feed[-1] < feed[0]:
risk_flags.append("采食与饮水同步下滑,建议排查健康风险")
recommendation["risk_level"] = "high" if risk_flags else "medium"
recommendation["action"] = "加强巡检与体温采样"
elif module_id == "m5_inventory_procurement":
inventory = forecast_map.get("feed_inventory_kg", {}).get("point_forecast", [])
if inventory and min(inventory) <= 0:
risk_flags.append("预测库存触底,需立即补货")
recommendation["procurement_window"] = "未来1周内分批采购"
recommendation["action"] = "结合价格低位滚动补货"
elif module_id == "m6_piglet_buy":
piglet_dir = _series_direction("piglet_price")
recommendation["buy_window"] = "未来3-5天" if piglet_dir == "down" else "等待价格回调"
recommendation["reason"] = f"猪仔价格趋势={piglet_dir}"
elif module_id == "m7_statistics":
if window_start is None or window_end is None:
raise QueryAPIError("E_PREDICT_INPUT", "m7 缺少统计窗口", status_code=500)
stats = _query_m7_statistics(window_start, window_end)
recommendation["pig_count"] = stats["pig_count"]
recommendation["statistics"] = stats
recommendation["summary"] = (
f"统计窗口 {stats['window_start']}{stats['window_end']}:"
f"猪群总数 {stats['pig_count']} 头,饮水总量 {stats['water_total_l']:.2f} 升,"
f"饮水均值 {stats['water_avg_l']:.2f} 升,平均体重 {stats['weight_avg_kg']:.2f} kg({stats['weight_comment']})。"
)
elif module_id == "m8_anomaly_detection":
if window_start is None or window_end is None:
raise QueryAPIError("E_PREDICT_INPUT", "m8 缺少统计窗口", status_code=500)
stats = _query_m7_statistics(window_start, window_end)
feed_baseline = float(stats["feed_avg_kg"])
water_baseline = float(stats["water_avg_l"])
pig_rows = _query_m8_pig_window_avg(window_start, window_end)
anomaly_pigs = []
anomalies = []
for pig_info in pig_rows:
recent_feed = pig_info.get("recent_feed")
recent_water = pig_info.get("recent_water")
item_anomalies = []
if recent_feed is not None and feed_baseline > 0 and recent_feed < feed_baseline * M8_FEED_THRESHOLD:
item_anomalies.append(f"采食量异常({recent_feed:.2f}kg,低于平均{feed_baseline:.2f}kg)")
if recent_water is not None and water_baseline > 0 and recent_water < water_baseline * M8_WATER_THRESHOLD:
item_anomalies.append(f"饮水量异常({recent_water:.2f}升,低于平均{water_baseline:.2f}升)")
if item_anomalies:
anomaly_entry = {
"pig_id": pig_info["pig_id"],
"recent_feed": round(float(recent_feed), 2) if recent_feed is not None else 0,
"recent_water": round(float(recent_water), 2) if recent_water is not None else 0,
"anomalies": item_anomalies,
}
anomaly_pigs.append(anomaly_entry)
anomalies.append(
f"猪ID {pig_info['pig_id']}: {'; '.join(item_anomalies)}"
)
recommendation["baseline_statistics"] = {
"window_start": window_start.isoformat(),
"window_end": window_end.isoformat(),
"feed_avg_kg": round(feed_baseline, 4),
"water_avg_l": round(water_baseline, 4),
}
recommendation["anomaly_count"] = len(anomaly_pigs)
recommendation["anomaly_pigs"] = anomaly_pigs
recommendation["anomalies"] = anomalies
recommendation["risk_level"] = "高" if len(anomaly_pigs) >= 5 else "中" if len(anomaly_pigs) >= 2 else "低"
recommendation["action"] = (
"立即启动应急预案,对异常猪进行隔离和治疗" if len(anomaly_pigs) >= 5
else "加强日常巡检,对异常猪进行重点监测" if len(anomaly_pigs) >= 2
else "继续正常饲养管理,无明显异常"
)
else:
raise QueryAPIError("E_PREDICT_INPUT", "未知 module_id", {"module_id": module_id}, status_code=400)
if module_id in REALTIME_MODULE_IDS:
if module_id == "m8_anomaly_detection":
anomaly_count = int(recommendation.get("anomaly_count", 0))
risk_level = recommendation.get("risk_level", "低")
else:
stats = recommendation.get("statistics", {})
anomaly_count = 0
risk_level = "低" if int(stats.get("pig_count", 0)) > 0 else "中"
decision_metrics = {
"horizon_days": DEMO_LOOKBACK_DAYS_REALTIME,
"series_count": 0,
"series": [],
"overall_change_pct": 0.0,
"trend_bias": "平稳",
"dominant_series": "historical_baseline",
"risk_score": float(anomaly_count),
"risk_level": risk_level,
"period_overview": [],
"data_source": "historical_7d",
}
recommendation.setdefault("confidence", 1.0)
recommendation.setdefault("summary_metric", float(recommendation.get("anomaly_count", 0)))
recommendation["dominant_driver"] = decision_metrics["dominant_series"]
recommendation["trend_bias"] = decision_metrics["trend_bias"]
recommendation["action_plan"] = {
"short_window": "1-3天",
"mid_window": "4-7天",
"long_window": "8-7天(已结束)",
"short_term_action": recommendation.get("action", "持续观察"),
"mid_term_action": "按日复核异常明细",
"long_term_action": "以周为单位复盘",
"risk_action": "根据异常数量调整巡检强度",
}
summary = (
f"{module_id} 已完成历史7天统计分析。窗口={series_pack.get('window_start')}~{series_pack.get('window_end')},"
f"风险等级={decision_metrics['risk_level']},建议动作={recommendation.get('action', '持续观察')}。"
)
return {
"forecast": {},
"recommendation": recommendation,
"risk_flags": risk_flags,
"decision_metrics": decision_metrics,
"backend_draft_summary": summary,
}
all_points = [point for item in forecast_map.values() for point in item.get("point_forecast", [])]
avg_point = round(mean(all_points), 4) if all_points else 0.0
recommendation.setdefault("confidence", 0.65 if all_points else 0.0)
recommendation.setdefault("summary_metric", avg_point)
decision_metrics = _build_decision_metrics(forecast_map, horizon_days)
recommendation["risk_level"] = decision_metrics["risk_level"]
recommendation["dominant_driver"] = decision_metrics["dominant_series"]
recommendation["trend_bias"] = decision_metrics["trend_bias"]
recommendation["action_plan"] = _build_action_plan(module_id=module_id, horizon_days=horizon_days, decision_metrics=decision_metrics)
summary = (
f"{module_id} 已完成未来{horizon_days}天预测。主导指标={decision_metrics['dominant_series']},"
f"总体趋势={decision_metrics['trend_bias']},风险等级={decision_metrics['risk_level']}。"
f"建议动作={recommendation.get('action', '关注并分批复核')}"
)
return {
"forecast": forecast_map,
"recommendation": recommendation,
"risk_flags": risk_flags,
"decision_metrics": decision_metrics,
"backend_draft_summary": summary,
}
def _build_decision_metrics(
forecast_map: dict[str, dict[str, Any]],
horizon_days: int,
) -> dict[str, Any]:
def _safe_pct(start: float, end: float) -> float:
if abs(start) < 1e-9:
return 0.0
return round((end - start) / abs(start) * 100, 4)
def _safe_volatility(points: list[float], base: float) -> float:
if len(points) < 2:
return 0.0
diffs = [abs(points[idx] - points[idx - 1]) for idx in range(1, len(points))]
if not diffs:
return 0.0
return round((sum(diffs) / len(diffs)) / max(abs(base), 1e-9) * 100, 4)
def _build_period_profile(points: list[float], start: float) -> list[dict[str, Any]]:
windows = [("短线", min(3, horizon_days)), ("中线", min(7, horizon_days)), ("长期", horizon_days)]
period_metrics: list[dict[str, Any]] = []
seen: set[int] = set()
for label, days in windows:
if days <= 1 or days in seen:
continue
seen.add(days)
segment = points[:days]
if not segment:
continue
period_start = float(segment[0])
period_end = float(segment[-1])
period_metrics.append(
{
"label": label,
"days": days,
"start": round(period_start, 4),
"end": round(period_end, 4),
"avg": round(mean(segment), 4),
"min": round(min(segment), 4),
"max": round(max(segment), 4),
"change_pct": _safe_pct(start, period_end),
"volatility": _safe_volatility(segment, period_start),
}
)
return period_metrics
metric_items: list[dict[str, Any]] = []
for name, series_item in forecast_map.items():
points = series_item.get("point_forecast", []) or []
if not points:
continue
start = float(points[0])
end = float(points[-1])
total_pct = _safe_pct(start, end)
abs_change = round(abs(end - start), 4)
volatility = _safe_volatility(points, start)
trend = "up" if end > start else "down" if end < start else "flat"
trend_cn = "上行" if trend == "up" else "下行" if trend == "down" else "平稳"
half = max(1, len(points) // 2)
first_half = mean(points[:half])
second_half = mean(points[half:])
if second_half > first_half * 1.002:
acceleration = "加速"
elif second_half < first_half * 0.998:
acceleration = "减速"
else:
acceleration = "平稳"
metric_items.append(
{
"series": name,
"start": start,
"end": end,
"change_pct": total_pct,
"abs_change": abs_change,
"volatility": volatility,
"horizon_days": horizon_days,
"trend": trend_cn,
"acceleration": acceleration,
"period_overview": _build_period_profile(points, start),
}
)
overall_change = round(sum(item["change_pct"] for item in metric_items) / len(metric_items), 4) if metric_items else 0.0
risk_score = 0.0
for item in metric_items:
risk_score += abs(item["change_pct"]) * 1.2 + item["volatility"] * 0.8
if item["acceleration"] != "平稳":
risk_score += 1.2
dominant_item = max(metric_items, key=lambda item: abs(item["change_pct"])) if metric_items else None
dominant_series = dominant_item["series"] if dominant_item else "未知"
period_overview: list[dict[str, Any]] = []
if metric_items:
aggregate: dict[int, list[dict[str, Any]]] = {}
for item in metric_items:
for window in item.get("period_overview", []):
aggregate.setdefault(window["days"], []).append(window)
for days, windows in aggregate.items():
period_overview.append(
{
"label": windows[0]["label"],
"days": days,
"start": round(mean([w["start"] for w in windows]), 4),
"end": round(mean([w["end"] for w in windows]), 4),
"avg": round(mean([w["avg"] for w in windows]), 4),
"min": round(min(w["min"] for w in windows), 4),
"max": round(max(w["max"] for w in windows), 4),
"change_pct": round(mean([w["change_pct"] for w in windows]), 4),
"volatility": round(mean([w["volatility"] for w in windows]), 4),
}
)
period_overview.sort(key=lambda x: x["days"])
if overall_change > 1:
trend_bias = "上行"
elif overall_change < -1:
trend_bias = "下行"
else:
trend_bias = "平稳"
risk_level = "高" if risk_score >= RISK_SCORE_HIGH else "中" if risk_score >= RISK_SCORE_LOW else "低"
return {
"horizon_days": horizon_days,
"series_count": len(metric_items),
"series": metric_items,
"overall_change_pct": overall_change,
"trend_bias": trend_bias,
"dominant_series": dominant_series,
"risk_score": round(risk_score, 4),
"risk_level": risk_level,
"period_overview": period_overview,
}
def _build_action_plan(module_id: str, horizon_days: int, decision_metrics: dict[str, Any]) -> dict[str, str]:
trend = decision_metrics.get("trend_bias", "平稳")
risk_level = decision_metrics.get("risk_level", "低")
dominant = decision_metrics.get("dominant_series", "核心指标")
if horizon_days <= 3:
short_window = f"1-{horizon_days}天"
mid_window = "剩余期间"
long_window = f"{horizon_days + 1}-{horizon_days}天(已结束)"
elif horizon_days <= 7:
short_window = "1-3天"
mid_window = f"4-{horizon_days}天"
long_window = f"{horizon_days + 1}-{horizon_days}天(已结束)"
else:
short_window = "1-3天"
mid_window = "4-7天"
long_window = f"8-{horizon_days}天"
if module_id == "m1_price_forecast":
if trend == "上行":
short_action = f"偏上行:关注 {dominant} 的回撤确认,2 个交易日无反转可择机加仓出栏窗口"
mid_action = "同步关注玉米/豆粕成本,避免单边决策"
long_action = "建议滚动调整出栏节奏,降低一次性决策风险"
elif trend == "下行":
short_action = f"偏下行:优先观察 {dominant} 是否持续下探,短期提前安排销售计划"
mid_action = "设置分批出售线,避免集中抛售"
long_action = "若连续三周回落可启动保本优先策略"
else:
short_action = "偏平稳:先维持当前节奏,等待趋势突破确认"
mid_action = "以资金压力和饲料占用成本决定微调幅度"
long_action = "按周复盘并保留10%应急流动性头寸"
elif module_id == "m2_slaughter_window":
short_action = "先看价格窗口与体重趋势是否同步向上,未同步则延后"
mid_action = "逐步缩短出栏批次间隔,验证单月收益率"
long_action = "保持7天一次滚动评估"
elif module_id == "m3_feeding_plan":
short_action = "先小步调整饲喂结构,观察3天采食响应"
mid_action = "体重未达目标则反向压降日增重偏差过高区间"
long_action = "以料肉比趋势为准,每周复核一次"
elif module_id == "m4_disease_risk":
short_action = "优先做采食与饮水双指标突增突降告警"
mid_action = "出现同步下滑继续延长复核周期并分组检疫"
long_action = "连续出现两次异常后启动应急预案"
elif module_id == "m5_inventory_procurement":
short_action = "保留最近1周库存安全天数,优先做分批补货"
mid_action = "将采购节奏与价格拐点挂钩,弱势时加速补齐"
long_action = "风险偏高时建议延长库存冗余,避免断供"
else:
short_action = "先观察短线波动,按最小动作执行"
mid_action = "分批复核后再执行下一阶段"
long_action = "以周为单位复盘,避免过度追涨杀跌"
if risk_level == "高":
risk_action = "高风险预警:动作必须分批且可回撤,新增仓位前先做小额验证"
elif risk_level == "中":
risk_action = "中风险:每3-5天复核一次,执行范围控制在既定预算内"
else:
risk_action = "低风险:按标准流程推进,偏差过大时才调整"
return {
"short_window": short_window,
"mid_window": mid_window,
"long_window": long_window,
"short_term_action": short_action,
"mid_term_action": mid_action,
"long_term_action": long_action,
"risk_action": risk_action,
}
def generate_executive_summary(
module_id: str,
question: str,
series_pack: dict[str, Any],
module_output: dict[str, Any],
) -> str | None:
if not PREDICT_BACKEND_SUMMARY_ENABLED:
return None
metrics = module_output.get("decision_metrics") or {}
recommendation = module_output.get("recommendation") or {}
risk_flags = module_output.get("risk_flags") or []
horizon_days = int(series_pack.get("horizon_days", 15))
recommendation_action = recommendation.get("action", "持续观察")
risk_level = metrics.get("risk_level", "低")
overall_change = metrics.get("overall_change_pct", 0.0)
trend_bias = metrics.get("trend_bias", "平稳")
risk_score = metrics.get("risk_score", 0.0)
period_overview = metrics.get("period_overview", [])
action_plan = recommendation.get("action_plan") or {}
if risk_level == "高":
risk_level_text = "高风险"
elif risk_level == "中":
risk_level_text = "中等风险"
else:
risk_level_text = "低风险"
risk_lines = ";".join(risk_flags) if risk_flags else "当前无明显高危预警"
period_desc = ";".join(
[
f"{item.get('label', '区间')}({item.get('days')}天):"
f"均值={item.get('avg', 0):.4f}, 变动={item.get('change_pct', 0):+.2f}%, "
f"波动={item.get('volatility', 0):.2f}%"
for item in period_overview
]
)
if not period_desc:
period_desc = "模型分段信号不足,采用整体指标口径。"
return (
f"**一、预测结论**\n"
f"基于“{question}”,当前场景为 {module_id}。\n"
f"核心判断:未来 {horizon_days} 天总体趋势{trend_bias},整体变化{overall_change:+.4f}%;"
f"风险评分 {risk_score:.2f}{risk_level_text})。\n"
f"核心驱动:{metrics.get('dominant_series', '')};窗口快照:{period_desc}\n\n"
f"**二、风险评级**\n"
f"建议风险等级:{risk_level_text}。\n"
f"依据:趋势偏向 {trend_bias}、序列协同、窗口波动与历史基线变化。\n\n"
f"**三、决策分析**\n"
f"- 短线({action_plan.get('short_window', '1-3天')}):{action_plan.get('short_term_action', recommendation_action)}\n"
f"- 中线({action_plan.get('mid_window', '4-7天')}):{action_plan.get('mid_term_action', '按周复盘执行策略')}\n"
f"- 远期({action_plan.get('long_window', f'8-{horizon_days}天')}):{action_plan.get('long_term_action', recommendation_action)}\n"
f"- 风险动作:{action_plan.get('risk_action', '按月复盘')}\n\n"
f"**四、风险提示**\n"
f"{risk_lines}。建议至少7天复盘一次,避免短期噪声驱动过度交易。"
)
def _insert_prediction_audit(
trace_id: str,
module_id: str,
question: str,
horizon_days: int,
payload: PredictRunData,
) -> None:
if not PREDICT_LOG_ENABLED:
return
if psycopg is None or not PG_DSN:
LOGGER.warning("predict_audit_insert_skipped trace_id=%s reason=psycopg_or_dsn_missing", trace_id)
return
try:
with psycopg.connect(PG_DSN, autocommit=True) as conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO public.ai_prediction_run (
trace_id, module_id, question, horizon_days, db_used, model_trace, status, created_at
)
VALUES (%s, %s, %s, %s, %s, %s::jsonb, %s, now())
ON CONFLICT (trace_id) DO UPDATE
SET module_id = EXCLUDED.module_id,
question = EXCLUDED.question,
horizon_days = EXCLUDED.horizon_days,
db_used = EXCLUDED.db_used,
model_trace = EXCLUDED.model_trace,
status = EXCLUDED.status
RETURNING id
""",
(
trace_id,
module_id,
question,
horizon_days,
payload.db_used,
json.dumps(payload.model_trace, ensure_ascii=False),
"ok",
),
)
row = cur.fetchone()
if not row:
return
run_id = int(row[0])
cur.execute(
"""
INSERT INTO public.ai_prediction_output (
run_id, forecast, recommendation, risk_flags, backend_draft_summary, created_at
)
VALUES (%s, %s::jsonb, %s::jsonb, %s::jsonb, %s, now())
""",
(
run_id,
json.dumps(payload.forecast, ensure_ascii=False),
json.dumps(payload.recommendation, ensure_ascii=False),
json.dumps(payload.risk_flags, ensure_ascii=False),
payload.backend_draft_summary,
),
)
except Exception as exc:
LOGGER.warning("predict_audit_insert_skipped trace_id=%s reason=%s", trace_id, str(exc))
def _require_psycopg() -> None:
if psycopg is None:
raise QueryAPIError("E_DRIVER_MISSING", "缺少 psycopg 依赖,请先安装后再启动服务", status_code=500)
def _to_jsonable(value: Any) -> Any:
if value is None:
return None
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Decimal):
return float(value)
if isinstance(value, (datetime, date)):
return value.isoformat()
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def run_query(sql: str, trace_id: str) -> QuerySuccessData:
_require_psycopg()
if not PG_DSN:
raise QueryAPIError("E_DSN_MISSING", "未配置 PIG_QUERY_PG_DSN", status_code=500)
started = time.perf_counter()
try:
with psycopg.connect(PG_DSN, autocommit=False) as conn:
with conn.cursor() as cur:
cur.execute("SET TRANSACTION READ ONLY")
cur.execute(f"SET LOCAL statement_timeout = '{STATEMENT_TIMEOUT_MS}ms'")
cur.execute(sql)
rows = cur.fetchmany(MAX_ROWS)
columns = [desc.name for desc in cur.description] if cur.description else []
except QueryAPIError:
raise
except Exception as exc:
raise QueryAPIError("E_DB_QUERY", "数据库查询失败", {"reason": str(exc)}, status_code=500) from exc
elapsed_ms = round((time.perf_counter() - started) * 1000, 2)
normalized_rows = [[_to_jsonable(value) for value in row] for row in rows]
data = QuerySuccessData(
columns=columns,
rows=normalized_rows,
row_count=len(normalized_rows),
sql_executed=sql,
empty=len(normalized_rows) == 0,
trace_id=trace_id,
)
LOGGER.info(
"query_ok trace_id=%s sql_hash=%s row_count=%s duration_ms=%s",
trace_id,
hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16],
data.row_count,
elapsed_ms,
)
return data
async def verify_api_key(x_api_key: str = Header(default="", alias="X-API-Key")) -> None:
if not x_api_key or x_api_key != API_KEY:
raise QueryAPIError("E_UNAUTHORIZED", "无效的 API Key", status_code=status.HTTP_401_UNAUTHORIZED)
def build_error_response(exc: QueryAPIError, trace_id: str | None = None) -> JSONResponse:
detail = dict(exc.detail)
if trace_id and "trace_id" not in detail:
detail["trace_id"] = trace_id
body = QueryResponse(
ok=False,
data=None,
err=ErrorDetail(code=exc.code, message=exc.message, detail=detail),
)
return JSONResponse(status_code=exc.status_code, content=body.model_dump())
app = FastAPI(title=SERVICE_NAME, version="1.0.0")
@app.exception_handler(QueryAPIError)
async def query_api_error_handler(_: Any, exc: QueryAPIError) -> JSONResponse:
return build_error_response(exc)
@app.exception_handler(Exception)
async def unhandled_exception_handler(_: Any, exc: Exception) -> JSONResponse: # pragma: no cover - defensive fallback
LOGGER.exception("unhandled error: %s", exc)
return build_error_response(
QueryAPIError("E_UNKNOWN", "系统错误", {"reason": str(exc)}, status_code=500)
)
@app.get("/api/v1/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(
service=SERVICE_NAME,
time=datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z"),
sql_parser="sqlglot" if sqlglot else "regex-fallback",
db_driver="psycopg" if psycopg else "missing",
)
@app.get("/api/v1/predict/health", response_model=PredictHealthResponse)
async def predict_health() -> PredictHealthResponse:
ready = False
error_text = TIMESFM_INIT_ERROR
if timesfm is not None:
try:
ensure_timesfm_ready()
ready = True
except QueryAPIError as exc:
error_text = exc.message if error_text is None else f"{error_text}; {exc.message}"
else:
pkg_error = "timesfm package not installed"
error_text = pkg_error if error_text is None else f"{error_text}; {pkg_error}"
return PredictHealthResponse(
ok=True,
service=SERVICE_NAME,
time=_now_utc_iso(),
timesfm_enabled=timesfm is not None,
timesfm_ready=ready,
timesfm_device=TIMESFM_DEVICE,
timesfm_checkpoint=TIMESFM_CHECKPOINT,
timesfm_init_error=error_text,
)
@app.post("/api/v1/predict/run", response_model=PredictRunResponse)
async def predict_run(request: PredictRunRequest, _: None = Depends(verify_api_key)) -> PredictRunResponse:
trace_id = request.trace_id or str(uuid.uuid4())
module_id = normalize_module_id(request.module_id, request.question)
is_realtime_module = module_id in REALTIME_MODULE_IDS
if not is_realtime_module:
ensure_timesfm_ready()
horizon_days = infer_horizon_days(request.question)
series_pack = build_series_from_db(module_id=module_id, question=request.question, horizon_days=horizon_days)
horizon_days = int(series_pack.get("horizon_days", horizon_days))
if is_realtime_module:
LOGGER.info(
"predict_realtime_window trace_id=%s module_id=%s reference_date=%s window_start=%s window_end=%s",
trace_id,
module_id,
series_pack.get("reference_date"),
series_pack.get("window_start"),
series_pack.get("window_end"),
)
forecast_map: dict[str, dict[str, Any]] = {}
if not is_realtime_module:
for series_item in series_pack.get("series", []):
series_name = str(series_item.get("name") or "series")
values = series_item.get("values") or []
if not isinstance(values, list) or not values:
continue
forecast_map[series_name] = run_timesfm_forecast(values, horizon_days=horizon_days)
if not forecast_map:
raise QueryAPIError("E_PREDICT_DATA_EMPTY", "缺少可预测序列", {"module_id": module_id}, status_code=400)
module_output = run_module_logic(module_id, series_pack, forecast_map)
if is_realtime_module:
executive_summary = None
model_trace = ["decision:historical_stats"]
else:
executive_summary = generate_executive_summary(
question=request.question,
module_id=module_id,
series_pack=series_pack,
module_output=module_output,
)
model_trace = [f"timesfm:{TIMESFM_CHECKPOINT}", "decision:backend_db_logic"]
if executive_summary:
model_trace.append("summary:rule_engine")
final_summary = executive_summary or module_output["backend_draft_summary"]
is_realtime = _is_realtime_question(request.question)
data = PredictRunData(
module_id=module_id,
db_used=True,
model_trace=model_trace,
forecast=module_output["forecast"],
recommendation=module_output["recommendation"],
risk_flags=module_output["risk_flags"],
decision_metrics=module_output["decision_metrics"],
backend_draft_summary=module_output["backend_draft_summary"],
executive_summary=final_summary,
trace_id=trace_id,
)
_insert_prediction_audit(
trace_id=trace_id,
module_id=module_id,
question=request.question,
horizon_days=horizon_days,
payload=data,
)
LOGGER.info(
"predict_run_ok trace_id=%s module_id=%s series_count=%s is_realtime=%s",
trace_id,
module_id,
len(forecast_map),
is_realtime,
)
return PredictRunResponse(ok=True, data=data, err=None)
@app.post("/api/v1/query", response_model=QueryResponse)
async def query(request: QueryRequest, _: None = Depends(verify_api_key)) -> QueryResponse:
trace_id = request.trace_id or str(uuid.uuid4())
prepared = prepare_sql_for_execution(request.sql)
LOGGER.info(
"query_request trace_id=%s user_query=%s tables=%s",
trace_id,
request.user_query,
json.dumps(sorted(prepared.tables), ensure_ascii=False),
)
data = run_query(prepared.sql, trace_id)
return QueryResponse(ok=True, data=data, err=None)
if __name__ == "__main__": # pragma: no cover - manual run path
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False)