Spaces:
Sleeping
Sleeping
| # api/weaviate_retrieve.py | |
| """ | |
| 与 ClareVoice 共用同一 Weaviate 数据库(GenAICourses)的检索封装。 | |
| 教师 Agent 和 Clare 均可调用,需与 build_weaviate_index 使用相同 embedding(HF all-MiniLM-L6-v2)。 | |
| 支持带引用的检索:返回 (text, refs),用于标注 [Source: Filename/Page]。 | |
| """ | |
| import os | |
| from typing import List, Optional, Tuple | |
| from .config import USE_WEAVIATE, WEAVIATE_URL, WEAVIATE_API_KEY, WEAVIATE_COLLECTION | |
| # 引用项:本地 VDB 为 {"type": "vdb", "source": "Filename", "page": "1"},Web 为 {"type": "web", "url": "..."} | |
| RefItem = dict | |
| def retrieve_from_weaviate(query: str, top_k: int = 8, timeout_sec: float = 45.0) -> str: | |
| """ | |
| 从 Weaviate Cloud 的 GenAICourses 中检索与 query 相关的课程片段。 | |
| 使用 HuggingFace all-MiniLM-L6-v2 与建索引时一致。 | |
| 若未配置 Weaviate、query 过短、或依赖未安装,返回空字符串(教师 Agent 仍可运行,仅无 RAG)。 | |
| """ | |
| if not USE_WEAVIATE or not query or len(query.strip()) < 3: | |
| return "" | |
| def _call() -> str: | |
| try: | |
| import weaviate | |
| from weaviate.classes.init import Auth | |
| from llama_index.core import Settings, VectorStoreIndex | |
| from llama_index.vector_stores.weaviate import WeaviateVectorStore | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| Settings.embed_model = HuggingFaceEmbedding( | |
| model_name=os.getenv("HF_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| ) | |
| client = weaviate.connect_to_weaviate_cloud( | |
| cluster_url=WEAVIATE_URL, | |
| auth_credentials=Auth.api_key(WEAVIATE_API_KEY), | |
| ) | |
| try: | |
| if not client.is_ready(): | |
| return "" | |
| vs = WeaviateVectorStore( | |
| weaviate_client=client, | |
| index_name=WEAVIATE_COLLECTION, | |
| ) | |
| index = VectorStoreIndex.from_vector_store(vs) | |
| nodes = index.as_retriever(similarity_top_k=top_k).retrieve(query) | |
| return "\n\n---\n\n".join(n.get_content() for n in nodes) if nodes else "" | |
| finally: | |
| client.close() | |
| except ImportError as e: | |
| print(f"[weaviate_retrieve] 未安装 weaviate/llama_index,跳过 RAG: {e}") | |
| return "" | |
| except Exception as e: | |
| print(f"[weaviate_retrieve] {repr(e)}") | |
| return "" | |
| try: | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: | |
| return ex.submit(_call).result(timeout=timeout_sec) | |
| except concurrent.futures.TimeoutError: | |
| print(f"[weaviate_retrieve] timeout after {timeout_sec}s") | |
| return "" | |
| def retrieve_from_weaviate_with_refs( | |
| query: str, top_k: int = 8, timeout_sec: float = 45.0 | |
| ) -> Tuple[str, List[RefItem]]: | |
| """ | |
| 从 Weaviate 检索并返回正文与引用列表。引用用于标注 [Source: Filename/Page]。 | |
| 若 node 无 file_name/page 等元数据,则用 index_name 或 "GenAICourses" 作为 source。 | |
| """ | |
| if not USE_WEAVIATE or not query or len(query.strip()) < 3: | |
| return "", [] | |
| def _call() -> Tuple[str, List[RefItem]]: | |
| try: | |
| import weaviate | |
| from weaviate.classes.init import Auth | |
| from llama_index.core import Settings, VectorStoreIndex | |
| from llama_index.vector_stores.weaviate import WeaviateVectorStore | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| Settings.embed_model = HuggingFaceEmbedding( | |
| model_name=os.getenv("HF_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| ) | |
| client = weaviate.connect_to_weaviate_cloud( | |
| cluster_url=WEAVIATE_URL, | |
| auth_credentials=Auth.api_key(WEAVIATE_API_KEY), | |
| ) | |
| try: | |
| if not client.is_ready(): | |
| return "", [] | |
| vs = WeaviateVectorStore( | |
| weaviate_client=client, | |
| index_name=WEAVIATE_COLLECTION, | |
| ) | |
| index = VectorStoreIndex.from_vector_store(vs) | |
| nodes = index.as_retriever(similarity_top_k=top_k).retrieve(query) | |
| if not nodes: | |
| return "", [] | |
| texts = [] | |
| refs: List[RefItem] = [] | |
| seen = set() | |
| for n in nodes: | |
| content = n.get_content() | |
| if isinstance(content, str) and content.strip(): | |
| texts.append(content.strip()) | |
| # NodeWithScore: n.node 或 n 上可能有 metadata | |
| node = getattr(n, "node", n) | |
| meta = getattr(node, "metadata", None) or {} | |
| fname = (meta.get("file_name") or meta.get("source_file") or meta.get("filename") or WEAVIATE_COLLECTION or "GenAICourses").strip() | |
| page = (meta.get("page_label") or meta.get("page_number") or meta.get("page") or "") | |
| page_str = str(page).strip() if page else "" | |
| key = (fname, page_str) | |
| if key not in seen: | |
| seen.add(key) | |
| refs.append({"type": "vdb", "source": fname, "page": page_str}) | |
| return "\n\n---\n\n".join(texts), refs | |
| finally: | |
| client.close() | |
| except ImportError as e: | |
| print(f"[weaviate_retrieve] 未安装 weaviate/llama_index,跳过 RAG: {e}") | |
| return "", [] | |
| except Exception as e: | |
| print(f"[weaviate_retrieve] {repr(e)}") | |
| return "", [] | |
| try: | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: | |
| return ex.submit(_call).result(timeout=timeout_sec) | |
| except concurrent.futures.TimeoutError: | |
| print(f"[weaviate_retrieve] timeout after {timeout_sec}s") | |
| return "", [] | |