Spaces:
Runtime error
Runtime error
| import time | |
| import gradio as gr | |
| from langchain.docstore.document import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter, Language | |
| import vector_db as vdb | |
| from llm_model import LLMModel | |
| chunk_size = 2000 | |
| chunk_overlap = 200 | |
| uploaded_docs = [] | |
| uploaded_df = gr.Dataframe(headers=["file_name", "content_length"]) | |
| upload_files_section = gr.Files( | |
| file_types=[".md", ".mdx", ".rst", ".txt"], | |
| ) | |
| chatbot_stream = gr.Chatbot(bubble_full_width=False, show_copy_button=True) | |
| def load_docs(files): | |
| all_docs = [] | |
| all_qa = [] | |
| for file in files: | |
| if file.name is not None: | |
| with open(file.name, "r") as f: | |
| file_content = f.read() | |
| file_name = file.name.split("/")[-1] | |
| # Create document with metadata | |
| doc = Document(page_content=file_content, metadata={"source": file_name}) | |
| # Create an instance of the RecursiveCharacterTextSplitter class with specific parameters. | |
| # It splits text into chunks of 1000 characters each with a 150-character overlap. | |
| language = get_language(file_name) | |
| text_splitter = RecursiveCharacterTextSplitter.from_language( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| language=language | |
| ) | |
| # Split the text into chunks using the text splitter. | |
| doc_chunks = text_splitter.split_documents([doc]) | |
| print(f"Number of chunks: {len(doc_chunks)}") | |
| # Foreach chunk, send to LLM to get potential questions and answers | |
| for doc_chunk in doc_chunks: | |
| gr.Info("Analysing document...") | |
| potential_qa_from_doc = llm_model.get_potential_question_answer(doc_chunk.page_content) | |
| all_qa += [Document(page_content=potential_qa_from_doc, metadata=doc_chunk.metadata)] | |
| all_docs += doc_chunks | |
| uploaded_docs.append(file.name) | |
| vector_db.load_docs_into_vector_db(all_qa) | |
| gr.Info("Loaded document(s) into vector db.") | |
| return uploaded_docs | |
| def get_language(file_name: str): | |
| if file_name.endswith(".md") or file_name.endswith(".mdx"): | |
| return Language.MARKDOWN | |
| elif file_name.endswith(".rst"): | |
| return Language.RST | |
| else: | |
| return Language.MARKDOWN | |
| def get_vector_db(): | |
| return vdb.VectorDB() | |
| def get_llm_model(_db: vdb.VectorDB): | |
| retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2}) | |
| # return LLMModel(retriever=retriever).create_qa_chain() | |
| return LLMModel(retriever=retriever) | |
| def predict(message, history): | |
| # resp = llm_model.answer_question_inference(message) | |
| # return resp.get("answer") | |
| resp = llm_model.answer_question_inference_text_gen(message) | |
| for i in range(len(resp)): | |
| time.sleep(0.005) | |
| yield resp[:i + 1] | |
| # final_resp = "" | |
| # for c in resp: | |
| # final_resp += str(c) | |
| # # + "β" | |
| # yield final_resp | |
| # start_time = time.time() | |
| # res = llm_model({"query": message}) | |
| # sources = [] | |
| # for source_docs in res['source_documents']: | |
| # if 'source' in source_docs.metadata: | |
| # sources.append(source_docs.metadata['source']) | |
| # # Display assistant response in chat message container | |
| # end_time = time.time() | |
| # time_taken = "{:.2f}".format(end_time - start_time) | |
| # format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s" | |
| # format_source = None | |
| # for source_docs in res['source_documents']: | |
| # if 'source' in source_docs.metadata: | |
| # format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}" | |
| # | |
| # return format_answer | |
| def vote(data: gr.LikeData): | |
| if data.liked: | |
| gr.Info("You upvoted this response π", ) | |
| else: | |
| gr.Warning("You downvoted this response π") | |
| vector_db = get_vector_db() | |
| llm_model = get_llm_model(vector_db) | |
| chat_interface_stream = gr.ChatInterface( | |
| predict, | |
| title="π Document answering bot", | |
| description="ππ¦ Upload some documents on the side and ask questions!", | |
| textbox=gr.Textbox(container=False, scale=7), | |
| chatbot=chatbot_stream, | |
| examples=["What is Data Caterer?"], | |
| ).queue(default_concurrency_limit=1) | |
| with gr.Blocks() as blocks: | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=100) as upload_col: | |
| gr.Interface( | |
| load_docs, | |
| title="π Upload documents", | |
| inputs=upload_files_section, | |
| outputs=gr.Files(), | |
| allow_flagging="never" | |
| ) | |
| # upload_files_section.upload(load_docs, inputs=upload_files_section) | |
| with gr.Column(scale=4, min_width=600) as chat_col: | |
| chatbot_stream.like(vote, None, None) | |
| chat_interface_stream.render() | |
| blocks.queue().launch() | |