Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import time | |
| import asyncio | |
| from fastapi import status | |
| from langchain_groq import ChatGroq | |
| from langchain.schema import Document | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from app.core.template import prompt_template_description | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Async PDF loader | |
| async def pdf_loader(url: str): | |
| pages = [] | |
| loader = PyPDFLoader(url) | |
| async for page in loader.alazy_load(): | |
| pages.append(page) | |
| return pages | |
| # Main function to create/load vectorstore | |
| async def load_and_create_vector_store(url: str): | |
| """ | |
| Loads a PDF document from a URL and either reuses or builds a FAISS vectorstore. | |
| Returns a retriever object. | |
| """ | |
| vectorstore_path = "/tmp/database/faiss_index" | |
| if os.path.exists(f"{vectorstore_path}/index.faiss"): | |
| logging.info("Vector store already exists, loading it.") | |
| vectorstore = FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True) | |
| else: | |
| logging.info("Vector store not found. Creating new one from document.") | |
| pages = await pdf_loader(url) | |
| if not pages: | |
| raise ValueError("No pages loaded from the document.") | |
| full_text = "\n\n".join([page.page_content for page in pages]) | |
| documents = [Document(page_content=full_text, metadata={"source": url})] | |
| # Use CharacterTextSplitter with optimized parameters for better chunk quality | |
| text_splitter = CharacterTextSplitter( | |
| separator="\n\n", | |
| chunk_size=2500, | |
| chunk_overlap=300, | |
| length_function=len, | |
| ) | |
| split_docs = text_splitter.split_documents(documents) | |
| logging.info(f"Document split into {len(split_docs)} chunks") | |
| vectorstore = FAISS.from_documents(split_docs, embeddings) | |
| vectorstore.save_local(vectorstore_path) | |
| return vectorstore.as_retriever( | |
| search_kwargs={"k": 2, "score_threshold": 0.5} | |
| ) | |
| async def llm_setup(config, url): | |
| """ | |
| Setup the LLM for question answering. | |
| This function initializes the LLM with the necessary configurations | |
| for processing questions and generating answers based on the context. | |
| Args: | |
| config: Configuration dictionary with LLM settings | |
| url: URL of the document to process | |
| Returns: | |
| object: The configured LLM instance. | |
| """ | |
| llm = ChatGroq( | |
| model=f"{config.get('MODEL_NAME')}", | |
| temperature=f"{config.get('TEMPERATURE', 0)}", | |
| max_tokens=f"{config.get('MAX_TOKENS', 300)}", # Increased token limit for JSON responses | |
| max_retries=f"{config.get('MAX_RETRIES', 3)}", | |
| api_key=f"{os.getenv('GROQ_KEY')}", | |
| ) | |
| logging.info(f"LLM initialized with model: {config.get('MODEL_NAME')}, api_key: {os.getenv('GROQ_KEY')}") | |
| # Choose template based on whether we need structured JSON output | |
| prompt_template = prompt_template_description() | |
| retriever = await load_and_create_vector_store(url=url) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": prompt_template} | |
| ) | |
| return qa_chain | |
| async def llm_response_generator(config, url, questions): | |
| """ | |
| Generate answers from the LLM within 30 seconds. | |
| Args: | |
| config: Configuration dictionary with LLM settings | |
| url: URL of the document to process | |
| questions: List of questions to answer | |
| use_json: Whether to force JSON output format | |
| Returns: | |
| Tuple of (response dict, status code) | |
| """ | |
| try: | |
| start = time.time() | |
| qa_chain = await llm_setup(config, url) | |
| answers = [] | |
| for question in questions: | |
| elapsed = time.time() - start | |
| if elapsed > 28: # leave margin for safety | |
| logging.warning("Time limit reached, skipping remaining questions.") | |
| break | |
| try: | |
| answer = await qa_chain.arun(question) | |
| answers.append(answer) | |
| except Exception as e: | |
| logging.error(f"Error answering: {question} | {e}") | |
| answers.append("Error processing question.") | |
| return {"answers": answers}, status.HTTP_200_OK | |
| except Exception as e: | |
| logging.error(f"Error in llm_response_generator: {e}") | |
| return {"answers": []}, status.HTTP_500_INTERNAL_SERVER_ERROR | |