File size: 8,952 Bytes
6cfe55f aa08cd6 6cfe55f aa08cd6 6cfe55f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 | import json
import os
import re
from typing import Optional
import chromadb
from chromadb.config import Settings
from app.db.note_dao import load_note
from app.utils.logger import get_logger
logger = get_logger(__name__)
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db")
def _chunk_markdown(markdown: str) -> list[dict]:
"""按 H2/H3 标题拆分 markdown 为语义块。"""
sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE)
chunks = []
for section in sections:
section = section.strip()
if not section or len(section) < 30:
continue
heading_match = re.match(r'^(#{2,3})\s+(.+)', section)
title = heading_match.group(2).strip() if heading_match else "intro"
chunks.append({
"text": section,
"metadata": {"source_type": "markdown", "section_title": title},
})
return chunks
def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]:
"""将转录 segments 按滑动窗口分组。"""
if not segments:
return []
chunks = []
step = max(window_size - overlap, 1)
for i in range(0, len(segments), step):
window = segments[i:i + window_size]
if not window:
break
text = "\n".join(
f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window
)
chunks.append({
"text": text,
"metadata": {
"source_type": "transcript",
"start_time": window[0].get("start", 0),
"end_time": window[-1].get("end", 0),
},
})
return chunks
def _build_meta_chunk(audio_meta: dict) -> list[dict]:
"""将视频元信息(标题、作者、描述、标签等)构建为可检索的 chunk。"""
if not audio_meta:
return []
raw = audio_meta.get("raw_info", {}) or {}
parts = []
title = audio_meta.get("title") or raw.get("title", "")
if title:
parts.append(f"视频标题:{title}")
uploader = raw.get("uploader", "")
if uploader:
parts.append(f"视频作者/UP主:{uploader}")
desc = raw.get("description", "")
if desc:
parts.append(f"视频简介:{desc[:500]}")
tags = raw.get("tags", [])
if tags and isinstance(tags, list):
parts.append(f"标签:{', '.join(str(t) for t in tags[:20])}")
duration = audio_meta.get("duration", 0)
if duration:
m, s = divmod(int(duration), 60)
parts.append(f"视频时长:{m}分{s}秒")
platform = audio_meta.get("platform", "")
if platform:
parts.append(f"平台:{platform}")
url = raw.get("webpage_url", "")
if url:
parts.append(f"链接:{url}")
if not parts:
return []
return [{
"text": "\n".join(parts),
"metadata": {"source_type": "meta"},
}]
class VectorStoreManager:
"""基于 ChromaDB 的笔记向量存储管理器。"""
def __init__(self):
os.makedirs(VECTOR_DB_DIR, exist_ok=True)
self._client = chromadb.PersistentClient(
path=VECTOR_DB_DIR,
settings=Settings(anonymized_telemetry=False),
)
def _collection_name(self, task_id: str) -> str:
"""ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。"""
return task_id
def index_task(self, task_id: str) -> None:
"""读取笔记结果并建立向量索引。"""
note_data = load_note(task_id)
if note_data is None:
logger.warning(f"笔记不存在,跳过索引: {task_id}")
return
markdown = note_data.get("markdown", "")
transcript = note_data.get("transcript", {})
segments = transcript.get("segments", [])
audio_meta = note_data.get("audio_meta", {})
meta_chunks = _build_meta_chunk(audio_meta)
md_chunks = _chunk_markdown(markdown)
tr_chunks = _chunk_transcript(segments)
all_chunks = meta_chunks + md_chunks + tr_chunks
if not all_chunks:
logger.warning(f"笔记内容为空,跳过索引: {task_id}")
return
col_name = self._collection_name(task_id)
# 删除旧 collection(幂等)
try:
self._client.delete_collection(col_name)
except Exception:
pass
collection = self._client.create_collection(
name=col_name,
metadata={"hnsw:space": "cosine"},
)
documents = [c["text"] for c in all_chunks]
metadatas = [c["metadata"] for c in all_chunks]
ids = [f"{task_id}_{i}" for i in range(len(all_chunks))]
collection.add(documents=documents, metadatas=metadatas, ids=ids)
logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
def _parse_results(self, results: dict) -> list[dict]:
"""将 ChromaDB query 结果转换为 chunk 列表。"""
chunks = []
if not results or not results.get("documents") or not results["documents"][0]:
return chunks
for i in range(len(results["documents"][0])):
chunks.append({
"text": results["documents"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if results["distances"] else None,
})
return chunks
def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]:
"""
按固定配额从各来源检索:meta 1 条、markdown 2 条、transcript 3 条,
确保三种来源都被召回。
"""
col_name = self._collection_name(task_id)
try:
collection = self._client.get_collection(col_name)
except Exception:
logger.warning(f"Collection 不存在: {col_name}")
return []
all_chunks = []
# 每种来源的配额
quotas = {"meta": 1, "markdown": 2, "transcript": 3}
for source_type, quota in quotas.items():
try:
results = collection.query(
query_texts=[query_text],
n_results=quota,
where={"source_type": source_type},
)
all_chunks.extend(self._parse_results(results))
except Exception:
pass
return all_chunks
def list_indexed_task_ids(self) -> list[str]:
"""返回所有已建立索引的 task_id。collection_name 与 task_id 一一对应。"""
try:
return [c.name for c in self._client.list_collections()]
except Exception as e:
logger.warning(f"列出 collection 失败: {e}")
return []
def query_across(
self,
query_text: str,
task_ids: Optional[list[str]] = None,
n_results_per_task: int = 3,
max_total: int = 12,
) -> list[dict]:
"""
跨多个笔记并行检索,按距离归并排序后截断。
- task_ids=None: 全库(所有已索引的 task)
- 每条 chunk 额外带 task_id 字段,前端用来反查笔记
"""
if task_ids is None:
task_ids = self.list_indexed_task_ids()
if not task_ids:
return []
all_chunks: list[dict] = []
for tid in task_ids:
try:
chunks = self.query(tid, query_text, n_results=n_results_per_task)
except Exception as e:
logger.warning(f"跨笔记检索单笔记失败 task_id={tid}: {e}")
continue
for ch in chunks:
ch["task_id"] = tid
all_chunks.extend(chunks)
# 距离越小越相关;None 排到最后
all_chunks.sort(key=lambda c: c.get("distance") if c.get("distance") is not None else float("inf"))
return all_chunks[:max_total]
def delete_index(self, task_id: str) -> None:
"""删除指定任务的向量索引。"""
col_name = self._collection_name(task_id)
try:
self._client.delete_collection(col_name)
logger.info(f"已删除向量索引: {task_id}")
except Exception:
pass
def is_indexed(self, task_id: str) -> bool:
"""检查指定任务是否已建立完整索引(含 meta 信息)。"""
col_name = self._collection_name(task_id)
try:
col = self._client.get_collection(col_name)
if col.count() == 0:
return False
# 检查是否包含 meta chunk,旧索引可能缺失
meta = col.get(where={"source_type": "meta"}, limit=1)
return len(meta["ids"]) > 0
except Exception:
return False
|