File size: 5,842 Bytes
8cd7f8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
知识加载模块

负责加载和管理:市场知识库、类目知识库、跨类目痛点映射、Prompt模板。
所有数据从 JSON/YAML 文件加载并通过 Pydantic 校验。
"""

import json
from pathlib import Path
from typing import Any
from jinja2 import Environment, FileSystemLoader, Template


class KnowledgeLoader:
    """知识库统一加载器"""

    def __init__(self, markets_dir: Path, categories_dir: Path, shared_dir: Path):
        self.markets_dir = Path(markets_dir)
        self.categories_dir = Path(categories_dir)
        self.shared_dir = Path(shared_dir)
        self._market_cache: dict[str, dict] = {}
        self._category_cache: dict[str, dict] = {}
        self._shared_cache: dict[str, dict] = {}
        self._prompt_env: Environment | None = None

    # ---- 市场知识 ----

    def load_market(self, country_code: str) -> dict:
        """加载单个市场知识库"""
        if country_code in self._market_cache:
            return self._market_cache[country_code]

        # 支持 country_code → filename 映射
        filename_map = {"vn": "vietnam", "id": "indonesia", "th": "thailand"}
        filename = filename_map.get(country_code, country_code)
        path = self.markets_dir / f"{filename}.json"

        if not path.exists():
            # 尝试直接用 country_code
            path = self.markets_dir / f"{country_code}.json"
            if not path.exists():
                raise FileNotFoundError(f"市场知识库不存在: marker_code={country_code}, tried={path} and {self.markets_dir / f'{filename}.json'}")

        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        self._market_cache[country_code] = data
        return data

    def load_all_markets(self) -> dict[str, dict]:
        """加载所有市场"""
        markets = {}
        for path in self.markets_dir.glob("*.json"):
            if path.stem in ("market_template",):
                continue
            markets[path.stem] = self.load_market(path.stem)
        return markets

    # ---- 类目知识 ----

    def load_category(self, category_code: str) -> dict:
        """加载单个类目知识库(如 beauty_personal_care)"""
        if category_code in self._category_cache:
            return self._category_cache[category_code]

        path = self.categories_dir / f"{category_code}.json"
        if not path.exists():
            raise FileNotFoundError(f"类目知识库不存在: {path}")

        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        self._category_cache[category_code] = data
        return data

    def load_all_categories(self) -> dict[str, dict]:
        """加载所有类目"""
        categories = {}
        for path in self.categories_dir.glob("*.json"):
            if path.stem in ("category_template",):
                continue
            categories[path.stem] = self.load_category(path.stem)
        return categories

    # ---- 共享知识 ----

    def load_cross_category_pain_points(self) -> dict:
        """加载跨类目痛点映射"""
        if "cross_pain_points" in self._shared_cache:
            return self._shared_cache["cross_pain_points"]

        path = self.shared_dir / "cross_category_pain_points.json"
        if not path.exists():
            return {"pain_points": []}

        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        self._shared_cache["cross_pain_points"] = data
        return data

    # ---- 为Stage 2准备注入上下文 ----

    def build_market_context(self, country_code: str, category_code: str) -> dict:
        """构建Stage 2所需的市场+类目+跨类目痛点完整上下文"""
        market = self.load_market(country_code)
        category = self.load_category(category_code)
        cross_pain = self.load_cross_category_pain_points()

        # 直接取 market 的顶层字段
        pain_points = market.get("consumer_pain_points", [])
        if isinstance(pain_points, dict):
            pain_points = pain_points.get("items", [])
        if not isinstance(pain_points, list):
            pain_points = []

        # 获取跨类目映射
        market_pains_cross = []
        for pp in cross_pain.get("pain_points", []):
            source_markets = pp.get("source_market", [])
            if country_code in source_markets:
                market_pains_cross.append(pp)

        return {
            "country_code": country_code,
            "market_name": market.get("market_name", country_code),
            "market_overview": market.get("overview", {}),
            "climate": market.get("climate", {}),
            "pain_points": pain_points,
            "cross_category_pain_points": market_pains_cross,
            "price_psychology": market.get("price_psychology", {}),
            "beauty_standards": market.get("beauty_standards", {}),
            "scent_preferences": market.get("scent_preferences", {}),
            "packaging_preferences": market.get("packaging_preferences", {}),
            "religious_considerations": (
                market.get("cultural_dimensions", {}) or {}
            ).get("religious_considerations", {}),
            "category_attributes": category,
            "regulatory": market.get("regulatory", {}),
        }

    # ---- Prompt 模板 ----

    def init_prompt_env(self, prompts_dir: Path):
        """初始化 Jijna2 prompt模板环境"""
        self._prompt_env = Environment(loader=FileSystemLoader(str(prompts_dir)))

    def load_prompt(self, template_name: str) -> Template:
        """加载单个 prompt 模板"""
        if self._prompt_env is None:
            raise RuntimeError("Prompt 环境未初始化,请先调用 init_prompt_env()")
        return self._prompt_env.get_template(template_name)