ClareCourseWare / api /weaviate_retrieve.py
claudqunwang's picture
feat(courseware): AI Teacher Assistant Agent 模块化
0cde401
# api/weaviate_retrieve.py
"""
与 ClareVoice 共用同一 Weaviate 数据库(GenAICourses)的检索封装。
教师 Agent 和 Clare 均可调用,需与 build_weaviate_index 使用相同 embedding(HF all-MiniLM-L6-v2)。
支持带引用的检索:返回 (text, refs),用于标注 [Source: Filename/Page]。
"""
import os
from typing import List, Optional, Tuple
from .config import USE_WEAVIATE, WEAVIATE_URL, WEAVIATE_API_KEY, WEAVIATE_COLLECTION
# 引用项:本地 VDB 为 {"type": "vdb", "source": "Filename", "page": "1"},Web 为 {"type": "web", "url": "..."}
RefItem = dict
def retrieve_from_weaviate(query: str, top_k: int = 8, timeout_sec: float = 45.0) -> str:
"""
从 Weaviate Cloud 的 GenAICourses 中检索与 query 相关的课程片段。
使用 HuggingFace all-MiniLM-L6-v2 与建索引时一致。
若未配置 Weaviate、query 过短、或依赖未安装,返回空字符串(教师 Agent 仍可运行,仅无 RAG)。
"""
if not USE_WEAVIATE or not query or len(query.strip()) < 3:
return ""
def _call() -> str:
try:
import weaviate
from weaviate.classes.init import Auth
from llama_index.core import Settings, VectorStoreIndex
from llama_index.vector_stores.weaviate import WeaviateVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
Settings.embed_model = HuggingFaceEmbedding(
model_name=os.getenv("HF_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
)
client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
)
try:
if not client.is_ready():
return ""
vs = WeaviateVectorStore(
weaviate_client=client,
index_name=WEAVIATE_COLLECTION,
)
index = VectorStoreIndex.from_vector_store(vs)
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(query)
return "\n\n---\n\n".join(n.get_content() for n in nodes) if nodes else ""
finally:
client.close()
except ImportError as e:
print(f"[weaviate_retrieve] 未安装 weaviate/llama_index,跳过 RAG: {e}")
return ""
except Exception as e:
print(f"[weaviate_retrieve] {repr(e)}")
return ""
try:
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
return ex.submit(_call).result(timeout=timeout_sec)
except concurrent.futures.TimeoutError:
print(f"[weaviate_retrieve] timeout after {timeout_sec}s")
return ""
def retrieve_from_weaviate_with_refs(
query: str, top_k: int = 8, timeout_sec: float = 45.0
) -> Tuple[str, List[RefItem]]:
"""
从 Weaviate 检索并返回正文与引用列表。引用用于标注 [Source: Filename/Page]。
若 node 无 file_name/page 等元数据,则用 index_name 或 "GenAICourses" 作为 source。
"""
if not USE_WEAVIATE or not query or len(query.strip()) < 3:
return "", []
def _call() -> Tuple[str, List[RefItem]]:
try:
import weaviate
from weaviate.classes.init import Auth
from llama_index.core import Settings, VectorStoreIndex
from llama_index.vector_stores.weaviate import WeaviateVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
Settings.embed_model = HuggingFaceEmbedding(
model_name=os.getenv("HF_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
)
client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
)
try:
if not client.is_ready():
return "", []
vs = WeaviateVectorStore(
weaviate_client=client,
index_name=WEAVIATE_COLLECTION,
)
index = VectorStoreIndex.from_vector_store(vs)
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(query)
if not nodes:
return "", []
texts = []
refs: List[RefItem] = []
seen = set()
for n in nodes:
content = n.get_content()
if isinstance(content, str) and content.strip():
texts.append(content.strip())
# NodeWithScore: n.node 或 n 上可能有 metadata
node = getattr(n, "node", n)
meta = getattr(node, "metadata", None) or {}
fname = (meta.get("file_name") or meta.get("source_file") or meta.get("filename") or WEAVIATE_COLLECTION or "GenAICourses").strip()
page = (meta.get("page_label") or meta.get("page_number") or meta.get("page") or "")
page_str = str(page).strip() if page else ""
key = (fname, page_str)
if key not in seen:
seen.add(key)
refs.append({"type": "vdb", "source": fname, "page": page_str})
return "\n\n---\n\n".join(texts), refs
finally:
client.close()
except ImportError as e:
print(f"[weaviate_retrieve] 未安装 weaviate/llama_index,跳过 RAG: {e}")
return "", []
except Exception as e:
print(f"[weaviate_retrieve] {repr(e)}")
return "", []
try:
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
return ex.submit(_call).result(timeout=timeout_sec)
except concurrent.futures.TimeoutError:
print(f"[weaviate_retrieve] timeout after {timeout_sec}s")
return "", []