|
|
import gradio as gr |
|
|
import os |
|
|
from groq import Groq |
|
|
import PyPDF2 |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import json |
|
|
from datetime import datetime |
|
|
import docx |
|
|
|
|
|
|
|
|
client = None |
|
|
try: |
|
|
api_key = os.environ.get("GROQ_API_KEY") |
|
|
if api_key: |
|
|
import httpx |
|
|
client = Groq(api_key=api_key, http_client=httpx.Client()) |
|
|
print("Groq client initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"Error initializing Groq client: {e}") |
|
|
|
|
|
|
|
|
print("Loading sentence transformer model...") |
|
|
embedder = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
|
|
|
document_store = { |
|
|
'chunks': [], |
|
|
'embeddings': [], |
|
|
'metadata': [], |
|
|
'conversation_history': [] |
|
|
} |
|
|
|
|
|
def extract_text_from_pdf(pdf_file): |
|
|
"""Extract text from PDF file""" |
|
|
try: |
|
|
if isinstance(pdf_file, str): |
|
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
|
filename = os.path.basename(pdf_file) |
|
|
else: |
|
|
pdf_reader = PyPDF2.PdfReader(pdf_file.name) |
|
|
filename = os.path.basename(pdf_file.name) |
|
|
|
|
|
text_data = [] |
|
|
for page_num, page in enumerate(pdf_reader.pages): |
|
|
text = page.extract_text() |
|
|
if text and text.strip(): |
|
|
text_data.append({ |
|
|
'text': text, |
|
|
'page': page_num + 1, |
|
|
'filename': filename |
|
|
}) |
|
|
return text_data |
|
|
except Exception as e: |
|
|
print(f"Error reading PDF: {e}") |
|
|
return [] |
|
|
|
|
|
def extract_text_from_docx(docx_file): |
|
|
"""Extract text from DOCX file""" |
|
|
try: |
|
|
if isinstance(docx_file, str): |
|
|
doc = docx.Document(docx_file) |
|
|
filename = os.path.basename(docx_file) |
|
|
else: |
|
|
doc = docx.Document(docx_file.name) |
|
|
filename = os.path.basename(docx_file.name) |
|
|
|
|
|
text = '\n'.join([p.text for p in doc.paragraphs if p.text.strip()]) |
|
|
return [{'text': text, 'page': 1, 'filename': filename}] |
|
|
except Exception as e: |
|
|
print(f"Error reading DOCX: {e}") |
|
|
return [] |
|
|
|
|
|
def chunk_text(text_data, chunk_size=500, overlap=50): |
|
|
"""Split text into chunks""" |
|
|
chunks = [] |
|
|
metadata = [] |
|
|
|
|
|
for data in text_data: |
|
|
words = data['text'].split() |
|
|
for i in range(0, len(words), chunk_size - overlap): |
|
|
chunk = ' '.join(words[i:i + chunk_size]) |
|
|
if len(chunk.strip()) > 50: |
|
|
chunks.append(chunk) |
|
|
metadata.append({ |
|
|
'page': data['page'], |
|
|
'filename': data['filename'], |
|
|
'chunk_id': len(chunks) |
|
|
}) |
|
|
|
|
|
return chunks, metadata |
|
|
|
|
|
def process_files(files): |
|
|
"""Process uploaded files""" |
|
|
global document_store |
|
|
|
|
|
if not files: |
|
|
return "[ERROR] Please upload at least one file." |
|
|
|
|
|
try: |
|
|
document_store = {'chunks': [], 'embeddings': [], 'metadata': [], 'conversation_history': []} |
|
|
|
|
|
all_text_data = [] |
|
|
file_summaries = [] |
|
|
|
|
|
for file in files: |
|
|
file_path = file.name if hasattr(file, 'name') else file |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
print(f"Processing file: {file_path}") |
|
|
|
|
|
if file_ext == '.pdf': |
|
|
text_data = extract_text_from_pdf(file) |
|
|
elif file_ext == '.docx': |
|
|
text_data = extract_text_from_docx(file) |
|
|
else: |
|
|
continue |
|
|
|
|
|
all_text_data.extend(text_data) |
|
|
total_chars = sum(len(d['text']) for d in text_data) |
|
|
filename = os.path.basename(file_path) |
|
|
file_summaries.append(f"- **{filename}**: {len(text_data)} pages, {total_chars} characters") |
|
|
|
|
|
if not all_text_data: |
|
|
return "[ERROR] No valid text extracted." |
|
|
|
|
|
chunks, metadata = chunk_text(all_text_data) |
|
|
if not chunks: |
|
|
return "[ERROR] No text chunks created." |
|
|
|
|
|
embeddings = embedder.encode(chunks, show_progress_bar=False) |
|
|
|
|
|
document_store['chunks'] = chunks |
|
|
document_store['embeddings'] = embeddings |
|
|
document_store['metadata'] = metadata |
|
|
|
|
|
summary = f"**Successfully Processed {len(files)} file(s)**\n\n" |
|
|
summary += "\n".join(file_summaries) |
|
|
summary += f"\n\n**Created {len(chunks)} text chunks for retrieval.**" |
|
|
|
|
|
return summary |
|
|
except Exception as e: |
|
|
print(f"Error processing files: {e}") |
|
|
return f"[ERROR] {str(e)}" |
|
|
|
|
|
def retrieve_relevant_chunks(query, top_k=3): |
|
|
"""Retrieve relevant chunks""" |
|
|
if not document_store['chunks']: |
|
|
return [], [] |
|
|
|
|
|
try: |
|
|
query_embedding = embedder.encode([query], show_progress_bar=False) |
|
|
similarities = cosine_similarity(query_embedding, document_store['embeddings'])[0] |
|
|
top_indices = np.argsort(similarities)[-top_k:][::-1] |
|
|
|
|
|
relevant_chunks = [document_store['chunks'][i] for i in top_indices] |
|
|
relevant_metadata = [document_store['metadata'][i] for i in top_indices] |
|
|
|
|
|
return relevant_chunks, relevant_metadata |
|
|
except Exception as e: |
|
|
print(f"Error retrieving chunks: {e}") |
|
|
return [], [] |
|
|
|
|
|
def chat(message, history): |
|
|
"""Chat function - returns response string for ChatInterface""" |
|
|
global client |
|
|
|
|
|
|
|
|
if client is None: |
|
|
try: |
|
|
api_key = os.environ.get("GROQ_API_KEY") |
|
|
if api_key: |
|
|
import httpx |
|
|
client = Groq(api_key=api_key, http_client=httpx.Client()) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if client is None: |
|
|
return "[ERROR] Groq API not initialized. Set GROQ_API_KEY in Settings." |
|
|
|
|
|
if not document_store['chunks']: |
|
|
return "[WARNING] Please upload and process documents first." |
|
|
|
|
|
try: |
|
|
|
|
|
relevant_chunks, metadata = retrieve_relevant_chunks(message, top_k=3) |
|
|
|
|
|
if not relevant_chunks: |
|
|
return "[ERROR] No relevant information found." |
|
|
|
|
|
|
|
|
context = "\n\n".join([ |
|
|
f"[Source: {meta['filename']}, Page {meta['page']}]\n{chunk}" |
|
|
for chunk, meta in zip(relevant_chunks, metadata) |
|
|
]) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a helpful assistant that answers questions based on provided context. Be concise and accurate."} |
|
|
] |
|
|
|
|
|
|
|
|
if history: |
|
|
for user_msg, assistant_msg in history[-3:]: |
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
|
|
|
messages.append({ |
|
|
"role": "user", |
|
|
"content": f"Context:\n{context}\n\nQuestion: {message}" |
|
|
}) |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
messages=messages, |
|
|
model="llama-3.1-8b-instant", |
|
|
temperature=0.3, |
|
|
max_tokens=1024, |
|
|
) |
|
|
|
|
|
answer = response.choices[0].message.content |
|
|
|
|
|
|
|
|
sources = "\n\n**Sources:**\n" + "\n".join([ |
|
|
f"- {m['filename']} (Page {m['page']})" for m in metadata |
|
|
]) |
|
|
|
|
|
full_answer = answer + sources |
|
|
|
|
|
|
|
|
document_store['conversation_history'].append({ |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
'query': message, |
|
|
'answer': answer |
|
|
}) |
|
|
|
|
|
return full_answer |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
return f"[ERROR] {str(e)}" |
|
|
|
|
|
def download_history(): |
|
|
"""Download chat history""" |
|
|
if not document_store['conversation_history']: |
|
|
return None |
|
|
|
|
|
try: |
|
|
with open("chat_history.json", 'w') as f: |
|
|
json.dump(document_store['conversation_history'], f, indent=2) |
|
|
return "chat_history.json" |
|
|
except: |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Enhanced RAG Chatbot") as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# Enhanced RAG-Based Chatbot |
|
|
Upload PDF/DOCX files and ask questions! |
|
|
|
|
|
**Features:** Multiple files, Semantic search, Source references, Chat history |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
file_upload = gr.File( |
|
|
label="Upload Documents (PDF/DOCX)", |
|
|
file_count="multiple", |
|
|
file_types=[".pdf", ".docx"] |
|
|
) |
|
|
process_btn = gr.Button("Process Documents", variant="primary") |
|
|
process_output = gr.Markdown() |
|
|
|
|
|
gr.Markdown("### History") |
|
|
download_btn = gr.Button("Download (JSON)") |
|
|
download_file = gr.File(label="Download") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
|
fn=chat |
|
|
) |
|
|
|
|
|
|
|
|
process_btn.click(process_files, [file_upload], [process_output]) |
|
|
|
|
|
|
|
|
download_btn.click(download_history, None, [download_file]) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### How It Works: |
|
|
1. Upload PDF/DOCX files and click "Process Documents" |
|
|
2. Ask questions - RAG finds relevant chunks and generates answers |
|
|
3. Sources are cited with page numbers |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |