Toly-89 commited on
Commit
8cd7f8e
·
verified ·
1 Parent(s): 5994a21

Upload phase_b/src/knowledge/loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. phase_b/src/knowledge/loader.py +149 -0
phase_b/src/knowledge/loader.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 知识加载模块
3
+
4
+ 负责加载和管理:市场知识库、类目知识库、跨类目痛点映射、Prompt模板。
5
+ 所有数据从 JSON/YAML 文件加载并通过 Pydantic 校验。
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Any
11
+ from jinja2 import Environment, FileSystemLoader, Template
12
+
13
+
14
+ class KnowledgeLoader:
15
+ """知识库统一加载器"""
16
+
17
+ def __init__(self, markets_dir: Path, categories_dir: Path, shared_dir: Path):
18
+ self.markets_dir = Path(markets_dir)
19
+ self.categories_dir = Path(categories_dir)
20
+ self.shared_dir = Path(shared_dir)
21
+ self._market_cache: dict[str, dict] = {}
22
+ self._category_cache: dict[str, dict] = {}
23
+ self._shared_cache: dict[str, dict] = {}
24
+ self._prompt_env: Environment | None = None
25
+
26
+ # ---- 市场知识 ----
27
+
28
+ def load_market(self, country_code: str) -> dict:
29
+ """加载单个市场知识库"""
30
+ if country_code in self._market_cache:
31
+ return self._market_cache[country_code]
32
+
33
+ # 支持 country_code → filename 映射
34
+ filename_map = {"vn": "vietnam", "id": "indonesia", "th": "thailand"}
35
+ filename = filename_map.get(country_code, country_code)
36
+ path = self.markets_dir / f"{filename}.json"
37
+
38
+ if not path.exists():
39
+ # 尝试直接用 country_code
40
+ path = self.markets_dir / f"{country_code}.json"
41
+ if not path.exists():
42
+ raise FileNotFoundError(f"市场知识库不存在: marker_code={country_code}, tried={path} and {self.markets_dir / f'{filename}.json'}")
43
+
44
+ with open(path, "r", encoding="utf-8") as f:
45
+ data = json.load(f)
46
+ self._market_cache[country_code] = data
47
+ return data
48
+
49
+ def load_all_markets(self) -> dict[str, dict]:
50
+ """加载所有市场"""
51
+ markets = {}
52
+ for path in self.markets_dir.glob("*.json"):
53
+ if path.stem in ("market_template",):
54
+ continue
55
+ markets[path.stem] = self.load_market(path.stem)
56
+ return markets
57
+
58
+ # ---- 类目知识 ----
59
+
60
+ def load_category(self, category_code: str) -> dict:
61
+ """加载单个类目知识库(如 beauty_personal_care)"""
62
+ if category_code in self._category_cache:
63
+ return self._category_cache[category_code]
64
+
65
+ path = self.categories_dir / f"{category_code}.json"
66
+ if not path.exists():
67
+ raise FileNotFoundError(f"类目知识库不存在: {path}")
68
+
69
+ with open(path, "r", encoding="utf-8") as f:
70
+ data = json.load(f)
71
+ self._category_cache[category_code] = data
72
+ return data
73
+
74
+ def load_all_categories(self) -> dict[str, dict]:
75
+ """加载所有类目"""
76
+ categories = {}
77
+ for path in self.categories_dir.glob("*.json"):
78
+ if path.stem in ("category_template",):
79
+ continue
80
+ categories[path.stem] = self.load_category(path.stem)
81
+ return categories
82
+
83
+ # ---- 共享知识 ----
84
+
85
+ def load_cross_category_pain_points(self) -> dict:
86
+ """加载跨类目痛点映射"""
87
+ if "cross_pain_points" in self._shared_cache:
88
+ return self._shared_cache["cross_pain_points"]
89
+
90
+ path = self.shared_dir / "cross_category_pain_points.json"
91
+ if not path.exists():
92
+ return {"pain_points": []}
93
+
94
+ with open(path, "r", encoding="utf-8") as f:
95
+ data = json.load(f)
96
+ self._shared_cache["cross_pain_points"] = data
97
+ return data
98
+
99
+ # ---- 为Stage 2准备注入上下文 ----
100
+
101
+ def build_market_context(self, country_code: str, category_code: str) -> dict:
102
+ """构建Stage 2所需的市场+类目+跨类目痛点完整上下文"""
103
+ market = self.load_market(country_code)
104
+ category = self.load_category(category_code)
105
+ cross_pain = self.load_cross_category_pain_points()
106
+
107
+ # 直接取 market 的顶层字段
108
+ pain_points = market.get("consumer_pain_points", [])
109
+ if isinstance(pain_points, dict):
110
+ pain_points = pain_points.get("items", [])
111
+ if not isinstance(pain_points, list):
112
+ pain_points = []
113
+
114
+ # 获取跨类目映射
115
+ market_pains_cross = []
116
+ for pp in cross_pain.get("pain_points", []):
117
+ source_markets = pp.get("source_market", [])
118
+ if country_code in source_markets:
119
+ market_pains_cross.append(pp)
120
+
121
+ return {
122
+ "country_code": country_code,
123
+ "market_name": market.get("market_name", country_code),
124
+ "market_overview": market.get("overview", {}),
125
+ "climate": market.get("climate", {}),
126
+ "pain_points": pain_points,
127
+ "cross_category_pain_points": market_pains_cross,
128
+ "price_psychology": market.get("price_psychology", {}),
129
+ "beauty_standards": market.get("beauty_standards", {}),
130
+ "scent_preferences": market.get("scent_preferences", {}),
131
+ "packaging_preferences": market.get("packaging_preferences", {}),
132
+ "religious_considerations": (
133
+ market.get("cultural_dimensions", {}) or {}
134
+ ).get("religious_considerations", {}),
135
+ "category_attributes": category,
136
+ "regulatory": market.get("regulatory", {}),
137
+ }
138
+
139
+ # ---- Prompt 模板 ----
140
+
141
+ def init_prompt_env(self, prompts_dir: Path):
142
+ """初始化 Jijna2 prompt模板环境"""
143
+ self._prompt_env = Environment(loader=FileSystemLoader(str(prompts_dir)))
144
+
145
+ def load_prompt(self, template_name: str) -> Template:
146
+ """加载单个 prompt 模板"""
147
+ if self._prompt_env is None:
148
+ raise RuntimeError("Prompt 环境未初始化,请先调用 init_prompt_env()")
149
+ return self._prompt_env.get_template(template_name)