Spaces:
Build error
Build error
| import os | |
| import streamlit as st | |
| import PyPDF2 | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| import faiss | |
| from groq import Groq | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| # Load your environment variables | |
| API = os.environ['GROQ_API_KEY'] = "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" | |
| # Initialize Groq client | |
| client = Groq(api_key=API) | |
| # Initialize HuggingFace embedding model from langchain_community | |
| embedding_model_name = "sentence-transformers/all-mpnet-base-v2" | |
| embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
| # Determine the vector size dynamically | |
| sample_embedding = embedding_model.embed_query("test") | |
| dimension = len(sample_embedding) | |
| # Initialize FAISS | |
| index = faiss.IndexFlatL2(dimension) | |
| # Streamlit front-end | |
| st.title("RAG-based PDF Query Application") | |
| # Function to extract file ID from Google Drive link | |
| def extract_file_id(link): | |
| try: | |
| return link.split("/d/")[1].split("/")[0] | |
| except IndexError: | |
| st.error("Invalid Google Drive link format.") | |
| return None | |
| # Function to fetch PDF content from Google Drive | |
| def fetch_pdf_from_drive(file_id): | |
| url = f"https://drive.google.com/uc?export=download&id={file_id}" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| return BytesIO(response.content) | |
| else: | |
| st.error(f"Failed to fetch document with ID: {file_id}") | |
| return None | |
| # Initialize session state | |
| if 'chunks' not in st.session_state: | |
| st.session_state['chunks'] = [] | |
| if 'embeddings_added' not in st.session_state: | |
| st.session_state['embeddings_added'] = False | |
| # Input for Google Drive links | |
| gdrive_links = st.text_area("Enter Google Drive links (one per line):") | |
| if st.button("Process Documents"): | |
| if gdrive_links: | |
| links = gdrive_links.splitlines() | |
| text = "" | |
| for link in links: | |
| file_id = extract_file_id(link) | |
| if file_id: | |
| pdf_file = fetch_pdf_from_drive(file_id) | |
| if pdf_file: | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| st.write(f"Processed document: {link}") | |
| if text.strip(): | |
| def create_chunks(text, chunk_size=500): | |
| import re | |
| paragraphs = re.split(r'\n\s*\n', text) | |
| chunks = [] | |
| current_chunk = "" | |
| for paragraph in paragraphs: | |
| if len(current_chunk) + len(paragraph) <= chunk_size: | |
| current_chunk += " " + paragraph | |
| else: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = paragraph | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| st.session_state['chunks'] = create_chunks(text) | |
| st.write(f"Created {len(st.session_state['chunks'])} chunks.") | |
| valid_chunks = [chunk for chunk in st.session_state['chunks'] if chunk.strip()] | |
| embeddings = [] | |
| for chunk in valid_chunks: | |
| try: | |
| embedding = embedding_model.embed_query(chunk) | |
| embeddings.append(embedding) | |
| except Exception as e: | |
| st.error(f"Error creating embedding for chunk: {chunk[:30]}... {str(e)}") | |
| if embeddings: | |
| embeddings = np.array(embeddings, dtype=np.float32) | |
| faiss.normalize_L2(embeddings) | |
| index.add(embeddings) | |
| st.session_state['embeddings_added'] = True | |
| st.write("Embeddings generated and stored in FAISS.") | |
| else: | |
| st.error("No valid embeddings were generated. Please check the document content.") | |
| else: | |
| st.error("No valid text extracted from the documents.") | |
| # Query input | |
| if st.session_state['chunks'] and st.session_state['embeddings_added']: | |
| user_query = st.text_input("Enter your query:") | |
| if user_query: | |
| query_embedding = embedding_model.embed_query(user_query) | |
| query_embedding = np.array([query_embedding], dtype=np.float32) | |
| faiss.normalize_L2(query_embedding) | |
| k = 5 | |
| distances, indices = index.search(query_embedding, k) | |
| # Use a relaxed threshold to find relevant chunks | |
| threshold = 0.6 # Adjust this value as needed | |
| relevant_chunks = [ | |
| st.session_state['chunks'][i] | |
| for i, distance in zip(indices[0], distances[0]) | |
| if distance <= threshold | |
| ] | |
| if not relevant_chunks: | |
| st.write("No highly relevant chunks found. Using the best available chunks.") | |
| relevant_chunks = [st.session_state['chunks'][i] for i in indices[0]] | |
| summarized_chunks = [chunk.strip() for chunk in relevant_chunks] | |
| prompt = ( | |
| "You are an intelligent assistant. Strictly answer the query using the context provided below.\n\n" | |
| "Context:\n" | |
| + "\n\n".join(summarized_chunks) | |
| + f"\n\nQuery: {user_query}" | |
| ) | |
| chat_completion = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama3-8b-8192" | |
| ) | |
| response = chat_completion.choices[0].message.content | |
| st.write("### Response") | |
| st.write(response) | |
| else: | |
| st.warning("No chunks available or embeddings not added. Please upload and process a document first.") |