Spaces:
Sleeping
Sleeping
File size: 5,842 Bytes
8cd7f8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | """
知识加载模块
负责加载和管理:市场知识库、类目知识库、跨类目痛点映射、Prompt模板。
所有数据从 JSON/YAML 文件加载并通过 Pydantic 校验。
"""
import json
from pathlib import Path
from typing import Any
from jinja2 import Environment, FileSystemLoader, Template
class KnowledgeLoader:
"""知识库统一加载器"""
def __init__(self, markets_dir: Path, categories_dir: Path, shared_dir: Path):
self.markets_dir = Path(markets_dir)
self.categories_dir = Path(categories_dir)
self.shared_dir = Path(shared_dir)
self._market_cache: dict[str, dict] = {}
self._category_cache: dict[str, dict] = {}
self._shared_cache: dict[str, dict] = {}
self._prompt_env: Environment | None = None
# ---- 市场知识 ----
def load_market(self, country_code: str) -> dict:
"""加载单个市场知识库"""
if country_code in self._market_cache:
return self._market_cache[country_code]
# 支持 country_code → filename 映射
filename_map = {"vn": "vietnam", "id": "indonesia", "th": "thailand"}
filename = filename_map.get(country_code, country_code)
path = self.markets_dir / f"{filename}.json"
if not path.exists():
# 尝试直接用 country_code
path = self.markets_dir / f"{country_code}.json"
if not path.exists():
raise FileNotFoundError(f"市场知识库不存在: marker_code={country_code}, tried={path} and {self.markets_dir / f'{filename}.json'}")
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self._market_cache[country_code] = data
return data
def load_all_markets(self) -> dict[str, dict]:
"""加载所有市场"""
markets = {}
for path in self.markets_dir.glob("*.json"):
if path.stem in ("market_template",):
continue
markets[path.stem] = self.load_market(path.stem)
return markets
# ---- 类目知识 ----
def load_category(self, category_code: str) -> dict:
"""加载单个类目知识库(如 beauty_personal_care)"""
if category_code in self._category_cache:
return self._category_cache[category_code]
path = self.categories_dir / f"{category_code}.json"
if not path.exists():
raise FileNotFoundError(f"类目知识库不存在: {path}")
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self._category_cache[category_code] = data
return data
def load_all_categories(self) -> dict[str, dict]:
"""加载所有类目"""
categories = {}
for path in self.categories_dir.glob("*.json"):
if path.stem in ("category_template",):
continue
categories[path.stem] = self.load_category(path.stem)
return categories
# ---- 共享知识 ----
def load_cross_category_pain_points(self) -> dict:
"""加载跨类目痛点映射"""
if "cross_pain_points" in self._shared_cache:
return self._shared_cache["cross_pain_points"]
path = self.shared_dir / "cross_category_pain_points.json"
if not path.exists():
return {"pain_points": []}
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self._shared_cache["cross_pain_points"] = data
return data
# ---- 为Stage 2准备注入上下文 ----
def build_market_context(self, country_code: str, category_code: str) -> dict:
"""构建Stage 2所需的市场+类目+跨类目痛点完整上下文"""
market = self.load_market(country_code)
category = self.load_category(category_code)
cross_pain = self.load_cross_category_pain_points()
# 直接取 market 的顶层字段
pain_points = market.get("consumer_pain_points", [])
if isinstance(pain_points, dict):
pain_points = pain_points.get("items", [])
if not isinstance(pain_points, list):
pain_points = []
# 获取跨类目映射
market_pains_cross = []
for pp in cross_pain.get("pain_points", []):
source_markets = pp.get("source_market", [])
if country_code in source_markets:
market_pains_cross.append(pp)
return {
"country_code": country_code,
"market_name": market.get("market_name", country_code),
"market_overview": market.get("overview", {}),
"climate": market.get("climate", {}),
"pain_points": pain_points,
"cross_category_pain_points": market_pains_cross,
"price_psychology": market.get("price_psychology", {}),
"beauty_standards": market.get("beauty_standards", {}),
"scent_preferences": market.get("scent_preferences", {}),
"packaging_preferences": market.get("packaging_preferences", {}),
"religious_considerations": (
market.get("cultural_dimensions", {}) or {}
).get("religious_considerations", {}),
"category_attributes": category,
"regulatory": market.get("regulatory", {}),
}
# ---- Prompt 模板 ----
def init_prompt_env(self, prompts_dir: Path):
"""初始化 Jijna2 prompt模板环境"""
self._prompt_env = Environment(loader=FileSystemLoader(str(prompts_dir)))
def load_prompt(self, template_name: str) -> Template:
"""加载单个 prompt 模板"""
if self._prompt_env is None:
raise RuntimeError("Prompt 环境未初始化,请先调用 init_prompt_env()")
return self._prompt_env.get_template(template_name)
|