Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import PyPDF2 | |
| import docx | |
| import requests | |
| import json | |
| from typing import List | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class RAGSystem: | |
| def __init__(self): | |
| # Initialize sentence transformer for embeddings | |
| self.embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.documents = [] | |
| self.embeddings = None | |
| self.groq_api_key = None | |
| self.groq_base_url = "https://api.groq.com/openai/v1/chat/completions" | |
| def set_api_key(self, api_key: str): | |
| """Set the Groq API key""" | |
| self.groq_api_key = api_key | |
| def extract_text_from_pdf(self, file_path: str) -> str: | |
| """Extract text from PDF file""" | |
| try: | |
| with open(file_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error extracting text from PDF: {e}") | |
| return "" | |
| def extract_text_from_docx(self, file_path: str) -> str: | |
| """Extract text from DOCX file""" | |
| try: | |
| doc = docx.Document(file_path) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error extracting text from DOCX: {e}") | |
| return "" | |
| def extract_text_from_txt(self, file_path: str) -> str: | |
| """Extract text from TXT file""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| return file.read() | |
| except Exception as e: | |
| logger.error(f"Error extracting text from TXT: {e}") | |
| return "" | |
| def process_documents(self, files) -> str: | |
| """Process uploaded documents and create embeddings""" | |
| if not files: | |
| return "No files uploaded." | |
| self.documents = [] | |
| all_text = "" | |
| for file in files: | |
| file_path = file.name | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| if file_extension == '.pdf': | |
| text = self.extract_text_from_pdf(file_path) | |
| elif file_extension == '.docx': | |
| text = self.extract_text_from_docx(file_path) | |
| elif file_extension == '.txt': | |
| text = self.extract_text_from_txt(file_path) | |
| else: | |
| continue | |
| if text.strip(): | |
| # Split text into chunks (sentences or paragraphs) | |
| chunks = self.split_text(text) | |
| self.documents.extend(chunks) | |
| all_text += text + "\n" | |
| if self.documents: | |
| # Create embeddings for all document chunks | |
| self.embeddings = self.embedder.encode(self.documents) | |
| return f"✅ Processed {len(files)} files with {len(self.documents)} text chunks." | |
| else: | |
| return "⚠️ No text could be extracted from the uploaded files." | |
| def split_text(self, text: str, chunk_size: int = 500) -> List[str]: | |
| """Split text into smaller chunks""" | |
| sentences = text.split('.') | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) < chunk_size: | |
| current_chunk += sentence + "." | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence + "." | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return [chunk for chunk in chunks if chunk.strip()] | |
| def retrieve_relevant_chunks(self, query: str, top_k: int = 3) -> List[str]: | |
| """Retrieve most relevant document chunks for the query""" | |
| if not self.documents or self.embeddings is None: | |
| return [] | |
| # Encode the query | |
| query_embedding = self.embedder.encode([query]) | |
| # Calculate similarities | |
| similarities = cosine_similarity(query_embedding, self.embeddings)[0] | |
| # Get top-k most similar chunks | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| relevant_chunks = [self.documents[i] for i in top_indices] | |
| return relevant_chunks | |
| def query_groq(self, prompt: str) -> str: | |
| """Query Groq API with the given prompt""" | |
| if not self.groq_api_key: | |
| return "⚠️ Please set your Groq API key first." | |
| headers = { | |
| "Authorization": f"Bearer {self.groq_api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": "llama-3.1-8b-instant", # ✅ Valid Groq model | |
| "messages": [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant. Answer questions based on the provided context. If the context doesn't contain enough information to answer the question, say so clearly." | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| "temperature": 0.7, | |
| "max_tokens": 1024, | |
| "stream": False | |
| } | |
| try: | |
| response = requests.post(self.groq_base_url, headers=headers, json=data) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result["choices"][0]["message"]["content"] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error querying Groq API: {e}") | |
| return f"Error querying Groq API: {str(e)}" | |
| except KeyError: | |
| logger.error(f"Unexpected Groq API response: {result}") | |
| return f"Unexpected Groq API response: {json.dumps(result, indent=2)}" | |
| def answer_query(self, query: str) -> str: | |
| """Answer a query using RAG""" | |
| if not self.documents: | |
| return "⚠️ No documents have been processed yet. Please upload and process documents first." | |
| if not self.groq_api_key: | |
| return "⚠️ Please set your Groq API key first." | |
| # Retrieve relevant chunks | |
| relevant_chunks = self.retrieve_relevant_chunks(query) | |
| if not relevant_chunks: | |
| return "⚠️ No relevant information found in the documents." | |
| # Create context from relevant chunks | |
| context = "\n\n".join(relevant_chunks) | |
| # Create prompt for the LLM | |
| prompt = f"""Context from documents: | |
| {context} | |
| Question: {query} | |
| Please answer the question based on the provided context. If the context doesn't contain enough information to fully answer the question, please mention what information is missing.""" | |
| # Get response from Groq | |
| response = self.query_groq(prompt) | |
| return response | |
| # Initialize RAG system | |
| rag_system = RAGSystem() | |
| # Gradio interface functions | |
| def set_api_key(api_key): | |
| rag_system.set_api_key(api_key) | |
| return "✅ API key set successfully!" | |
| def process_files(files): | |
| if not files: | |
| return "⚠️ Please upload at least one file." | |
| return rag_system.process_documents(files) | |
| def answer_question(query): | |
| if not query.strip(): | |
| return "⚠️ Please enter a question." | |
| return rag_system.answer_query(query) | |
| # Create Gradio interface | |
| with gr.Blocks(title="RAG Document Q&A System", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 📚 RAG Document Q&A System") | |
| gr.Markdown("Upload documents and ask questions about their content using AI!") | |
| with gr.Tab("Setup"): | |
| gr.Markdown("## Step 1: Set your Groq API Key") | |
| gr.Markdown("Get your free API key from [Groq Console](https://console.groq.com/)") | |
| with gr.Row(): | |
| api_key_input = gr.Textbox( | |
| type="password", | |
| label="Groq API Key", | |
| placeholder="Enter your Groq API key here..." | |
| ) | |
| set_key_btn = gr.Button("Set API Key", variant="primary") | |
| api_key_status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("## Step 2: Upload Documents") | |
| gr.Markdown("Upload PDF, DOCX, or TXT files") | |
| file_upload = gr.Files( | |
| file_types=[".pdf", ".docx", ".txt"], | |
| label="Upload Documents", | |
| file_count="multiple" | |
| ) | |
| process_btn = gr.Button("Process Documents", variant="primary") | |
| process_status = gr.Textbox(label="Processing Status", interactive=False) | |
| with gr.Tab("Ask Questions"): | |
| gr.Markdown("## Ask Questions About Your Documents") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| query_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask a question about your documents...", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| ask_btn = gr.Button("Ask Question", variant="primary") | |
| answer_output = gr.Textbox( | |
| label="Answer", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Example questions | |
| gr.Markdown("### Example Questions:") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["What is the main topic of the document?"], | |
| ["Can you summarize the key points?"], | |
| ["What are the conclusions mentioned?"], | |
| ["Are there any specific dates or numbers mentioned?"] | |
| ], | |
| inputs=query_input | |
| ) | |
| # Event handlers | |
| set_key_btn.click( | |
| fn=set_api_key, | |
| inputs=[api_key_input], | |
| outputs=[api_key_status] | |
| ) | |
| process_btn.click( | |
| fn=process_files, | |
| inputs=[file_upload], | |
| outputs=[process_status] | |
| ) | |
| ask_btn.click( | |
| fn=answer_question, | |
| inputs=[query_input], | |
| outputs=[answer_output] | |
| ) | |
| # Allow Enter key to submit questions | |
| query_input.submit( | |
| fn=answer_question, | |
| inputs=[query_input], | |
| outputs=[answer_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |