File size: 5,047 Bytes
c3579fe
 
 
 
 
 
 
 
 
 
6195444
c3579fe
 
 
 
 
 
c7ff072
8593f1c
fe3a981
c7ff072
c3579fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d31c971
c3579fe
 
 
 
 
 
 
 
 
 
 
 
 
d31c971
c3579fe
c7ff072
 
 
 
c3579fe
c7ff072
c3579fe
c7ff072
c3579fe
 
d31c971
c7ff072
 
 
 
 
 
 
 
 
 
 
 
 
c3579fe
 
 
 
 
 
 
c7ff072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3579fe
 
 
 
 
 
 
 
 
 
 
 
 
 
7a88482
c3579fe
 
 
f2585ef
 
c3579fe
f2585ef
c3579fe
 
 
 
 
 
c8845f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import uuid
import yaml
import hnswlib
from typing import List, Dict
from unstructured.partition.html import partition_html
from unstructured.chunking.title import chunk_by_title
import cohere
import requests
import json
import os

co_embed = cohere.ClientV2(os.environ.get("COHERE_API_KEY"))
co_rerank = cohere.ClientV2(os.environ.get("COHERE_API_KEY"))

vectored = None

raw_documents = [
    {"title": "2006年トリノオリンピック", "url": "https://ja.wikipedia.org/wiki/2006年トリノオリンピック"},
    {"title": "生成AI", "url": "https://ja.wikipedia.org/wiki/生成的人工知能"}
]

# あなたのクラスとロジックをここに統合します
class Vectorstore:

    def __init__(self, raw_documents: List[Dict[str, str]]):
        self.raw_documents = raw_documents
        self.docs = []
        self.docs_embs = []
        self.retrieve_top_k = 10
        self.rerank_top_k = 3
        self.load_and_chunk()
        self.embed()
        self.index()


    def load_and_chunk(self) -> None:
        """
        Loads the text from the sources and chunks the HTML content.
        """
        print("Loading documents...")

        for raw_document in self.raw_documents:
            elements = partition_html(url=raw_document["url"])
            chunks = chunk_by_title(elements)
            for chunk in chunks:
                self.docs.append(
                    {
                        "data": {
                            "title": raw_document["title"],
                            "text": str(chunk),
                            "url": raw_document["url"],
                        }
                    }
                )

    def embed(self) -> None:
        """
        Embeds the document chunks using the Cohere API.
        """
        #print("Embedding document chunks...")

        batch_size = 90
        self.docs_len = len(self.docs)
        for i in range(0, self.docs_len, batch_size):
            batch = self.docs[i : min(i + batch_size, self.docs_len)]
            texts = [item["data"]["text"] for item in batch]
            docs_embs_batch = co_embed.embed(
                texts=texts,
                model="embed-multilingual-v3.0",
                input_type="search_document",
                embedding_types=["float"]
            ).embeddings.float
            self.docs_embs.extend(docs_embs_batch)
            #print(docs_embs_batch)

    def index(self) -> None:
        """
        Indexes the document chunks for efficient retrieval.
        """
        print("Indexing document chunks...")

        self.idx = hnswlib.Index(space="ip", dim=1024)
        self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64)
        self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs))))

        #print(f"Indexing complete with {self.idx.get_current_count()} document chunks.")

    def retrieve(self, query: str) -> List[Dict[str, str]]:
        """
        Retrieves document chunks based on the given query.

        Parameters:
        query (str): The query to retrieve document chunks for.

        Returns:
        List[Dict[str, str]]: A list of dictionaries representing the retrieved document chunks, with 'title', 'text', and 'url' keys.
        """

        # Dense retrieval
        query_emb = co_embed.embed(
            texts=[query],
            model="embed-multilingual-v3.0",
            input_type="search_query",
            embedding_types=["float"]
        ).embeddings.float

        doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0]

        # Reranking
        docs_to_rerank = [self.docs[doc_id]["data"] for doc_id in doc_ids]
        yaml_docs = [yaml.dump(doc, sort_keys=False) for doc in docs_to_rerank]
        rerank_results = co_rerank.rerank(
            query=query,
            documents=yaml_docs,
            model="rerank-v3.5", # Pass a dummy string
            top_n=self.rerank_top_k
        )

        doc_ids_reranked = [doc_ids[result.index] for result in rerank_results.results]

        docs_retrieved = []
        for doc_id in doc_ids_reranked:
            docs_retrieved.append(self.docs[doc_id]["data"])

        return docs_retrieved

if not vectored == "vectored":
    vectorstore = Vectorstore(raw_documents)
    vectored = "vectored"


# Vectorstoreのインスタンス
vectorstore = Vectorstore(raw_documents)

# Gradioの関数
def search(query):
    results = vectorstore.retrieve(query)
    return "\n\n".join([f"**Title**: {r['title']}\n**Text**: {r['text']}\n**URL**: {r['url']}" for r in results])

# Gradioインターフェース
interface = gr.Interface(
    css="footer {visibility: hidden;}",
    theme=gr.themes.Glass(),
    fn=search,
    inputs=[gr.Textbox(label="検索クエリ", value="生成的人工知能モデルについて教えてください。")],
    outputs=gr.Textbox(label="検索結果"),
    title="Vectorstore検索デモ"
)

# アプリの実行
if __name__ == "__main__":
    interface.launch(favicon_path='favicon.ico')