File size: 3,899 Bytes
08583a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

import sys
import os
from langchain_core.tools import tool

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils.call_llm import llm
from configs.config import Config
env = Config()

def generate_search_string(query: str) -> str:
    """
    Generate an optimal Wikipedia search string from the given query.

    Args:
        query (str): The input query for generating the search string.

    Returns:
        str: A single continuous search string optimized for Wikipedia search.
    """
    if not query or not isinstance(query, str):
        raise ValueError("Query must be a non-empty string.")

    prompt = f"""
    Generate an optimal Wikipedia search string from the query '{query}'. \n
    Just return a single continuous search string without any additional text or formatting or quotation marks. \n
    Do not include any other text or explanation."""

    response = env.LOCAL_LLM.invoke(prompt)
    if not response or not response.content.strip():
        raise ValueError("Failed to generate a valid search string.")
    
    return response.content.strip()

def document_store(query, chunk_size, chunk_overlap):
    """Load a Wikipedia page based on the query and language."""
    from langchain_community.document_loaders import WikipediaLoader
    from langchain.text_splitter import RecursiveCharacterTextSplitter
    from langchain.schema.document import Document
    from langchain_community.vectorstores.faiss import FAISS

    embedding_model = env.EMBED_MODEL
    language = "en"

    search_query = generate_search_string(query)
    if not search_query:
        raise ValueError("Search query is empty or invalid.")
        
    loader = WikipediaLoader(query=search_query, lang=language)
    documents = loader.load()
    combined_text = "".join([doc.page_content for doc in documents if doc.page_content])
    if not combined_text:
        raise ValueError("No text found in the loaded documents.")

    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
    )
        
    chunks = splitter.split_text(combined_text)
    if not chunks:
        raise ValueError("No chunks generated from the combined text.")

    docs = [
        Document(page_content=chunk, metadata={"source": query})
        for chunk in chunks
    ]
    if not docs:
        raise ValueError("No documents created from the chunks.")

    embeddings = embedding_model.embed_documents([doc.page_content for doc in docs])
    if not embeddings:
        raise ValueError("No embeddings generated for the documents.")

    store = FAISS.from_documents(docs, embedding=embedding_model)
    return store

def search(query,chunk_size, chunk_overlap):
    store = document_store(query,chunk_size, chunk_overlap)
    results = store.similarity_search_with_score(query, k=5)

    # Filter results based on a relevance threshold
    filtered_results = []
    for doc, score in results:
        if score <= 0.5:  # Relevance threshold
            filtered_results.append((doc, score))
    
    return filtered_results
    
@tool("wikipedia_search_tool")
def wikipedia_search_tool(query: str, chunk_size: int =1000, chunk_overlap: int =200):
    """
    Run the Wikipedia search tool with the given query and parameters.
    """
    print("----- Wiki Run ---")
    default_prompts = env.WIKI_DEFAULT_PROMPTS

    response = search(query, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    if not response:
        response = [("No relevant documents found.", 1.0)]

    llm_input = [
        {"role": "system", "content": default_prompts["system"]},
        {"role": "user", "content": default_prompts["user"].format(query=query)},
        {"role": "user", "content": response[0][0] if response else "No relevant documents found."}
    ]

    call_llm = env.LOCAL_LLM.invoke(llm_input)

    return call_llm