Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import gradio as gr | |
| from transformers import pipeline | |
| from langchain_classic.chains import RetrievalQAWithSourcesChain | |
| from langchain_community.document_loaders import UnstructuredURLLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| # Get HuggingFace API token from environment variables | |
| token = os.environ.get("API_TOKEN") | |
| # ------------------------ | |
| # LLM | |
| # ------------------------ | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash", | |
| temperature=0.7,api_key = token, | |
| max_token =100 | |
| ) | |
| # Global variable to store the QA chain | |
| chain = None | |
| # Paths to save FAISS and URLs | |
| # ------------------------ | |
| FAISS_FILE = "vectorstore.pkl" | |
| URLS_FILE = "urls.pkl" | |
| # ------------------------ | |
| # Function to process URLs with logging and FAISS management | |
| # ------------------------ | |
| def process_urls_with_logs(url1, url2, url3): | |
| global chain | |
| urls = [url1, url2, url3] | |
| urls = [u.strip() for u in urls if u.strip() != ""] | |
| if len(urls) == 0: | |
| return "Please provide at least one URL." | |
| # Check if FAISS and saved URLs exist | |
| if os.path.exists(FAISS_FILE) and os.path.exists(URLS_FILE): | |
| with open(URLS_FILE, "rb") as f: | |
| saved_urls = pickle.load(f) | |
| else: | |
| saved_urls = [] | |
| # If there are new URLs, recreate FAISS | |
| if set(urls) != set(saved_urls): | |
| print("New URLs detected or FAISS does not exist. Recreating FAISS...") | |
| # Remove old FAISS from memory to free RAM | |
| if 'vectorstore' in globals(): | |
| del globals()['vectorstore'] | |
| print("Loading URLs...") | |
| loader = UnstructuredURLLoader(urls=urls) | |
| documents = loader.load() | |
| print("Splitting documents into chunks...") | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=200) | |
| splits = text_splitter.split_documents(documents) | |
| print("Creating embeddings...") | |
| embeddings = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
| print("Creating vector database (FAISS)...") | |
| vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings) | |
| # Save FAISS and URLs to pickle | |
| with open(FAISS_FILE, "wb") as f: | |
| pickle.dump(vectorstore, f) | |
| with open(URLS_FILE, "wb") as f: | |
| pickle.dump(urls, f) | |
| print("Initializing LLM chain...") | |
| chain = RetrievalQAWithSourcesChain.from_llm( llm=llm, retriever=vectorstore.as_retriever()) | |
| return "FAISS successfully created/recreated!" | |
| else: | |
| print("No new URLs. Using existing FAISS.") | |
| # Load existing FAISS | |
| with open(FAISS_FILE, "rb") as f: | |
| vectorstore = pickle.load(f) | |
| chain = RetrievalQAWithSourcesChain.from_llm( llm=llm, retriever=vectorstore.as_retriever()) | |
| return "Existing FAISS loaded." | |
| # ------------------------ | |
| # Function to answer questions | |
| # ------------------------ | |
| def ask_question(question): | |
| global chain | |
| if chain is None: | |
| return "Please process URLs first." | |
| result = chain.invoke({'question': question}) | |
| answer = result.get("answer", "") | |
| sources = result.get("sources", "") | |
| return answer, sources | |
| # ------------------------ | |
| # Gradio Interface | |
| # ------------------------ | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| # Sidebar: URL input and processing | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Insert URLs") | |
| url1 = gr.Textbox(label="URL 1") | |
| url2 = gr.Textbox(label="URL 2") | |
| url3 = gr.Textbox(label="URL 3") | |
| process_btn = gr.Button("Process URLs") | |
| status_output = gr.Textbox(label="Status", lines=8) | |
| # Main Area: Question input and answer output | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Write your question") | |
| question_box = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Type your question based on the URLs...", | |
| lines=4 | |
| ) | |
| ask_btn = gr.Button("Ask") | |
| answer_output = gr.Textbox(label="Answer", lines=8) | |
| sources_output = gr.Textbox(label="Sources", lines=4) | |
| # Connect buttons to suas funções | |
| process_btn.click( | |
| process_urls_with_logs, | |
| inputs=[url1, url2, url3], | |
| outputs=status_output | |
| ) | |
| ask_btn.click( | |
| ask_question, | |
| inputs=question_box, | |
| outputs=[answer_output, sources_output] | |
| ) | |
| # Launch the Gradio app | |
| app.launch() |