co_agent / app.py
fudii0921's picture
Update app.py
d31c971 verified
import gradio as gr
import uuid
import yaml
import hnswlib
from typing import List, Dict
from unstructured.partition.html import partition_html
from unstructured.chunking.title import chunk_by_title
import cohere
import requests
import json
import os
co_embed = cohere.ClientV2(os.environ.get("COHERE_API_KEY"))
co_rerank = cohere.ClientV2(os.environ.get("COHERE_API_KEY"))
vectored = None
raw_documents = [
{"title": "2006年トリノオリンピック", "url": "https://ja.wikipedia.org/wiki/2006年トリノオリンピック"},
{"title": "生成AI", "url": "https://ja.wikipedia.org/wiki/生成的人工知能"}
]
# あなたのクラスとロジックをここに統合します
class Vectorstore:
def __init__(self, raw_documents: List[Dict[str, str]]):
self.raw_documents = raw_documents
self.docs = []
self.docs_embs = []
self.retrieve_top_k = 10
self.rerank_top_k = 3
self.load_and_chunk()
self.embed()
self.index()
def load_and_chunk(self) -> None:
"""
Loads the text from the sources and chunks the HTML content.
"""
print("Loading documents...")
for raw_document in self.raw_documents:
elements = partition_html(url=raw_document["url"])
chunks = chunk_by_title(elements)
for chunk in chunks:
self.docs.append(
{
"data": {
"title": raw_document["title"],
"text": str(chunk),
"url": raw_document["url"],
}
}
)
def embed(self) -> None:
"""
Embeds the document chunks using the Cohere API.
"""
#print("Embedding document chunks...")
batch_size = 90
self.docs_len = len(self.docs)
for i in range(0, self.docs_len, batch_size):
batch = self.docs[i : min(i + batch_size, self.docs_len)]
texts = [item["data"]["text"] for item in batch]
docs_embs_batch = co_embed.embed(
texts=texts,
model="embed-multilingual-v3.0",
input_type="search_document",
embedding_types=["float"]
).embeddings.float
self.docs_embs.extend(docs_embs_batch)
#print(docs_embs_batch)
def index(self) -> None:
"""
Indexes the document chunks for efficient retrieval.
"""
print("Indexing document chunks...")
self.idx = hnswlib.Index(space="ip", dim=1024)
self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64)
self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs))))
#print(f"Indexing complete with {self.idx.get_current_count()} document chunks.")
def retrieve(self, query: str) -> List[Dict[str, str]]:
"""
Retrieves document chunks based on the given query.
Parameters:
query (str): The query to retrieve document chunks for.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved document chunks, with 'title', 'text', and 'url' keys.
"""
# Dense retrieval
query_emb = co_embed.embed(
texts=[query],
model="embed-multilingual-v3.0",
input_type="search_query",
embedding_types=["float"]
).embeddings.float
doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0]
# Reranking
docs_to_rerank = [self.docs[doc_id]["data"] for doc_id in doc_ids]
yaml_docs = [yaml.dump(doc, sort_keys=False) for doc in docs_to_rerank]
rerank_results = co_rerank.rerank(
query=query,
documents=yaml_docs,
model="rerank-v3.5", # Pass a dummy string
top_n=self.rerank_top_k
)
doc_ids_reranked = [doc_ids[result.index] for result in rerank_results.results]
docs_retrieved = []
for doc_id in doc_ids_reranked:
docs_retrieved.append(self.docs[doc_id]["data"])
return docs_retrieved
if not vectored == "vectored":
vectorstore = Vectorstore(raw_documents)
vectored = "vectored"
# Vectorstoreのインスタンス
vectorstore = Vectorstore(raw_documents)
# Gradioの関数
def search(query):
results = vectorstore.retrieve(query)
return "\n\n".join([f"**Title**: {r['title']}\n**Text**: {r['text']}\n**URL**: {r['url']}" for r in results])
# Gradioインターフェース
interface = gr.Interface(
css="footer {visibility: hidden;}",
theme=gr.themes.Glass(),
fn=search,
inputs=[gr.Textbox(label="検索クエリ", value="生成的人工知能モデルについて教えてください。")],
outputs=gr.Textbox(label="検索結果"),
title="Vectorstore検索デモ"
)
# アプリの実行
if __name__ == "__main__":
interface.launch(favicon_path='favicon.ico')