Spaces:
Sleeping
Sleeping
Initial commit
Browse files- rag/rag_manager.py +76 -14
rag/rag_manager.py
CHANGED
|
@@ -4,6 +4,7 @@ from chromadb import PersistentClient
|
|
| 4 |
from chromadb.utils.embedding_functions import EmbeddingFunction
|
| 5 |
from config import CHROMA_DIR
|
| 6 |
|
|
|
|
| 7 |
CHROMA_DIR.mkdir(parents=True, exist_ok=True)
|
| 8 |
print(f"๐ ChromaDB ๊ฒฝ๋ก: {CHROMA_DIR.resolve()}")
|
| 9 |
|
|
@@ -11,13 +12,45 @@ _client = PersistentClient(path=str(CHROMA_DIR))
|
|
| 11 |
_collection = _client.get_or_create_collection(name="game_docs")
|
| 12 |
_embedder: Optional[EmbeddingFunction] = None
|
| 13 |
|
|
|
|
|
|
|
| 14 |
def set_embedder(embedder: Any):
|
| 15 |
global _embedder
|
| 16 |
_embedder = embedder
|
| 17 |
|
|
|
|
| 18 |
def chroma_initialized() -> bool:
|
| 19 |
-
return os.path.exists(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
| 21 |
def load_game_docs_from_disk(path: str) -> List[Dict[str, Any]]:
|
| 22 |
docs = []
|
| 23 |
for filename in os.listdir(path):
|
|
@@ -29,35 +62,39 @@ def load_game_docs_from_disk(path: str) -> List[Dict[str, Any]]:
|
|
| 29 |
for i, doc in enumerate(data):
|
| 30 |
if "id" not in doc:
|
| 31 |
doc["id"] = f"{filename}_{i}"
|
|
|
|
| 32 |
docs.append(doc)
|
| 33 |
-
|
| 34 |
if "id" not in data:
|
| 35 |
data["id"] = filename
|
|
|
|
| 36 |
docs.append(data)
|
| 37 |
elif filename.endswith(".txt"):
|
| 38 |
with open(full, "r", encoding="utf-8") as f:
|
| 39 |
content = f.read()
|
| 40 |
docs.append({
|
| 41 |
"id": filename,
|
|
|
|
| 42 |
"content": content,
|
| 43 |
"metadata": {}
|
| 44 |
})
|
| 45 |
return docs
|
| 46 |
|
|
|
|
|
|
|
| 47 |
def add_docs(docs: List[Dict[str, Any]], batch_size: int = 32):
|
| 48 |
assert _embedder is not None, "Embedder not initialized"
|
| 49 |
for i in range(0, len(docs), batch_size):
|
| 50 |
batch = docs[i:i+batch_size]
|
| 51 |
-
ids = []
|
| 52 |
-
contents = []
|
| 53 |
-
embeddings = []
|
| 54 |
-
metadatas = []
|
| 55 |
for doc in batch:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
embeddings.append(emb)
|
| 62 |
_collection.add(
|
| 63 |
documents=contents,
|
|
@@ -66,9 +103,11 @@ def add_docs(docs: List[Dict[str, Any]], batch_size: int = 32):
|
|
| 66 |
ids=ids
|
| 67 |
)
|
| 68 |
|
|
|
|
|
|
|
| 69 |
def retrieve(query: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 70 |
assert _embedder is not None, "Embedder not initialized"
|
| 71 |
-
|
| 72 |
if query:
|
| 73 |
q_emb = _embedder.encode(query).tolist()
|
| 74 |
res = _collection.query(
|
|
@@ -78,7 +117,6 @@ def retrieve(query: Optional[str] = None, filters: Optional[Dict[str, Any]] = No
|
|
| 78 |
)
|
| 79 |
docs = res.get("documents", [[]])[0]
|
| 80 |
metas = res.get("metadatas", [[]])[0]
|
| 81 |
-
return [{"content": d, "metadata": m} for d, m in zip(docs, metas)]
|
| 82 |
else:
|
| 83 |
res = _collection.get(
|
| 84 |
where=filters or {},
|
|
@@ -86,4 +124,28 @@ def retrieve(query: Optional[str] = None, filters: Optional[Dict[str, Any]] = No
|
|
| 86 |
)
|
| 87 |
docs = res.get("documents", [])
|
| 88 |
metas = res.get("metadatas", [])
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from chromadb.utils.embedding_functions import EmbeddingFunction
|
| 5 |
from config import CHROMA_DIR
|
| 6 |
|
| 7 |
+
# === ์ด๊ธฐํ ===
|
| 8 |
CHROMA_DIR.mkdir(parents=True, exist_ok=True)
|
| 9 |
print(f"๐ ChromaDB ๊ฒฝ๋ก: {CHROMA_DIR.resolve()}")
|
| 10 |
|
|
|
|
| 12 |
_collection = _client.get_or_create_collection(name="game_docs")
|
| 13 |
_embedder: Optional[EmbeddingFunction] = None
|
| 14 |
|
| 15 |
+
|
| 16 |
+
# === Embedder ์ค์ ===
|
| 17 |
def set_embedder(embedder: Any):
|
| 18 |
global _embedder
|
| 19 |
_embedder = embedder
|
| 20 |
|
| 21 |
+
|
| 22 |
def chroma_initialized() -> bool:
|
| 23 |
+
return os.path.exists(str(CHROMA_DIR)) and len(os.listdir(str(CHROMA_DIR))) > 0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# === type๋ณ content ์ถ์ถ ===
|
| 27 |
+
def extract_content(doc: Dict[str, Any]) -> str:
|
| 28 |
+
"""๋ฌธ์ type์ ๋ฐ๋ผ content ํ๋๋ฅผ ์์ฑ"""
|
| 29 |
+
if "content" in doc and isinstance(doc["content"], str):
|
| 30 |
+
return doc["content"]
|
| 31 |
+
|
| 32 |
+
t = doc.get("type", "").lower()
|
| 33 |
+
if t in ["description", "lore", "fallback", "main_res_validate", "npc_persona"]:
|
| 34 |
+
return doc.get("description", "") or doc.get("content", "")
|
| 35 |
+
elif t == "trigger_def":
|
| 36 |
+
return doc.get("description", json.dumps(doc.get("trigger", {}), ensure_ascii=False))
|
| 37 |
+
elif t == "dialogue_turn":
|
| 38 |
+
# player + npc ๋์ฌ๋ฅผ ํฉ์ณ์ ์ ์ฅ
|
| 39 |
+
return f"PLAYER: {doc.get('player', '')}\nNPC: {doc.get('npc', '')}".strip()
|
| 40 |
+
elif t == "flag_def":
|
| 41 |
+
return "\n".join(doc.get("examples_positive", []))
|
| 42 |
+
elif t == "trigger_meta":
|
| 43 |
+
return doc.get("trigger", "")
|
| 44 |
+
else:
|
| 45 |
+
# ์ ์ ์๋ type์ด๋ฉด ๊ฐ๋ฅํ ๋ชจ๋ ํ
์คํธ ํ๋ ํฉ์นจ
|
| 46 |
+
text_parts = []
|
| 47 |
+
for k, v in doc.items():
|
| 48 |
+
if isinstance(v, str):
|
| 49 |
+
text_parts.append(v)
|
| 50 |
+
return "\n".join(text_parts)
|
| 51 |
|
| 52 |
+
|
| 53 |
+
# === ๋์คํฌ์์ ๋ฌธ์ ๋ก๋ ===
|
| 54 |
def load_game_docs_from_disk(path: str) -> List[Dict[str, Any]]:
|
| 55 |
docs = []
|
| 56 |
for filename in os.listdir(path):
|
|
|
|
| 62 |
for i, doc in enumerate(data):
|
| 63 |
if "id" not in doc:
|
| 64 |
doc["id"] = f"{filename}_{i}"
|
| 65 |
+
doc["content"] = extract_content(doc)
|
| 66 |
docs.append(doc)
|
| 67 |
+
elif isinstance(data, dict):
|
| 68 |
if "id" not in data:
|
| 69 |
data["id"] = filename
|
| 70 |
+
data["content"] = extract_content(data)
|
| 71 |
docs.append(data)
|
| 72 |
elif filename.endswith(".txt"):
|
| 73 |
with open(full, "r", encoding="utf-8") as f:
|
| 74 |
content = f.read()
|
| 75 |
docs.append({
|
| 76 |
"id": filename,
|
| 77 |
+
"type": "text",
|
| 78 |
"content": content,
|
| 79 |
"metadata": {}
|
| 80 |
})
|
| 81 |
return docs
|
| 82 |
|
| 83 |
+
|
| 84 |
+
# === ๋ฌธ์ ์ถ๊ฐ ===
|
| 85 |
def add_docs(docs: List[Dict[str, Any]], batch_size: int = 32):
|
| 86 |
assert _embedder is not None, "Embedder not initialized"
|
| 87 |
for i in range(0, len(docs), batch_size):
|
| 88 |
batch = docs[i:i+batch_size]
|
| 89 |
+
ids, contents, embeddings, metadatas = [], [], [], []
|
|
|
|
|
|
|
|
|
|
| 90 |
for doc in batch:
|
| 91 |
+
# id๋ ํ์, content๋ ์์ผ๋ฉด ๋น ๋ฌธ์์ด
|
| 92 |
+
doc_id = doc.get("id", f"doc_{i}")
|
| 93 |
+
content = doc.get("content", "")
|
| 94 |
+
ids.append(doc_id)
|
| 95 |
+
contents.append(content)
|
| 96 |
+
metadatas.append(doc) # ์๋ณธ ์ ์ฒด ์ ์ฅ
|
| 97 |
+
emb = _embedder.encode(content).tolist() if content else []
|
| 98 |
embeddings.append(emb)
|
| 99 |
_collection.add(
|
| 100 |
documents=contents,
|
|
|
|
| 103 |
ids=ids
|
| 104 |
)
|
| 105 |
|
| 106 |
+
|
| 107 |
+
# === ๋ฌธ์ ๊ฒ์ ===
|
| 108 |
def retrieve(query: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 109 |
assert _embedder is not None, "Embedder not initialized"
|
| 110 |
+
|
| 111 |
if query:
|
| 112 |
q_emb = _embedder.encode(query).tolist()
|
| 113 |
res = _collection.query(
|
|
|
|
| 117 |
)
|
| 118 |
docs = res.get("documents", [[]])[0]
|
| 119 |
metas = res.get("metadatas", [[]])[0]
|
|
|
|
| 120 |
else:
|
| 121 |
res = _collection.get(
|
| 122 |
where=filters or {},
|
|
|
|
| 124 |
)
|
| 125 |
docs = res.get("documents", [])
|
| 126 |
metas = res.get("metadatas", [])
|
| 127 |
+
|
| 128 |
+
# ์๋ณธ ๊ตฌ์กฐ ๋ณต์
|
| 129 |
+
results = []
|
| 130 |
+
for d, m in zip(docs, metas):
|
| 131 |
+
if isinstance(m, dict):
|
| 132 |
+
results.append({
|
| 133 |
+
"id": m.get("id", ""),
|
| 134 |
+
"type": m.get("type", "unknown"),
|
| 135 |
+
"npc_id": m.get("npc_id", ""),
|
| 136 |
+
"quest_stage": m.get("quest_stage", ""),
|
| 137 |
+
"location": m.get("location", ""),
|
| 138 |
+
"content": d,
|
| 139 |
+
"metadata": m
|
| 140 |
+
})
|
| 141 |
+
else:
|
| 142 |
+
results.append({
|
| 143 |
+
"id": "",
|
| 144 |
+
"type": "unknown",
|
| 145 |
+
"npc_id": "",
|
| 146 |
+
"quest_stage": "",
|
| 147 |
+
"location": "",
|
| 148 |
+
"content": d,
|
| 149 |
+
"metadata": {}
|
| 150 |
+
})
|
| 151 |
+
return results
|