""" 知识加载模块 负责加载和管理:市场知识库、类目知识库、跨类目痛点映射、Prompt模板。 所有数据从 JSON/YAML 文件加载并通过 Pydantic 校验。 """ import json from pathlib import Path from typing import Any from jinja2 import Environment, FileSystemLoader, Template class KnowledgeLoader: """知识库统一加载器""" def __init__(self, markets_dir: Path, categories_dir: Path, shared_dir: Path): self.markets_dir = Path(markets_dir) self.categories_dir = Path(categories_dir) self.shared_dir = Path(shared_dir) self._market_cache: dict[str, dict] = {} self._category_cache: dict[str, dict] = {} self._shared_cache: dict[str, dict] = {} self._prompt_env: Environment | None = None # ---- 市场知识 ---- def load_market(self, country_code: str) -> dict: """加载单个市场知识库""" if country_code in self._market_cache: return self._market_cache[country_code] # 支持 country_code → filename 映射 filename_map = {"vn": "vietnam", "id": "indonesia", "th": "thailand"} filename = filename_map.get(country_code, country_code) path = self.markets_dir / f"{filename}.json" if not path.exists(): # 尝试直接用 country_code path = self.markets_dir / f"{country_code}.json" if not path.exists(): raise FileNotFoundError(f"市场知识库不存在: marker_code={country_code}, tried={path} and {self.markets_dir / f'{filename}.json'}") with open(path, "r", encoding="utf-8") as f: data = json.load(f) self._market_cache[country_code] = data return data def load_all_markets(self) -> dict[str, dict]: """加载所有市场""" markets = {} for path in self.markets_dir.glob("*.json"): if path.stem in ("market_template",): continue markets[path.stem] = self.load_market(path.stem) return markets # ---- 类目知识 ---- def load_category(self, category_code: str) -> dict: """加载单个类目知识库(如 beauty_personal_care)""" if category_code in self._category_cache: return self._category_cache[category_code] path = self.categories_dir / f"{category_code}.json" if not path.exists(): raise FileNotFoundError(f"类目知识库不存在: {path}") with open(path, "r", encoding="utf-8") as f: data = json.load(f) self._category_cache[category_code] = data return data def load_all_categories(self) -> dict[str, dict]: """加载所有类目""" categories = {} for path in self.categories_dir.glob("*.json"): if path.stem in ("category_template",): continue categories[path.stem] = self.load_category(path.stem) return categories # ---- 共享知识 ---- def load_cross_category_pain_points(self) -> dict: """加载跨类目痛点映射""" if "cross_pain_points" in self._shared_cache: return self._shared_cache["cross_pain_points"] path = self.shared_dir / "cross_category_pain_points.json" if not path.exists(): return {"pain_points": []} with open(path, "r", encoding="utf-8") as f: data = json.load(f) self._shared_cache["cross_pain_points"] = data return data # ---- 为Stage 2准备注入上下文 ---- def build_market_context(self, country_code: str, category_code: str) -> dict: """构建Stage 2所需的市场+类目+跨类目痛点完整上下文""" market = self.load_market(country_code) category = self.load_category(category_code) cross_pain = self.load_cross_category_pain_points() # 直接取 market 的顶层字段 pain_points = market.get("consumer_pain_points", []) if isinstance(pain_points, dict): pain_points = pain_points.get("items", []) if not isinstance(pain_points, list): pain_points = [] # 获取跨类目映射 market_pains_cross = [] for pp in cross_pain.get("pain_points", []): source_markets = pp.get("source_market", []) if country_code in source_markets: market_pains_cross.append(pp) return { "country_code": country_code, "market_name": market.get("market_name", country_code), "market_overview": market.get("overview", {}), "climate": market.get("climate", {}), "pain_points": pain_points, "cross_category_pain_points": market_pains_cross, "price_psychology": market.get("price_psychology", {}), "beauty_standards": market.get("beauty_standards", {}), "scent_preferences": market.get("scent_preferences", {}), "packaging_preferences": market.get("packaging_preferences", {}), "religious_considerations": ( market.get("cultural_dimensions", {}) or {} ).get("religious_considerations", {}), "category_attributes": category, "regulatory": market.get("regulatory", {}), } # ---- Prompt 模板 ---- def init_prompt_env(self, prompts_dir: Path): """初始化 Jijna2 prompt模板环境""" self._prompt_env = Environment(loader=FileSystemLoader(str(prompts_dir))) def load_prompt(self, template_name: str) -> Template: """加载单个 prompt 模板""" if self._prompt_env is None: raise RuntimeError("Prompt 环境未初始化,请先调用 init_prompt_env()") return self._prompt_env.get_template(template_name)