| import gradio as gr |
| import uuid |
| import yaml |
| import hnswlib |
| from typing import List, Dict |
| from unstructured.partition.html import partition_html |
| from unstructured.chunking.title import chunk_by_title |
| import cohere |
| import requests |
| import json |
| import os |
|
|
| co_embed = cohere.ClientV2(os.environ.get("COHERE_API_KEY")) |
| co_rerank = cohere.ClientV2(os.environ.get("COHERE_API_KEY")) |
|
|
| vectored = None |
|
|
| raw_documents = [ |
| {"title": "2006年トリノオリンピック", "url": "https://ja.wikipedia.org/wiki/2006年トリノオリンピック"}, |
| {"title": "生成AI", "url": "https://ja.wikipedia.org/wiki/生成的人工知能"} |
| ] |
|
|
| |
| class Vectorstore: |
|
|
| def __init__(self, raw_documents: List[Dict[str, str]]): |
| self.raw_documents = raw_documents |
| self.docs = [] |
| self.docs_embs = [] |
| self.retrieve_top_k = 10 |
| self.rerank_top_k = 3 |
| self.load_and_chunk() |
| self.embed() |
| self.index() |
|
|
|
|
| def load_and_chunk(self) -> None: |
| """ |
| Loads the text from the sources and chunks the HTML content. |
| """ |
| print("Loading documents...") |
|
|
| for raw_document in self.raw_documents: |
| elements = partition_html(url=raw_document["url"]) |
| chunks = chunk_by_title(elements) |
| for chunk in chunks: |
| self.docs.append( |
| { |
| "data": { |
| "title": raw_document["title"], |
| "text": str(chunk), |
| "url": raw_document["url"], |
| } |
| } |
| ) |
|
|
| def embed(self) -> None: |
| """ |
| Embeds the document chunks using the Cohere API. |
| """ |
| |
|
|
| batch_size = 90 |
| self.docs_len = len(self.docs) |
| for i in range(0, self.docs_len, batch_size): |
| batch = self.docs[i : min(i + batch_size, self.docs_len)] |
| texts = [item["data"]["text"] for item in batch] |
| docs_embs_batch = co_embed.embed( |
| texts=texts, |
| model="embed-multilingual-v3.0", |
| input_type="search_document", |
| embedding_types=["float"] |
| ).embeddings.float |
| self.docs_embs.extend(docs_embs_batch) |
| |
|
|
| def index(self) -> None: |
| """ |
| Indexes the document chunks for efficient retrieval. |
| """ |
| print("Indexing document chunks...") |
|
|
| self.idx = hnswlib.Index(space="ip", dim=1024) |
| self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64) |
| self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs)))) |
|
|
| |
|
|
| def retrieve(self, query: str) -> List[Dict[str, str]]: |
| """ |
| Retrieves document chunks based on the given query. |
| |
| Parameters: |
| query (str): The query to retrieve document chunks for. |
| |
| Returns: |
| List[Dict[str, str]]: A list of dictionaries representing the retrieved document chunks, with 'title', 'text', and 'url' keys. |
| """ |
|
|
| |
| query_emb = co_embed.embed( |
| texts=[query], |
| model="embed-multilingual-v3.0", |
| input_type="search_query", |
| embedding_types=["float"] |
| ).embeddings.float |
|
|
| doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0] |
|
|
| |
| docs_to_rerank = [self.docs[doc_id]["data"] for doc_id in doc_ids] |
| yaml_docs = [yaml.dump(doc, sort_keys=False) for doc in docs_to_rerank] |
| rerank_results = co_rerank.rerank( |
| query=query, |
| documents=yaml_docs, |
| model="rerank-v3.5", |
| top_n=self.rerank_top_k |
| ) |
|
|
| doc_ids_reranked = [doc_ids[result.index] for result in rerank_results.results] |
|
|
| docs_retrieved = [] |
| for doc_id in doc_ids_reranked: |
| docs_retrieved.append(self.docs[doc_id]["data"]) |
|
|
| return docs_retrieved |
|
|
| if not vectored == "vectored": |
| vectorstore = Vectorstore(raw_documents) |
| vectored = "vectored" |
|
|
|
|
| |
| vectorstore = Vectorstore(raw_documents) |
|
|
| |
| def search(query): |
| results = vectorstore.retrieve(query) |
| return "\n\n".join([f"**Title**: {r['title']}\n**Text**: {r['text']}\n**URL**: {r['url']}" for r in results]) |
|
|
| |
| interface = gr.Interface( |
| css="footer {visibility: hidden;}", |
| theme=gr.themes.Glass(), |
| fn=search, |
| inputs=[gr.Textbox(label="検索クエリ", value="生成的人工知能モデルについて教えてください。")], |
| outputs=gr.Textbox(label="検索結果"), |
| title="Vectorstore検索デモ" |
| ) |
|
|
| |
| if __name__ == "__main__": |
| interface.launch(favicon_path='favicon.ico') |
|
|