|
|
import os |
|
|
import gradio as gr |
|
|
import PyPDF2 |
|
|
import docx |
|
|
import pandas as pd |
|
|
from typing import List, Dict, Any |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import re |
|
|
from groq import Groq |
|
|
import json |
|
|
import tempfile |
|
|
import io |
|
|
|
|
|
class RAGApplication: |
|
|
def __init__(self): |
|
|
"""Initialize the RAG application with necessary components""" |
|
|
|
|
|
self.groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
|
|
|
|
|
|
|
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
|
|
|
self.dimension = 384 |
|
|
self.index = faiss.IndexFlatIP(self.dimension) |
|
|
|
|
|
|
|
|
self.chunks = [] |
|
|
self.chunk_metadata = [] |
|
|
self.is_indexed = False |
|
|
|
|
|
def extract_text_from_file(self, file_path: str, file_type: str) -> str: |
|
|
"""Extract text from different file types""" |
|
|
text = "" |
|
|
|
|
|
try: |
|
|
if file_type == "pdf": |
|
|
with open(file_path, 'rb') as file: |
|
|
pdf_reader = PyPDF2.PdfReader(file) |
|
|
for page in pdf_reader.pages: |
|
|
text += page.extract_text() + "\n" |
|
|
|
|
|
elif file_type == "docx": |
|
|
doc = docx.Document(file_path) |
|
|
for paragraph in doc.paragraphs: |
|
|
text += paragraph.text + "\n" |
|
|
|
|
|
elif file_type == "txt": |
|
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
|
text = file.read() |
|
|
|
|
|
elif file_type in ["csv", "xlsx"]: |
|
|
if file_type == "csv": |
|
|
df = pd.read_csv(file_path) |
|
|
else: |
|
|
df = pd.read_excel(file_path) |
|
|
|
|
|
|
|
|
text = df.to_string(index=False) |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error reading file: {str(e)}" |
|
|
|
|
|
return text |
|
|
|
|
|
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: |
|
|
"""Split text into overlapping chunks""" |
|
|
if not text.strip(): |
|
|
return [] |
|
|
|
|
|
|
|
|
text = re.sub(r'\s+', ' ', text.strip()) |
|
|
|
|
|
|
|
|
sentences = re.split(r'[.!?]+', text) |
|
|
|
|
|
chunks = [] |
|
|
current_chunk = "" |
|
|
|
|
|
for sentence in sentences: |
|
|
sentence = sentence.strip() |
|
|
if not sentence: |
|
|
continue |
|
|
|
|
|
|
|
|
if len(current_chunk) + len(sentence) > chunk_size and current_chunk: |
|
|
chunks.append(current_chunk.strip()) |
|
|
|
|
|
|
|
|
words = current_chunk.split() |
|
|
overlap_text = ' '.join(words[-overlap:]) if len(words) > overlap else current_chunk |
|
|
current_chunk = overlap_text + " " + sentence |
|
|
else: |
|
|
current_chunk += " " + sentence if current_chunk else sentence |
|
|
|
|
|
|
|
|
if current_chunk.strip(): |
|
|
chunks.append(current_chunk.strip()) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def create_embeddings(self, chunks: List[str]) -> np.ndarray: |
|
|
"""Create embeddings for text chunks""" |
|
|
if not chunks: |
|
|
return np.array([]) |
|
|
|
|
|
embeddings = self.embedding_model.encode(chunks, convert_to_tensor=False) |
|
|
return embeddings |
|
|
|
|
|
def build_index(self, files) -> str: |
|
|
"""Process uploaded files and build the search index""" |
|
|
if not files: |
|
|
return "β No files uploaded. Please upload at least one file." |
|
|
|
|
|
try: |
|
|
|
|
|
self.chunks = [] |
|
|
self.chunk_metadata = [] |
|
|
self.index = faiss.IndexFlatIP(self.dimension) |
|
|
|
|
|
all_chunks = [] |
|
|
processing_status = [] |
|
|
|
|
|
for file in files: |
|
|
file_name = file.name |
|
|
file_extension = file_name.split('.')[-1].lower() |
|
|
|
|
|
|
|
|
text = self.extract_text_from_file(file.name, file_extension) |
|
|
|
|
|
if text.startswith("Error"): |
|
|
processing_status.append(f"β {file_name}: {text}") |
|
|
continue |
|
|
|
|
|
|
|
|
file_chunks = self.chunk_text(text) |
|
|
|
|
|
if not file_chunks: |
|
|
processing_status.append(f"β {file_name}: No text could be extracted") |
|
|
continue |
|
|
|
|
|
|
|
|
for i, chunk in enumerate(file_chunks): |
|
|
self.chunk_metadata.append({ |
|
|
'file_name': file_name, |
|
|
'chunk_id': i, |
|
|
'chunk_text': chunk |
|
|
}) |
|
|
all_chunks.append(chunk) |
|
|
|
|
|
processing_status.append(f"β
{file_name}: {len(file_chunks)} chunks created") |
|
|
|
|
|
if not all_chunks: |
|
|
return "β No valid text chunks were created from the uploaded files." |
|
|
|
|
|
|
|
|
embeddings = self.create_embeddings(all_chunks) |
|
|
|
|
|
|
|
|
faiss.normalize_L2(embeddings) |
|
|
|
|
|
|
|
|
self.index.add(embeddings) |
|
|
self.chunks = all_chunks |
|
|
self.is_indexed = True |
|
|
|
|
|
status_report = "\n".join(processing_status) |
|
|
summary = f"\n\nπ **Summary:**\n- Total chunks created: {len(all_chunks)}\n- Index built successfully!\n- Ready to answer questions!" |
|
|
|
|
|
return f"**File Processing Results:**\n\n{status_report}{summary}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error during indexing: {str(e)}" |
|
|
|
|
|
def search_similar_chunks(self, query: str, top_k: int = 5) -> List[Dict]: |
|
|
"""Search for similar chunks using vector similarity""" |
|
|
if not self.is_indexed: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
query_embedding = self.embedding_model.encode([query]) |
|
|
faiss.normalize_L2(query_embedding) |
|
|
|
|
|
|
|
|
scores, indices = self.index.search(query_embedding, top_k) |
|
|
|
|
|
results = [] |
|
|
for score, idx in zip(scores[0], indices[0]): |
|
|
if idx < len(self.chunk_metadata): |
|
|
results.append({ |
|
|
'chunk': self.chunks[idx], |
|
|
'metadata': self.chunk_metadata[idx], |
|
|
'similarity_score': float(score) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Search error: {e}") |
|
|
return [] |
|
|
|
|
|
def generate_response(self, query: str, context_chunks: List[str]) -> str: |
|
|
"""Generate response using Groq API with context""" |
|
|
try: |
|
|
|
|
|
context = "\n\n".join([f"Context {i+1}:\n{chunk}" for i, chunk in enumerate(context_chunks)]) |
|
|
|
|
|
|
|
|
prompt = f"""Based on the following context information, please answer the user's question. If the answer cannot be found in the context, please say so clearly. |
|
|
|
|
|
Context Information: |
|
|
{context} |
|
|
|
|
|
Question: {query} |
|
|
|
|
|
Please provide a comprehensive and accurate answer based on the context provided above.""" |
|
|
|
|
|
|
|
|
chat_completion = self.groq_client.chat.completions.create( |
|
|
messages=[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "You are a helpful assistant that answers questions based on provided context. Always cite which part of the context supports your answer." |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt, |
|
|
} |
|
|
], |
|
|
model="llama-3.3-70b-versatile", |
|
|
temperature=0.3, |
|
|
max_tokens=1000 |
|
|
) |
|
|
|
|
|
return chat_completion.choices[0].message.content |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error generating response: {str(e)}" |
|
|
|
|
|
def query_documents(self, query: str, top_k: int = 5) -> tuple: |
|
|
"""Main function to query the documents""" |
|
|
if not query.strip(): |
|
|
return "Please enter a question.", "" |
|
|
|
|
|
if not self.is_indexed: |
|
|
return "Please upload and index some documents first.", "" |
|
|
|
|
|
|
|
|
similar_chunks = self.search_similar_chunks(query, top_k) |
|
|
|
|
|
if not similar_chunks: |
|
|
return "No relevant information found in the documents.", "" |
|
|
|
|
|
|
|
|
context_chunks = [chunk_data['chunk'] for chunk_data in similar_chunks] |
|
|
response = self.generate_response(query, context_chunks) |
|
|
|
|
|
|
|
|
sources = "\n\nπ **Sources:**\n" |
|
|
for i, chunk_data in enumerate(similar_chunks): |
|
|
file_name = chunk_data['metadata']['file_name'] |
|
|
similarity = chunk_data['similarity_score'] |
|
|
sources += f"- **Source {i+1}:** {file_name} (Similarity: {similarity:.3f})\n" |
|
|
|
|
|
return response, sources |
|
|
|
|
|
|
|
|
rag_app = RAGApplication() |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
|
} |
|
|
|
|
|
.main-header { |
|
|
text-align: center; |
|
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
padding: 2rem; |
|
|
border-radius: 10px; |
|
|
margin-bottom: 2rem; |
|
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
|
|
|
.upload-area { |
|
|
border: 2px dashed #667eea; |
|
|
border-radius: 10px; |
|
|
padding: 2rem; |
|
|
text-align: center; |
|
|
background: #f8f9ff; |
|
|
} |
|
|
|
|
|
.chat-container { |
|
|
background: #ffffff; |
|
|
border-radius: 10px; |
|
|
padding: 1rem; |
|
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
|
|
|
#component-0 { |
|
|
border-radius: 15px; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
with gr.Blocks(css=custom_css, title="π€ RAG Document Assistant") as interface: |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="main-header"> |
|
|
<h1>π€ RAG Document Assistant</h1> |
|
|
<p>Upload your documents and ask questions - powered by AI!</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.HTML("<h3>π Document Upload</h3>") |
|
|
|
|
|
file_upload = gr.File( |
|
|
label="Upload Documents", |
|
|
file_types=[".pdf", ".docx", ".txt", ".csv", ".xlsx"], |
|
|
file_count="multiple", |
|
|
height=200 |
|
|
) |
|
|
|
|
|
upload_btn = gr.Button( |
|
|
"π Process Documents", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
upload_status = gr.Textbox( |
|
|
label="Processing Status", |
|
|
lines=8, |
|
|
interactive=False, |
|
|
placeholder="Upload documents and click 'Process Documents' to begin..." |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.HTML("<h3>π¬ Ask Questions</h3>") |
|
|
|
|
|
with gr.Row(): |
|
|
query_input = gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="Ask anything about your uploaded documents...", |
|
|
lines=2, |
|
|
scale=4 |
|
|
) |
|
|
ask_btn = gr.Button("Ask", variant="primary", scale=1) |
|
|
|
|
|
response_output = gr.Textbox( |
|
|
label="AI Response", |
|
|
lines=10, |
|
|
interactive=False, |
|
|
placeholder="AI responses will appear here..." |
|
|
) |
|
|
|
|
|
sources_output = gr.Textbox( |
|
|
label="Sources", |
|
|
lines=5, |
|
|
interactive=False, |
|
|
placeholder="Source information will appear here..." |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;"> |
|
|
<h4>π‘ Example Questions:</h4> |
|
|
<ul> |
|
|
<li>"What are the main topics discussed in the document?"</li> |
|
|
<li>"Can you summarize the key findings?"</li> |
|
|
<li>"What recommendations are provided?"</li> |
|
|
<li>"Tell me about [specific topic] mentioned in the documents"</li> |
|
|
</ul> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
upload_btn.click( |
|
|
fn=rag_app.build_index, |
|
|
inputs=[file_upload], |
|
|
outputs=[upload_status] |
|
|
) |
|
|
|
|
|
ask_btn.click( |
|
|
fn=rag_app.query_documents, |
|
|
inputs=[query_input], |
|
|
outputs=[response_output, sources_output] |
|
|
) |
|
|
|
|
|
|
|
|
query_input.submit( |
|
|
fn=rag_app.query_documents, |
|
|
inputs=[query_input], |
|
|
outputs=[response_output, sources_output] |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface = create_interface() |
|
|
interface.launch( |
|
|
share=True, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860 |
|
|
) |