Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import hashlib | |
| import requests | |
| from typing import List, Optional | |
| from datetime import datetime | |
| from langchain.schema.embeddings import Embeddings | |
| from streamlit.runtime.uploaded_file_manager import UploadedFile | |
| from clickhouse_connect import get_client | |
| from multiprocessing.pool import ThreadPool | |
| from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings | |
| from .helper import create_retriever_tool | |
| parser_url = "https://api.unstructured.io/general/v0/general" | |
| def parse_files(api_key, user_id, files: List[UploadedFile]): | |
| def parse_file(file: UploadedFile): | |
| headers = { | |
| "accept": "application/json", | |
| "unstructured-api-key": api_key, | |
| } | |
| data = {"strategy": "auto", "ocr_languages": ["eng"]} | |
| file_hash = hashlib.sha256(file.read()).hexdigest() | |
| file_data = {"files": (file.name, file.getvalue(), file.type)} | |
| response = requests.post( | |
| parser_url, headers=headers, data=data, files=file_data | |
| ) | |
| json_response = response.json() | |
| if response.status_code != 200: | |
| raise ValueError(str(json_response)) | |
| texts = [ | |
| { | |
| "text": t["text"], | |
| "file_name": t["metadata"]["filename"], | |
| "entity_id": hashlib.sha256( | |
| (file_hash + t["text"]).encode() | |
| ).hexdigest(), | |
| "user_id": user_id, | |
| "created_by": datetime.now(), | |
| } | |
| for t in json_response | |
| if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 | |
| ] | |
| return texts | |
| with ThreadPool(8) as p: | |
| rows = [] | |
| for r in p.imap_unordered(parse_file, files): | |
| rows.extend(r) | |
| return rows | |
| def extract_embedding(embeddings: Embeddings, texts): | |
| if len(texts) > 0: | |
| embs = embeddings.embed_documents( | |
| [t["text"] for _, t in enumerate(texts)]) | |
| for i, _ in enumerate(texts): | |
| texts[i]["vector"] = embs[i] | |
| return texts | |
| raise ValueError("No texts extracted!") | |
| class PrivateKnowledgeBase: | |
| def __init__( | |
| self, | |
| host, | |
| port, | |
| username, | |
| password, | |
| embedding: Embeddings, | |
| parser_api_key, | |
| db="chat", | |
| kb_table="private_kb", | |
| tool_table="private_tool", | |
| ) -> None: | |
| super().__init__() | |
| kb_schema_ = f""" | |
| CREATE TABLE IF NOT EXISTS {db}.{kb_table}( | |
| entity_id String, | |
| file_name String, | |
| text String, | |
| user_id String, | |
| created_by DateTime, | |
| vector Array(Float32), | |
| CONSTRAINT cons_vec_len CHECK length(vector) = 768, | |
| VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') | |
| ) ENGINE = ReplacingMergeTree ORDER BY entity_id | |
| """ | |
| tool_schema_ = f""" | |
| CREATE TABLE IF NOT EXISTS {db}.{tool_table}( | |
| tool_id String, | |
| tool_name String, | |
| file_names Array(String), | |
| user_id String, | |
| created_by DateTime, | |
| tool_description String | |
| ) ENGINE = ReplacingMergeTree ORDER BY tool_id | |
| """ | |
| self.kb_table = kb_table | |
| self.tool_table = tool_table | |
| config = MyScaleSettings( | |
| host=host, | |
| port=port, | |
| username=username, | |
| password=password, | |
| database=db, | |
| table=kb_table, | |
| ) | |
| client = get_client( | |
| host=config.host, | |
| port=config.port, | |
| username=config.username, | |
| password=config.password, | |
| ) | |
| client.command("SET allow_experimental_object_type=1") | |
| client.command(kb_schema_) | |
| client.command(tool_schema_) | |
| self.parser_api_key = parser_api_key | |
| self.vstore = MyScaleWithoutJSON( | |
| embedding=embedding, | |
| config=config, | |
| must_have_cols=["file_name", "text", "created_by"], | |
| ) | |
| def list_files(self, user_id, tool_name=None): | |
| query = f""" | |
| SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, | |
| arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars | |
| FROM {self.vstore.config.database}.{self.kb_table} | |
| WHERE user_id = '{user_id}' GROUP BY file_name | |
| """ | |
| return [r for r in self.vstore.client.query(query).named_results()] | |
| def add_by_file( | |
| self, user_id, files: List[UploadedFile], **kwargs | |
| ): | |
| data = parse_files(self.parser_api_key, user_id, files) | |
| data = extract_embedding(self.vstore.embeddings, data) | |
| self.vstore.client.insert_df( | |
| self.kb_table, | |
| pd.DataFrame(data), | |
| database=self.vstore.config.database, | |
| ) | |
| def clear(self, user_id): | |
| self.vstore.client.command( | |
| f"DELETE FROM {self.vstore.config.database}.{self.kb_table} " | |
| f"WHERE user_id='{user_id}'" | |
| ) | |
| query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} | |
| WHERE user_id = '{user_id}'""" | |
| self.vstore.client.command(query) | |
| def create_tool( | |
| self, user_id, tool_name, tool_description, files: Optional[List[str]] = None | |
| ): | |
| self.vstore.client.insert_df( | |
| self.tool_table, | |
| pd.DataFrame( | |
| [ | |
| { | |
| "tool_id": hashlib.sha256( | |
| (user_id + tool_name).encode("utf-8") | |
| ).hexdigest(), | |
| "tool_name": tool_name, | |
| "file_names": files, | |
| "user_id": user_id, | |
| "created_by": datetime.now(), | |
| "tool_description": tool_description, | |
| } | |
| ] | |
| ), | |
| database=self.vstore.config.database, | |
| ) | |
| def list_tools(self, user_id, tool_name=None): | |
| extended_where = f"AND tool_name = '{tool_name}'" if tool_name else "" | |
| query = f""" | |
| SELECT tool_name, tool_description, length(file_names) | |
| FROM {self.vstore.config.database}.{self.tool_table} | |
| WHERE user_id = '{user_id}' {extended_where} | |
| """ | |
| return [r for r in self.vstore.client.query(query).named_results()] | |
| def remove_tools(self, user_id, tool_names): | |
| tool_names = ",".join([f"'{t}'" for t in tool_names]) | |
| query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} | |
| WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]""" | |
| self.vstore.client.command(query) | |
| def as_tools(self, user_id, tool_name=None): | |
| tools = self.list_tools(user_id=user_id, tool_name=tool_name) | |
| retrievers = { | |
| t["tool_name"]: create_retriever_tool( | |
| self.vstore.as_retriever( | |
| search_kwargs={ | |
| "where_str": ( | |
| f"user_id='{user_id}' " | |
| f"""AND file_name IN ( | |
| SELECT arrayJoin(file_names) FROM ( | |
| SELECT file_names | |
| FROM {self.vstore.config.database}.{self.tool_table} | |
| WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}') | |
| )""" | |
| ) | |
| }, | |
| ), | |
| name=t["tool_name"], | |
| description=t["tool_description"], | |
| ) | |
| for t in tools | |
| } | |
| return retrievers | |