Spaces:
Running
Running
| import os | |
| import json | |
| import bcrypt | |
| from typing import List | |
| from pathlib import Path | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| #from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| #from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import StrOutputParser | |
| from langchain_community.document_loaders import ( | |
| PyMuPDFLoader, | |
| ) | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.indexes import SQLRecordManager, index | |
| from langchain.schema import Document | |
| from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| import chainlit as cl | |
| from chainlit.input_widget import TextInput, Select, Switch, Slider | |
| from literalai import LiteralClient | |
| def auth_callback(username: str, password: str): | |
| auth = json.loads(os.environ['CHAINLIT_AUTH_LOGIN']) | |
| ident = next(d['ident'] for d in auth if d['ident'] == username) | |
| pwd = next(d['pwd'] for d in auth if d['ident'] == username) | |
| resultLogAdmin = bcrypt.checkpw(username.encode('utf-8'), bcrypt.hashpw(ident.encode('utf-8'), bcrypt.gensalt())) | |
| resultPwdAdmin = bcrypt.checkpw(password.encode('utf-8'), bcrypt.hashpw(pwd.encode('utf-8'), bcrypt.gensalt())) | |
| resultRole = next(d['role'] for d in auth if d['ident'] == username) | |
| if resultLogAdmin and resultPwdAdmin and resultRole == "admindatapcc": | |
| return cl.User( | |
| identifier=ident + " : 🧑💼 Admin Datapcc", metadata={"role": "admin", "provider": "credentials"} | |
| ) | |
| elif resultLogAdmin and resultPwdAdmin and resultRole == "userdatapcc": | |
| return cl.User( | |
| identifier=ident + " : 🧑🎓 User Datapcc", metadata={"role": "user", "provider": "credentials"} | |
| ) | |
| literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY")) | |
| chunk_size = 1024 | |
| chunk_overlap = 50 | |
| embeddings_model = HuggingFaceEmbeddings() | |
| PDF_STORAGE_PATH = "./public/pdfs" | |
| def process_pdfs(pdf_storage_path: str): | |
| pdf_directory = Path(pdf_storage_path) | |
| docs = [] # type: List[Document] | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| for pdf_path in pdf_directory.glob("*.pdf"): | |
| loader = PyMuPDFLoader(str(pdf_path)) | |
| documents = loader.load() | |
| docs += text_splitter.split_documents(documents) | |
| doc_search = Chroma.from_documents(docs, embeddings_model) | |
| namespace = "chromadb/my_documents" | |
| record_manager = SQLRecordManager( | |
| namespace, db_url="sqlite:///record_manager_cache.sql" | |
| ) | |
| record_manager.create_schema() | |
| index_result = index( | |
| docs, | |
| record_manager, | |
| doc_search, | |
| cleanup="incremental", | |
| source_id_key="source", | |
| ) | |
| print(f"Indexing stats: {index_result}") | |
| return doc_search | |
| doc_search = process_pdfs(PDF_STORAGE_PATH) | |
| #model = ChatOpenAI(model_name="gpt-4", streaming=True) | |
| os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN'] | |
| repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| model = HuggingFaceEndpoint( | |
| repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True | |
| ) | |
| async def on_chat_start(): | |
| await cl.Message(f"> REVIEWSTREAM").send() | |
| settings = await cl.ChatSettings( | |
| [ | |
| Select( | |
| id="Model", | |
| label="Publications de recherche", | |
| values=["---", "HAL", "Persée"], | |
| initial_index=0, | |
| ), | |
| ] | |
| ).send() | |
| res = await cl.AskActionMessage( | |
| content="", | |
| actions=[ | |
| cl.Action(name="continue", value="continue", label="<p><strong>✅ Continue</strong></p>"), | |
| cl.Action(name="cancel", value="cancel", label="<img src='./public/learn.svg' /><p>Cancel</p>"), | |
| ], | |
| ).send() | |
| if res and res.get("value") == "continue": | |
| await cl.Message( | |
| content="On continue!", | |
| ).send() | |
| template = """Answer the question based only on the following context: | |
| {context} | |
| Question: {question} | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| def format_docs(docs): | |
| return "\n\n".join([d.page_content for d in docs]) | |
| retriever = doc_search.as_retriever() | |
| runnable = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | model | |
| | StrOutputParser() | |
| ) | |
| cl.user_session.set("runnable", runnable) | |
| async def on_message(message: cl.Message): | |
| runnable = cl.user_session.get("runnable") # type: Runnable | |
| msg = cl.Message(content="") | |
| class PostMessageHandler(BaseCallbackHandler): | |
| """ | |
| Callback handler for handling the retriever and LLM processes. | |
| Used to post the sources of the retrieved documents as a Chainlit element. | |
| """ | |
| def __init__(self, msg: cl.Message): | |
| BaseCallbackHandler.__init__(self) | |
| self.msg = msg | |
| self.sources = set() # To store unique pairs | |
| def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): | |
| for d in documents: | |
| source_page_pair = (d.metadata['source'], d.metadata['page']) | |
| self.sources.add(source_page_pair) # Add unique pairs to the set | |
| def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): | |
| if len(self.sources): | |
| sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) | |
| self.msg.elements.append( | |
| cl.Text(name="Sources", content=sources_text, display="inline") | |
| ) | |
| async with cl.Step(type="run", name="QA Assistant"): | |
| async for chunk in runnable.astream( | |
| message.content, | |
| config=RunnableConfig(callbacks=[ | |
| cl.LangchainCallbackHandler(), | |
| PostMessageHandler(msg) | |
| ]), | |
| ): | |
| await msg.stream_token(chunk) | |
| await msg.send() |