File size: 6,164 Bytes
fbe1c8a
 
 
 
0cde401
fbe1c8a
 
0cde401
fbe1c8a
 
 
0cde401
 
 
fbe1c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cde401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# api/weaviate_retrieve.py
"""
与 ClareVoice 共用同一 Weaviate 数据库(GenAICourses)的检索封装。
教师 Agent 和 Clare 均可调用,需与 build_weaviate_index 使用相同 embedding(HF all-MiniLM-L6-v2)。
支持带引用的检索:返回 (text, refs),用于标注 [Source: Filename/Page]。
"""
import os
from typing import List, Optional, Tuple

from .config import USE_WEAVIATE, WEAVIATE_URL, WEAVIATE_API_KEY, WEAVIATE_COLLECTION

# 引用项:本地 VDB 为 {"type": "vdb", "source": "Filename", "page": "1"},Web 为 {"type": "web", "url": "..."}
RefItem = dict


def retrieve_from_weaviate(query: str, top_k: int = 8, timeout_sec: float = 45.0) -> str:
    """
    从 Weaviate Cloud 的 GenAICourses 中检索与 query 相关的课程片段。
    使用 HuggingFace all-MiniLM-L6-v2 与建索引时一致。
    若未配置 Weaviate、query 过短、或依赖未安装,返回空字符串(教师 Agent 仍可运行,仅无 RAG)。
    """
    if not USE_WEAVIATE or not query or len(query.strip()) < 3:
        return ""

    def _call() -> str:
        try:
            import weaviate
            from weaviate.classes.init import Auth
            from llama_index.core import Settings, VectorStoreIndex
            from llama_index.vector_stores.weaviate import WeaviateVectorStore
            from llama_index.embeddings.huggingface import HuggingFaceEmbedding

            Settings.embed_model = HuggingFaceEmbedding(
                model_name=os.getenv("HF_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
            )
            client = weaviate.connect_to_weaviate_cloud(
                cluster_url=WEAVIATE_URL,
                auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
            )
            try:
                if not client.is_ready():
                    return ""
                vs = WeaviateVectorStore(
                    weaviate_client=client,
                    index_name=WEAVIATE_COLLECTION,
                )
                index = VectorStoreIndex.from_vector_store(vs)
                nodes = index.as_retriever(similarity_top_k=top_k).retrieve(query)
                return "\n\n---\n\n".join(n.get_content() for n in nodes) if nodes else ""
            finally:
                client.close()
        except ImportError as e:
            print(f"[weaviate_retrieve] 未安装 weaviate/llama_index,跳过 RAG: {e}")
            return ""
        except Exception as e:
            print(f"[weaviate_retrieve] {repr(e)}")
            return ""

    try:
        import concurrent.futures
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
            return ex.submit(_call).result(timeout=timeout_sec)
    except concurrent.futures.TimeoutError:
        print(f"[weaviate_retrieve] timeout after {timeout_sec}s")
        return ""


def retrieve_from_weaviate_with_refs(
    query: str, top_k: int = 8, timeout_sec: float = 45.0
) -> Tuple[str, List[RefItem]]:
    """
    从 Weaviate 检索并返回正文与引用列表。引用用于标注 [Source: Filename/Page]。
    若 node 无 file_name/page 等元数据,则用 index_name 或 "GenAICourses" 作为 source。
    """
    if not USE_WEAVIATE or not query or len(query.strip()) < 3:
        return "", []

    def _call() -> Tuple[str, List[RefItem]]:
        try:
            import weaviate
            from weaviate.classes.init import Auth
            from llama_index.core import Settings, VectorStoreIndex
            from llama_index.vector_stores.weaviate import WeaviateVectorStore
            from llama_index.embeddings.huggingface import HuggingFaceEmbedding

            Settings.embed_model = HuggingFaceEmbedding(
                model_name=os.getenv("HF_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
            )
            client = weaviate.connect_to_weaviate_cloud(
                cluster_url=WEAVIATE_URL,
                auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
            )
            try:
                if not client.is_ready():
                    return "", []
                vs = WeaviateVectorStore(
                    weaviate_client=client,
                    index_name=WEAVIATE_COLLECTION,
                )
                index = VectorStoreIndex.from_vector_store(vs)
                nodes = index.as_retriever(similarity_top_k=top_k).retrieve(query)
                if not nodes:
                    return "", []
                texts = []
                refs: List[RefItem] = []
                seen = set()
                for n in nodes:
                    content = n.get_content()
                    if isinstance(content, str) and content.strip():
                        texts.append(content.strip())
                    # NodeWithScore: n.node 或 n 上可能有 metadata
                    node = getattr(n, "node", n)
                    meta = getattr(node, "metadata", None) or {}
                    fname = (meta.get("file_name") or meta.get("source_file") or meta.get("filename") or WEAVIATE_COLLECTION or "GenAICourses").strip()
                    page = (meta.get("page_label") or meta.get("page_number") or meta.get("page") or "")
                    page_str = str(page).strip() if page else ""
                    key = (fname, page_str)
                    if key not in seen:
                        seen.add(key)
                        refs.append({"type": "vdb", "source": fname, "page": page_str})
                return "\n\n---\n\n".join(texts), refs
            finally:
                client.close()
        except ImportError as e:
            print(f"[weaviate_retrieve] 未安装 weaviate/llama_index,跳过 RAG: {e}")
            return "", []
        except Exception as e:
            print(f"[weaviate_retrieve] {repr(e)}")
            return "", []

    try:
        import concurrent.futures
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
            return ex.submit(_call).result(timeout=timeout_sec)
    except concurrent.futures.TimeoutError:
        print(f"[weaviate_retrieve] timeout after {timeout_sec}s")
        return "", []