SyedZainAliShah's picture
Update app.py
ca5f03e verified
import gradio as gr
import os
from groq import Groq
import PyPDF2
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json
from datetime import datetime
import docx
# Initialize Groq client
client = None
try:
api_key = os.environ.get("GROQ_API_KEY")
if api_key:
import httpx
client = Groq(api_key=api_key, http_client=httpx.Client())
print("Groq client initialized successfully")
except Exception as e:
print(f"Error initializing Groq client: {e}")
# Initialize sentence transformer model
print("Loading sentence transformer model...")
embedder = SentenceTransformer('all-MiniLM-L6-v2')
print("Model loaded successfully!")
# Global storage
document_store = {
'chunks': [],
'embeddings': [],
'metadata': [],
'conversation_history': []
}
def extract_text_from_pdf(pdf_file):
"""Extract text from PDF file"""
try:
if isinstance(pdf_file, str):
pdf_reader = PyPDF2.PdfReader(pdf_file)
filename = os.path.basename(pdf_file)
else:
pdf_reader = PyPDF2.PdfReader(pdf_file.name)
filename = os.path.basename(pdf_file.name)
text_data = []
for page_num, page in enumerate(pdf_reader.pages):
text = page.extract_text()
if text and text.strip():
text_data.append({
'text': text,
'page': page_num + 1,
'filename': filename
})
return text_data
except Exception as e:
print(f"Error reading PDF: {e}")
return []
def extract_text_from_docx(docx_file):
"""Extract text from DOCX file"""
try:
if isinstance(docx_file, str):
doc = docx.Document(docx_file)
filename = os.path.basename(docx_file)
else:
doc = docx.Document(docx_file.name)
filename = os.path.basename(docx_file.name)
text = '\n'.join([p.text for p in doc.paragraphs if p.text.strip()])
return [{'text': text, 'page': 1, 'filename': filename}]
except Exception as e:
print(f"Error reading DOCX: {e}")
return []
def chunk_text(text_data, chunk_size=500, overlap=50):
"""Split text into chunks"""
chunks = []
metadata = []
for data in text_data:
words = data['text'].split()
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if len(chunk.strip()) > 50:
chunks.append(chunk)
metadata.append({
'page': data['page'],
'filename': data['filename'],
'chunk_id': len(chunks)
})
return chunks, metadata
def process_files(files):
"""Process uploaded files"""
global document_store
if not files:
return "[ERROR] Please upload at least one file."
try:
document_store = {'chunks': [], 'embeddings': [], 'metadata': [], 'conversation_history': []}
all_text_data = []
file_summaries = []
for file in files:
file_path = file.name if hasattr(file, 'name') else file
file_ext = os.path.splitext(file_path)[1].lower()
print(f"Processing file: {file_path}")
if file_ext == '.pdf':
text_data = extract_text_from_pdf(file)
elif file_ext == '.docx':
text_data = extract_text_from_docx(file)
else:
continue
all_text_data.extend(text_data)
total_chars = sum(len(d['text']) for d in text_data)
filename = os.path.basename(file_path)
file_summaries.append(f"- **{filename}**: {len(text_data)} pages, {total_chars} characters")
if not all_text_data:
return "[ERROR] No valid text extracted."
chunks, metadata = chunk_text(all_text_data)
if not chunks:
return "[ERROR] No text chunks created."
embeddings = embedder.encode(chunks, show_progress_bar=False)
document_store['chunks'] = chunks
document_store['embeddings'] = embeddings
document_store['metadata'] = metadata
summary = f"**Successfully Processed {len(files)} file(s)**\n\n"
summary += "\n".join(file_summaries)
summary += f"\n\n**Created {len(chunks)} text chunks for retrieval.**"
return summary
except Exception as e:
print(f"Error processing files: {e}")
return f"[ERROR] {str(e)}"
def retrieve_relevant_chunks(query, top_k=3):
"""Retrieve relevant chunks"""
if not document_store['chunks']:
return [], []
try:
query_embedding = embedder.encode([query], show_progress_bar=False)
similarities = cosine_similarity(query_embedding, document_store['embeddings'])[0]
top_indices = np.argsort(similarities)[-top_k:][::-1]
relevant_chunks = [document_store['chunks'][i] for i in top_indices]
relevant_metadata = [document_store['metadata'][i] for i in top_indices]
return relevant_chunks, relevant_metadata
except Exception as e:
print(f"Error retrieving chunks: {e}")
return [], []
def chat(message, history):
"""Chat function - returns response string for ChatInterface"""
global client
# Reinitialize client if needed
if client is None:
try:
api_key = os.environ.get("GROQ_API_KEY")
if api_key:
import httpx
client = Groq(api_key=api_key, http_client=httpx.Client())
except:
pass
if client is None:
return "[ERROR] Groq API not initialized. Set GROQ_API_KEY in Settings."
if not document_store['chunks']:
return "[WARNING] Please upload and process documents first."
try:
# Retrieve context
relevant_chunks, metadata = retrieve_relevant_chunks(message, top_k=3)
if not relevant_chunks:
return "[ERROR] No relevant information found."
# Build context
context = "\n\n".join([
f"[Source: {meta['filename']}, Page {meta['page']}]\n{chunk}"
for chunk, meta in zip(relevant_chunks, metadata)
])
# Build messages for Groq
messages = [
{"role": "system", "content": "You are a helpful assistant that answers questions based on provided context. Be concise and accurate."}
]
# Add history - convert from tuples to message format
if history:
for user_msg, assistant_msg in history[-3:]: # Last 3 exchanges
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
# Add current query
messages.append({
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {message}"
})
# Call Groq
response = client.chat.completions.create(
messages=messages,
model="llama-3.1-8b-instant",
temperature=0.3,
max_tokens=1024,
)
answer = response.choices[0].message.content
# Add sources
sources = "\n\n**Sources:**\n" + "\n".join([
f"- {m['filename']} (Page {m['page']})" for m in metadata
])
full_answer = answer + sources
# Log
document_store['conversation_history'].append({
'timestamp': datetime.now().isoformat(),
'query': message,
'answer': answer
})
return full_answer
except Exception as e:
print(f"Error: {e}")
return f"[ERROR] {str(e)}"
def download_history():
"""Download chat history"""
if not document_store['conversation_history']:
return None
try:
with open("chat_history.json", 'w') as f:
json.dump(document_store['conversation_history'], f, indent=2)
return "chat_history.json"
except:
return None
# Build interface
with gr.Blocks(title="Enhanced RAG Chatbot") as demo:
gr.Markdown("""
# Enhanced RAG-Based Chatbot
Upload PDF/DOCX files and ask questions!
**Features:** Multiple files, Semantic search, Source references, Chat history
""")
with gr.Row():
with gr.Column(scale=1):
file_upload = gr.File(
label="Upload Documents (PDF/DOCX)",
file_count="multiple",
file_types=[".pdf", ".docx"]
)
process_btn = gr.Button("Process Documents", variant="primary")
process_output = gr.Markdown()
gr.Markdown("### History")
download_btn = gr.Button("Download (JSON)")
download_file = gr.File(label="Download")
with gr.Column(scale=2):
# Minimal ChatInterface compatible with Gradio 4.44.1
chat_interface = gr.ChatInterface(
fn=chat
)
# Process files
process_btn.click(process_files, [file_upload], [process_output])
# Download
download_btn.click(download_history, None, [download_file])
gr.Markdown("""
---
### How It Works:
1. Upload PDF/DOCX files and click "Process Documents"
2. Ask questions - RAG finds relevant chunks and generates answers
3. Sources are cited with page numbers
""")
if __name__ == "__main__":
demo.launch()