Spaces:
Sleeping
Sleeping
| """ | |
| 知识加载模块 | |
| 负责加载和管理:市场知识库、类目知识库、跨类目痛点映射、Prompt模板。 | |
| 所有数据从 JSON/YAML 文件加载并通过 Pydantic 校验。 | |
| """ | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| from jinja2 import Environment, FileSystemLoader, Template | |
| class KnowledgeLoader: | |
| """知识库统一加载器""" | |
| def __init__(self, markets_dir: Path, categories_dir: Path, shared_dir: Path): | |
| self.markets_dir = Path(markets_dir) | |
| self.categories_dir = Path(categories_dir) | |
| self.shared_dir = Path(shared_dir) | |
| self._market_cache: dict[str, dict] = {} | |
| self._category_cache: dict[str, dict] = {} | |
| self._shared_cache: dict[str, dict] = {} | |
| self._prompt_env: Environment | None = None | |
| # ---- 市场知识 ---- | |
| def load_market(self, country_code: str) -> dict: | |
| """加载单个市场知识库""" | |
| if country_code in self._market_cache: | |
| return self._market_cache[country_code] | |
| # 支持 country_code → filename 映射 | |
| filename_map = {"vn": "vietnam", "id": "indonesia", "th": "thailand"} | |
| filename = filename_map.get(country_code, country_code) | |
| path = self.markets_dir / f"{filename}.json" | |
| if not path.exists(): | |
| # 尝试直接用 country_code | |
| path = self.markets_dir / f"{country_code}.json" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"市场知识库不存在: marker_code={country_code}, tried={path} and {self.markets_dir / f'{filename}.json'}") | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| self._market_cache[country_code] = data | |
| return data | |
| def load_all_markets(self) -> dict[str, dict]: | |
| """加载所有市场""" | |
| markets = {} | |
| for path in self.markets_dir.glob("*.json"): | |
| if path.stem in ("market_template",): | |
| continue | |
| markets[path.stem] = self.load_market(path.stem) | |
| return markets | |
| # ---- 类目知识 ---- | |
| def load_category(self, category_code: str) -> dict: | |
| """加载单个类目知识库(如 beauty_personal_care)""" | |
| if category_code in self._category_cache: | |
| return self._category_cache[category_code] | |
| path = self.categories_dir / f"{category_code}.json" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"类目知识库不存在: {path}") | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| self._category_cache[category_code] = data | |
| return data | |
| def load_all_categories(self) -> dict[str, dict]: | |
| """加载所有类目""" | |
| categories = {} | |
| for path in self.categories_dir.glob("*.json"): | |
| if path.stem in ("category_template",): | |
| continue | |
| categories[path.stem] = self.load_category(path.stem) | |
| return categories | |
| # ---- 共享知识 ---- | |
| def load_cross_category_pain_points(self) -> dict: | |
| """加载跨类目痛点映射""" | |
| if "cross_pain_points" in self._shared_cache: | |
| return self._shared_cache["cross_pain_points"] | |
| path = self.shared_dir / "cross_category_pain_points.json" | |
| if not path.exists(): | |
| return {"pain_points": []} | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| self._shared_cache["cross_pain_points"] = data | |
| return data | |
| # ---- 为Stage 2准备注入上下文 ---- | |
| def build_market_context(self, country_code: str, category_code: str) -> dict: | |
| """构建Stage 2所需的市场+类目+跨类目痛点完整上下文""" | |
| market = self.load_market(country_code) | |
| category = self.load_category(category_code) | |
| cross_pain = self.load_cross_category_pain_points() | |
| # 直接取 market 的顶层字段 | |
| pain_points = market.get("consumer_pain_points", []) | |
| if isinstance(pain_points, dict): | |
| pain_points = pain_points.get("items", []) | |
| if not isinstance(pain_points, list): | |
| pain_points = [] | |
| # 获取跨类目映射 | |
| market_pains_cross = [] | |
| for pp in cross_pain.get("pain_points", []): | |
| source_markets = pp.get("source_market", []) | |
| if country_code in source_markets: | |
| market_pains_cross.append(pp) | |
| return { | |
| "country_code": country_code, | |
| "market_name": market.get("market_name", country_code), | |
| "market_overview": market.get("overview", {}), | |
| "climate": market.get("climate", {}), | |
| "pain_points": pain_points, | |
| "cross_category_pain_points": market_pains_cross, | |
| "price_psychology": market.get("price_psychology", {}), | |
| "beauty_standards": market.get("beauty_standards", {}), | |
| "scent_preferences": market.get("scent_preferences", {}), | |
| "packaging_preferences": market.get("packaging_preferences", {}), | |
| "religious_considerations": ( | |
| market.get("cultural_dimensions", {}) or {} | |
| ).get("religious_considerations", {}), | |
| "category_attributes": category, | |
| "regulatory": market.get("regulatory", {}), | |
| } | |
| # ---- Prompt 模板 ---- | |
| def init_prompt_env(self, prompts_dir: Path): | |
| """初始化 Jijna2 prompt模板环境""" | |
| self._prompt_env = Environment(loader=FileSystemLoader(str(prompts_dir))) | |
| def load_prompt(self, template_name: str) -> Template: | |
| """加载单个 prompt 模板""" | |
| if self._prompt_env is None: | |
| raise RuntimeError("Prompt 环境未初始化,请先调用 init_prompt_env()") | |
| return self._prompt_env.get_template(template_name) | |