| import json |
| import os |
| import re |
| from typing import Optional |
|
|
| import chromadb |
| from chromadb.config import Settings |
|
|
| from app.db.note_dao import load_note |
| from app.utils.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") |
| VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db") |
|
|
|
|
| def _chunk_markdown(markdown: str) -> list[dict]: |
| """按 H2/H3 标题拆分 markdown 为语义块。""" |
| sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE) |
| chunks = [] |
| for section in sections: |
| section = section.strip() |
| if not section or len(section) < 30: |
| continue |
| heading_match = re.match(r'^(#{2,3})\s+(.+)', section) |
| title = heading_match.group(2).strip() if heading_match else "intro" |
| chunks.append({ |
| "text": section, |
| "metadata": {"source_type": "markdown", "section_title": title}, |
| }) |
| return chunks |
|
|
|
|
| def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]: |
| """将转录 segments 按滑动窗口分组。""" |
| if not segments: |
| return [] |
| chunks = [] |
| step = max(window_size - overlap, 1) |
| for i in range(0, len(segments), step): |
| window = segments[i:i + window_size] |
| if not window: |
| break |
| text = "\n".join( |
| f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window |
| ) |
| chunks.append({ |
| "text": text, |
| "metadata": { |
| "source_type": "transcript", |
| "start_time": window[0].get("start", 0), |
| "end_time": window[-1].get("end", 0), |
| }, |
| }) |
| return chunks |
|
|
|
|
| def _build_meta_chunk(audio_meta: dict) -> list[dict]: |
| """将视频元信息(标题、作者、描述、标签等)构建为可检索的 chunk。""" |
| if not audio_meta: |
| return [] |
|
|
| raw = audio_meta.get("raw_info", {}) or {} |
| parts = [] |
|
|
| title = audio_meta.get("title") or raw.get("title", "") |
| if title: |
| parts.append(f"视频标题:{title}") |
|
|
| uploader = raw.get("uploader", "") |
| if uploader: |
| parts.append(f"视频作者/UP主:{uploader}") |
|
|
| desc = raw.get("description", "") |
| if desc: |
| parts.append(f"视频简介:{desc[:500]}") |
|
|
| tags = raw.get("tags", []) |
| if tags and isinstance(tags, list): |
| parts.append(f"标签:{', '.join(str(t) for t in tags[:20])}") |
|
|
| duration = audio_meta.get("duration", 0) |
| if duration: |
| m, s = divmod(int(duration), 60) |
| parts.append(f"视频时长:{m}分{s}秒") |
|
|
| platform = audio_meta.get("platform", "") |
| if platform: |
| parts.append(f"平台:{platform}") |
|
|
| url = raw.get("webpage_url", "") |
| if url: |
| parts.append(f"链接:{url}") |
|
|
| if not parts: |
| return [] |
|
|
| return [{ |
| "text": "\n".join(parts), |
| "metadata": {"source_type": "meta"}, |
| }] |
|
|
|
|
| class VectorStoreManager: |
| """基于 ChromaDB 的笔记向量存储管理器。""" |
|
|
| def __init__(self): |
| os.makedirs(VECTOR_DB_DIR, exist_ok=True) |
| self._client = chromadb.PersistentClient( |
| path=VECTOR_DB_DIR, |
| settings=Settings(anonymized_telemetry=False), |
| ) |
|
|
| def _collection_name(self, task_id: str) -> str: |
| """ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。""" |
| return task_id |
|
|
| def index_task(self, task_id: str) -> None: |
| """读取笔记结果并建立向量索引。""" |
| note_data = load_note(task_id) |
| if note_data is None: |
| logger.warning(f"笔记不存在,跳过索引: {task_id}") |
| return |
|
|
| markdown = note_data.get("markdown", "") |
| transcript = note_data.get("transcript", {}) |
| segments = transcript.get("segments", []) |
|
|
| audio_meta = note_data.get("audio_meta", {}) |
|
|
| meta_chunks = _build_meta_chunk(audio_meta) |
| md_chunks = _chunk_markdown(markdown) |
| tr_chunks = _chunk_transcript(segments) |
| all_chunks = meta_chunks + md_chunks + tr_chunks |
|
|
| if not all_chunks: |
| logger.warning(f"笔记内容为空,跳过索引: {task_id}") |
| return |
|
|
| col_name = self._collection_name(task_id) |
|
|
| |
| try: |
| self._client.delete_collection(col_name) |
| except Exception: |
| pass |
|
|
| collection = self._client.create_collection( |
| name=col_name, |
| metadata={"hnsw:space": "cosine"}, |
| ) |
|
|
| documents = [c["text"] for c in all_chunks] |
| metadatas = [c["metadata"] for c in all_chunks] |
| ids = [f"{task_id}_{i}" for i in range(len(all_chunks))] |
|
|
| collection.add(documents=documents, metadatas=metadatas, ids=ids) |
| logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}") |
|
|
| def _parse_results(self, results: dict) -> list[dict]: |
| """将 ChromaDB query 结果转换为 chunk 列表。""" |
| chunks = [] |
| if not results or not results.get("documents") or not results["documents"][0]: |
| return chunks |
| for i in range(len(results["documents"][0])): |
| chunks.append({ |
| "text": results["documents"][0][i], |
| "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, |
| "distance": results["distances"][0][i] if results["distances"] else None, |
| }) |
| return chunks |
|
|
| def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]: |
| """ |
| 按固定配额从各来源检索:meta 1 条、markdown 2 条、transcript 3 条, |
| 确保三种来源都被召回。 |
| """ |
| col_name = self._collection_name(task_id) |
| try: |
| collection = self._client.get_collection(col_name) |
| except Exception: |
| logger.warning(f"Collection 不存在: {col_name}") |
| return [] |
|
|
| all_chunks = [] |
|
|
| |
| quotas = {"meta": 1, "markdown": 2, "transcript": 3} |
|
|
| for source_type, quota in quotas.items(): |
| try: |
| results = collection.query( |
| query_texts=[query_text], |
| n_results=quota, |
| where={"source_type": source_type}, |
| ) |
| all_chunks.extend(self._parse_results(results)) |
| except Exception: |
| pass |
|
|
| return all_chunks |
|
|
| def list_indexed_task_ids(self) -> list[str]: |
| """返回所有已建立索引的 task_id。collection_name 与 task_id 一一对应。""" |
| try: |
| return [c.name for c in self._client.list_collections()] |
| except Exception as e: |
| logger.warning(f"列出 collection 失败: {e}") |
| return [] |
|
|
| def query_across( |
| self, |
| query_text: str, |
| task_ids: Optional[list[str]] = None, |
| n_results_per_task: int = 3, |
| max_total: int = 12, |
| ) -> list[dict]: |
| """ |
| 跨多个笔记并行检索,按距离归并排序后截断。 |
| - task_ids=None: 全库(所有已索引的 task) |
| - 每条 chunk 额外带 task_id 字段,前端用来反查笔记 |
| """ |
| if task_ids is None: |
| task_ids = self.list_indexed_task_ids() |
|
|
| if not task_ids: |
| return [] |
|
|
| all_chunks: list[dict] = [] |
| for tid in task_ids: |
| try: |
| chunks = self.query(tid, query_text, n_results=n_results_per_task) |
| except Exception as e: |
| logger.warning(f"跨笔记检索单笔记失败 task_id={tid}: {e}") |
| continue |
| for ch in chunks: |
| ch["task_id"] = tid |
| all_chunks.extend(chunks) |
|
|
| |
| all_chunks.sort(key=lambda c: c.get("distance") if c.get("distance") is not None else float("inf")) |
| return all_chunks[:max_total] |
|
|
| def delete_index(self, task_id: str) -> None: |
| """删除指定任务的向量索引。""" |
| col_name = self._collection_name(task_id) |
| try: |
| self._client.delete_collection(col_name) |
| logger.info(f"已删除向量索引: {task_id}") |
| except Exception: |
| pass |
|
|
| def is_indexed(self, task_id: str) -> bool: |
| """检查指定任务是否已建立完整索引(含 meta 信息)。""" |
| col_name = self._collection_name(task_id) |
| try: |
| col = self._client.get_collection(col_name) |
| if col.count() == 0: |
| return False |
| |
| meta = col.get(where={"source_type": "meta"}, limit=1) |
| return len(meta["ids"]) > 0 |
| except Exception: |
| return False |
|
|