Miraj74's picture
Update app.py
a06da41 verified
import gradio as gr
import PyPDF2
import io
from together import Together
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms.base import LLM
from typing import List, Optional
import traceback
# ---------------------------
# WRAP TOGETHER API AS LLM
# ---------------------------
class TogetherLLM(LLM):
client: Together = None
model: str = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
temperature: float = 0.3
max_tokens: int = 1000
def __init__(self, client, model="meta-llama/Llama-3.3-70B-Instruct-Turbo", temperature=0.3, max_tokens=1000, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'client', client)
object.__setattr__(self, 'model', model)
object.__setattr__(self, 'temperature', temperature)
object.__setattr__(self, 'max_tokens', max_tokens)
@property
def _llm_type(self) -> str:
return "together-llm"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
class Config:
arbitrary_types_allowed = True
# ---------------------------
# PDF TEXT EXTRACTION
# ---------------------------
def extract_text_from_pdf(pdf_file):
"""Extract text from PDF with page references"""
docs = []
try:
print("Starting PDF extraction...")
# Handle different input types
if hasattr(pdf_file, 'name'):
# File uploaded through Gradio
with open(pdf_file.name, 'rb') as file:
pdf_content = file.read()
elif hasattr(pdf_file, "read"):
pdf_content = pdf_file.read()
if hasattr(pdf_file, "seek"):
pdf_file.seek(0)
else:
pdf_content = pdf_file
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content))
print(f"PDF has {len(pdf_reader.pages)} pages")
for page_num, page in enumerate(pdf_reader.pages, start=1):
try:
page_text = page.extract_text()
if page_text and page_text.strip():
docs.append(Document(
page_content=page_text.strip(),
metadata={"page": page_num, "source": "financial_policy"}
))
print(f"Extracted text from page {page_num}: {len(page_text)} characters")
else:
docs.append(Document(
page_content="[No extractable text found on this page]",
metadata={"page": page_num, "source": "financial_policy"}
))
except Exception as e:
print(f"Error extracting page {page_num}: {str(e)}")
docs.append(Document(
page_content=f"[Error extracting page {page_num}: {str(e)}]",
metadata={"page": page_num, "source": "financial_policy"}
))
print(f"Total documents extracted: {len(docs)}")
return docs
except Exception as e:
print(f"Error in PDF extraction: {str(e)}")
traceback.print_exc()
return [Document(page_content=f"Error extracting text: {str(e)}", metadata={"page": -1})]
# ---------------------------
# BUILD KNOWLEDGE BASE (FAISS)
# ---------------------------
def build_vector_db(docs):
"""Convert extracted documents into FAISS vector DB"""
try:
print("Building vector database...")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
separators=["\n\n", "\n", ". ", " ", ""]
)
split_docs = text_splitter.split_documents(docs)
print(f"Split into {len(split_docs)} chunks")
# Initialize embeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
print("Embeddings model loaded")
# Create FAISS database
db = FAISS.from_documents(split_docs, embeddings)
print("Vector database created successfully")
return db
except Exception as e:
print(f"Error building vector database: {str(e)}")
traceback.print_exc()
return None
# ---------------------------
# CHATBOT PIPELINE
# ---------------------------
def create_chatbot(api_key, db):
"""Set up ConversationalRetrievalChain with memory"""
try:
print("Creating chatbot...")
client = Together(api_key=api_key)
llm = TogetherLLM(client=client)
retriever = db.as_retriever(
search_type="similarity",
search_kwargs={"k": 4}
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
return_source_documents=True,
verbose=True,
)
print("Chatbot created successfully")
return qa_chain
except Exception as e:
print(f"Error creating chatbot: {str(e)}")
traceback.print_exc()
return None
# ---------------------------
# GRADIO APP
# ---------------------------
def create_app():
with gr.Blocks(title="๐Ÿ“Š Financial Policy Document Chatbot", theme=gr.themes.Soft()) as app:
gr.Markdown("# ๐Ÿ“Š Financial Policy Document Chatbot")
gr.Markdown("""
Upload a financial policy PDF document and ask questions about its content.
The chatbot will provide answers with page references from the document.
""")
with gr.Row():
with gr.Column(scale=1):
api_key_input = gr.Textbox(
label="Together API Key",
placeholder="Enter your Together API key here...",
type="password",
)
pdf_file = gr.File(
label="Upload Financial Policy PDF",
file_types=[".pdf"],
)
process_button = gr.Button("๐Ÿ“„ Process PDF", variant="primary")
status_message = gr.Textbox(label="Status", interactive=False, lines=3)
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat with Financial Policy Document", height=500)
with gr.Row():
question = gr.Textbox(
label="Ask a question about the document",
placeholder="Example: What is the budget allocation for infrastructure?",
lines=2,
scale=4
)
submit_button = gr.Button("๐Ÿ” Ask", variant="secondary", scale=1)
gr.Markdown("""
**Sample Questions:**
- What is the debt policy outlined in the document?
- How much budget is allocated for infrastructure?
- What are the revenue sources mentioned?
- What are the key financial objectives?
""")
# State variables
db_state = gr.State()
qa_chain_state = gr.State()
# Event handlers
def process_pdf_handler(pdf_file, api_key):
try:
if pdf_file is None:
return "โš ๏ธ Please upload a PDF file.", None, None
if not api_key or api_key.strip() == "":
return "โš ๏ธ Please enter your Together API key.", None, None
status_msg = "๐Ÿ”„ Processing PDF... This may take a few moments."
yield status_msg, None, None
# Extract text from PDF
docs = extract_text_from_pdf(pdf_file)
if not docs or len(docs) == 0:
yield "โš ๏ธ No text could be extracted from the PDF.", None, None
return
# Check if extraction was successful
valid_docs = [doc for doc in docs if not doc.page_content.startswith("[Error") and not doc.page_content.startswith("[No extractable")]
if len(valid_docs) == 0:
yield "โš ๏ธ No readable text found in the PDF.", None, None
return
status_msg = f"๐Ÿ“„ Extracted text from {len(docs)} pages. Building search database..."
yield status_msg, None, None
# Build vector database
db = build_vector_db(docs)
if db is None:
yield "โš ๏ธ Failed to build search database.", None, None
return
status_msg = f"๐Ÿ” Search database created. Setting up chatbot..."
yield status_msg, None, None
# Create chatbot
qa_chain = create_chatbot(api_key, db)
if qa_chain is None:
yield "โš ๏ธ Failed to create chatbot.", None, None
return
final_status = f"โœ… Successfully processed PDF with {len(docs)} pages. Ready to answer questions!"
yield final_status, db, qa_chain
except Exception as e:
error_msg = f"โŒ Error processing PDF: {str(e)}"
print(f"Process PDF Error: {str(e)}")
traceback.print_exc()
yield error_msg, None, None
def chat_handler(user_question, qa_chain, history):
if not user_question or user_question.strip() == "":
return history, history, ""
if qa_chain is None:
return history + [(user_question, "โš ๏ธ Please process a PDF document first.")], history, ""
try:
# Get response from the chain
result = qa_chain({"question": user_question})
answer = result["answer"]
# Add source references
if "source_documents" in result and result["source_documents"]:
pages = []
for doc in result["source_documents"]:
if "page" in doc.metadata:
pages.append(doc.metadata["page"])
if pages:
unique_pages = sorted(set(pages))
if len(unique_pages) == 1:
answer += f"\n\n๐Ÿ“Œ **Reference:** Page {unique_pages[0]}"
else:
answer += f"\n\n๐Ÿ“Œ **References:** Pages {', '.join(map(str, unique_pages))}"
new_history = history + [(user_question, answer)]
return new_history, new_history, ""
except Exception as e:
error_response = f"โŒ Error processing question: {str(e)}"
print(f"Chat Error: {str(e)}")
traceback.print_exc()
return history + [(user_question, error_response)], history, ""
def clear_input():
return ""
# Bind events
process_button.click(
fn=process_pdf_handler,
inputs=[pdf_file, api_key_input],
outputs=[status_message, db_state, qa_chain_state],
)
submit_button.click(
fn=chat_handler,
inputs=[question, qa_chain_state, chatbot],
outputs=[chatbot, chatbot, question],
)
question.submit(
fn=chat_handler,
inputs=[question, qa_chain_state, chatbot],
outputs=[chatbot, chatbot, question],
)
return app
# ---------------------------
# MAIN EXECUTION
# ---------------------------
if __name__ == "__main__":
app = create_app()
app.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
debug=True
)