Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| from llama_index import ( | |
| VectorStoreIndex, | |
| download_loader, | |
| ) | |
| import chromadb | |
| from llama_index.llms import MistralAI | |
| from llama_index.embeddings import MistralAIEmbedding | |
| from llama_index.vector_stores import ChromaVectorStore | |
| from llama_index.storage.storage_context import StorageContext | |
| from llama_index import ServiceContext | |
| title = "Gaia Mistral Chat RAG PDF Demo" | |
| description = "Example of an assistant with Gradio, RAG from PDF documents and Mistral AI via its API" | |
| placeholder = ( | |
| "Vous pouvez me posez une question sur ce contexte, appuyer sur Entrée pour valider" | |
| ) | |
| placeholder_url = "Extract text from this url" | |
| llm_model = "mistral-small" | |
| env_api_key = os.environ.get("MISTRAL_API_KEY") | |
| query_engine = None | |
| # Define LLMs | |
| llm = MistralAI(api_key=env_api_key, model=llm_model) | |
| embed_model = MistralAIEmbedding(model_name="mistral-embed", api_key=env_api_key) | |
| # create client and a new collection | |
| db = chromadb.PersistentClient(path="./chroma_db") | |
| chroma_collection = db.get_or_create_collection("quickstart") | |
| # set up ChromaVectorStore and load in data | |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| service_context = ServiceContext.from_defaults( | |
| chunk_size=1024, llm=llm, embed_model=embed_model | |
| ) | |
| PDFReader = download_loader("PDFReader") | |
| loader = PDFReader() | |
| index = VectorStoreIndex( | |
| [], service_context=service_context, storage_context=storage_context | |
| ) | |
| query_engine = index.as_query_engine(similarity_top_k=5) | |
| def get_documents_in_db(): | |
| print("Fetching documents in DB") | |
| docs = [] | |
| for item in chroma_collection.get(include=["metadatas"])["metadatas"]: | |
| docs.append(json.loads(item["_node_content"])["metadata"]["file_name"]) | |
| docs = list(set(docs)) | |
| print(f"Found {len(docs)} documents") | |
| out = "**List of files in db:**\n" | |
| for d in docs: | |
| out += " - " + d + "\n" | |
| return out | |
| def empty_db(): | |
| ids = chroma_collection.get()["ids"] | |
| chroma_collection.delete(ids) | |
| return get_documents_in_db() | |
| def load_file(file): | |
| documents = loader.load_data(file=file) | |
| for doc in documents: | |
| index.insert(doc) | |
| return ( | |
| gr.Textbox(visible=False), | |
| gr.Textbox(value=f"Document encoded ! You can ask questions", visible=True), | |
| get_documents_in_db(), | |
| ) | |
| def load_document(input_file): | |
| file_name = input_file.name.split("/")[-1] | |
| return gr.Textbox(value=f"Document loaded: {file_name}", visible=True) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ # Welcome to Gaia Level 3 Demo | |
| Add a file before interacting with the Chat. | |
| This demo allows you to interact with a pdf file and then ask questions to Mistral APIs. | |
| Mistral will answer with the context extracted from your uploaded file. | |
| *The files will stay in the database unless there is 48h of inactivty or you re-build the space.* | |
| """ | |
| ) | |
| gr.Markdown(""" ### 1 / Extract data from PDF """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_file = gr.File( | |
| label="Load a pdf", | |
| file_types=[".pdf"], | |
| file_count="single", | |
| type="filepath", | |
| interactive=True, | |
| ) | |
| file_msg = gr.Textbox( | |
| label="Loaded documents:", container=False, visible=False | |
| ) | |
| input_file.upload( | |
| fn=load_document, | |
| inputs=[ | |
| input_file, | |
| ], | |
| outputs=[file_msg], | |
| concurrency_limit=20, | |
| ) | |
| help_msg = gr.Markdown( | |
| value="Once the document is loaded, press the Encode button below to add it to the db." | |
| ) | |
| file_btn = gr.Button(value="Encode file ✅", interactive=True) | |
| btn_msg = gr.Textbox(container=False, visible=False) | |
| with gr.Row(): | |
| db_list = gr.Markdown(value=get_documents_in_db) | |
| delete_btn = gr.Button(value="Empty db 🗑️", interactive=True, scale=0) | |
| file_btn.click( | |
| load_file, | |
| inputs=[input_file], | |
| outputs=[file_msg, btn_msg, db_list], | |
| show_progress="full", | |
| ) | |
| delete_btn.click(empty_db, outputs=[db_list], show_progress="minimal") | |
| gr.Markdown(""" ### 2 / Ask a question about this context """) | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(placeholder=placeholder) | |
| clear = gr.ClearButton([msg, chatbot]) | |
| def respond(message, chat_history): | |
| response = query_engine.query(message) | |
| chat_history.append((message, str(response))) | |
| return chat_history | |
| msg.submit(respond, [msg, chatbot], [chatbot]) | |
| demo.title = title | |
| demo.launch() | |