Toly-89's picture
Upload phase_b/src/knowledge/loader.py with huggingface_hub
8cd7f8e verified
"""
知识加载模块
负责加载和管理:市场知识库、类目知识库、跨类目痛点映射、Prompt模板。
所有数据从 JSON/YAML 文件加载并通过 Pydantic 校验。
"""
import json
from pathlib import Path
from typing import Any
from jinja2 import Environment, FileSystemLoader, Template
class KnowledgeLoader:
"""知识库统一加载器"""
def __init__(self, markets_dir: Path, categories_dir: Path, shared_dir: Path):
self.markets_dir = Path(markets_dir)
self.categories_dir = Path(categories_dir)
self.shared_dir = Path(shared_dir)
self._market_cache: dict[str, dict] = {}
self._category_cache: dict[str, dict] = {}
self._shared_cache: dict[str, dict] = {}
self._prompt_env: Environment | None = None
# ---- 市场知识 ----
def load_market(self, country_code: str) -> dict:
"""加载单个市场知识库"""
if country_code in self._market_cache:
return self._market_cache[country_code]
# 支持 country_code → filename 映射
filename_map = {"vn": "vietnam", "id": "indonesia", "th": "thailand"}
filename = filename_map.get(country_code, country_code)
path = self.markets_dir / f"{filename}.json"
if not path.exists():
# 尝试直接用 country_code
path = self.markets_dir / f"{country_code}.json"
if not path.exists():
raise FileNotFoundError(f"市场知识库不存在: marker_code={country_code}, tried={path} and {self.markets_dir / f'{filename}.json'}")
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self._market_cache[country_code] = data
return data
def load_all_markets(self) -> dict[str, dict]:
"""加载所有市场"""
markets = {}
for path in self.markets_dir.glob("*.json"):
if path.stem in ("market_template",):
continue
markets[path.stem] = self.load_market(path.stem)
return markets
# ---- 类目知识 ----
def load_category(self, category_code: str) -> dict:
"""加载单个类目知识库(如 beauty_personal_care)"""
if category_code in self._category_cache:
return self._category_cache[category_code]
path = self.categories_dir / f"{category_code}.json"
if not path.exists():
raise FileNotFoundError(f"类目知识库不存在: {path}")
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self._category_cache[category_code] = data
return data
def load_all_categories(self) -> dict[str, dict]:
"""加载所有类目"""
categories = {}
for path in self.categories_dir.glob("*.json"):
if path.stem in ("category_template",):
continue
categories[path.stem] = self.load_category(path.stem)
return categories
# ---- 共享知识 ----
def load_cross_category_pain_points(self) -> dict:
"""加载跨类目痛点映射"""
if "cross_pain_points" in self._shared_cache:
return self._shared_cache["cross_pain_points"]
path = self.shared_dir / "cross_category_pain_points.json"
if not path.exists():
return {"pain_points": []}
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self._shared_cache["cross_pain_points"] = data
return data
# ---- 为Stage 2准备注入上下文 ----
def build_market_context(self, country_code: str, category_code: str) -> dict:
"""构建Stage 2所需的市场+类目+跨类目痛点完整上下文"""
market = self.load_market(country_code)
category = self.load_category(category_code)
cross_pain = self.load_cross_category_pain_points()
# 直接取 market 的顶层字段
pain_points = market.get("consumer_pain_points", [])
if isinstance(pain_points, dict):
pain_points = pain_points.get("items", [])
if not isinstance(pain_points, list):
pain_points = []
# 获取跨类目映射
market_pains_cross = []
for pp in cross_pain.get("pain_points", []):
source_markets = pp.get("source_market", [])
if country_code in source_markets:
market_pains_cross.append(pp)
return {
"country_code": country_code,
"market_name": market.get("market_name", country_code),
"market_overview": market.get("overview", {}),
"climate": market.get("climate", {}),
"pain_points": pain_points,
"cross_category_pain_points": market_pains_cross,
"price_psychology": market.get("price_psychology", {}),
"beauty_standards": market.get("beauty_standards", {}),
"scent_preferences": market.get("scent_preferences", {}),
"packaging_preferences": market.get("packaging_preferences", {}),
"religious_considerations": (
market.get("cultural_dimensions", {}) or {}
).get("religious_considerations", {}),
"category_attributes": category,
"regulatory": market.get("regulatory", {}),
}
# ---- Prompt 模板 ----
def init_prompt_env(self, prompts_dir: Path):
"""初始化 Jijna2 prompt模板环境"""
self._prompt_env = Environment(loader=FileSystemLoader(str(prompts_dir)))
def load_prompt(self, template_name: str) -> Template:
"""加载单个 prompt 模板"""
if self._prompt_env is None:
raise RuntimeError("Prompt 环境未初始化,请先调用 init_prompt_env()")
return self._prompt_env.get_template(template_name)