agentic_rag / app.py
khaledsayed1's picture
Update app.py
5ff4e8e verified
import os
import gradio as gr
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.tools.retriever import create_retriever_tool
from langgraph.graph import MessagesState, StateGraph, START, END
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import ToolNode, tools_condition
from pydantic import BaseModel, Field
from typing import Literal
from langchain_core.messages import HumanMessage, AIMessage
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Get API key from Hugging Face Spaces secrets
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if GOOGLE_API_KEY:
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
else:
logger.warning("GOOGLE_API_KEY not found in environment variables")
class LegalConsultingBot:
def __init__(self):
self.graph = None
self.retriever_tool = None
self.response_model = None
self.grader_model = None
self.initialize_bot()
def initialize_bot(self):
"""Initialize the bot with error handling."""
try:
if not GOOGLE_API_KEY:
logger.error("Google API key not available")
return
self.setup_models()
self.setup_retriever()
self.setup_workflow()
logger.info("Bot initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize bot: {e}")
def setup_models(self):
"""Initialize the language models."""
try:
self.response_model = init_chat_model(
"gemini-2.0-flash",
model_provider="google_genai",
temperature=0
)
self.grader_model = init_chat_model(
"gemini-2.0-flash",
model_provider="google_genai",
temperature=0
)
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize models: {e}")
raise
def setup_retriever(self):
"""Initialize the document retriever with legal consulting URLs."""
urls = [
"https://firststepslegal.co.uk/",
"https://sprintlaw.co.uk/",
"https://www.smbs.solutions/legal-and-compliance-resources-legal-assistance-for-businesses",
"https://ignition.law/lawyers-for-smes/",
"https://www.cocredo.co.uk/news/free-legal-advice-small-business-owners",
"https://www.gannons.co.uk/sectors/smes/",
"https://kkbservices.com/who-we-work-with/small-businesses/",
"https://farringfordlegal.co.uk/",
"https://stanislawlegal.com/en/legal-solutions/for-sme-companies/",
"https://www.lawhive.co.uk/small-business/",
"https://www.catalystlaw.co.uk/business-legal-advice.html",
"https://medium.com/@kmitsme123/legal-consulting-tips-for-your-small-business-9075005eb574",
"https://smecomply.co.uk/",
"https://dojobusiness.com/blogs/news/legal-consultant-complete-guide",
"https://englishlegaladvice.com/",
]
try:
# Load documents with error handling
docs = []
successful_loads = 0
for url in urls:
try:
loader = WebBaseLoader(url)
docs.extend(loader.load())
successful_loads += 1
logger.info(f"Successfully loaded: {url}")
except Exception as e:
logger.warning(f"Failed to load {url}: {e}")
continue
logger.info(f"Successfully loaded {successful_loads}/{len(urls)} URLs")
if not docs:
logger.warning("No documents could be loaded, using fallback mode")
return
# Split documents
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=300, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs)
logger.info(f"Created {len(doc_splits)} document chunks")
# Create embeddings and vectorstore
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vectorstore = InMemoryVectorStore.from_documents(
documents=doc_splits,
embedding=embeddings
)
# Create retriever tool
retriever = vectorstore.as_retriever()
self.retriever_tool = create_retriever_tool(
retriever,
name="legal_consulting_retriever",
description="Search and return relevant legal information and consulting resources for small and medium-sized businesses."
)
logger.info("Retriever tool created successfully")
except Exception as e:
logger.error(f"Error setting up retriever: {e}")
self.retriever_tool = None
def setup_workflow(self):
"""Set up the LangGraph workflow."""
try:
if not self.response_model:
logger.error("Response model not available")
return
# Create workflow
workflow = StateGraph(MessagesState)
# Add nodes
workflow.add_node("generate_query_or_respond", self.generate_query_or_respond)
if self.retriever_tool:
workflow.add_node("retrieve", ToolNode([self.retriever_tool]))
workflow.add_node("grade_documents", self.grade_documents_node)
workflow.add_node("rewrite_question", self.rewrite_question)
workflow.add_node("generate_answer", self.generate_answer)
# Add edges
workflow.add_edge(START, "generate_query_or_respond")
if self.retriever_tool:
workflow.add_conditional_edges(
"generate_query_or_respond",
tools_condition,
{
"tools": "retrieve",
END: END,
},
)
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
lambda x: x.get("grade_result", "generate_answer"),
{
"generate_answer": "generate_answer",
"rewrite_question": "rewrite_question"
}
)
else:
workflow.add_edge("generate_query_or_respond", END)
workflow.add_edge("generate_answer", END)
workflow.add_edge("rewrite_question", "generate_query_or_respond")
self.graph = workflow.compile()
logger.info("Workflow compiled successfully")
except Exception as e:
logger.error(f"Error setting up workflow: {e}")
self.graph = None
def generate_query_or_respond(self, state: MessagesState):
"""Generate query or respond directly."""
try:
if not self.retriever_tool:
# Fallback response when retriever is not available
fallback_response = """I'm a legal consulting assistant for small and medium enterprises. While my document retriever is currently unavailable, I can still help answer general questions about:
- Business formation and structure
- Contract basics and employment law
- Intellectual property fundamentals
- Compliance and regulatory matters
- General legal considerations for SMEs
Please note: This is general information only, not legal advice. Always consult qualified legal professionals for specific matters."""
return {"messages": [AIMessage(content=fallback_response)]}
response = (
self.response_model
.bind_tools([self.retriever_tool])
.invoke(state["messages"])
)
return {"messages": [response]}
except Exception as e:
logger.error(f"Error in generate_query_or_respond: {e}")
return {"messages": [AIMessage(content="I'm sorry, I encountered an error. Please try again.")]}
class GradeDocuments(BaseModel):
"""Grade documents using a binary score for relevance check."""
binary_score: str = Field(
description="Relevance score: 'yes' if relevant, or 'no' if not relevant"
)
def grade_documents_node(self, state: MessagesState):
"""Node wrapper for document grading."""
try:
question = state["messages"][0].content
context = state["messages"][-1].content if state["messages"] else ""
GRADE_PROMPT = (
"You are a grader assessing relevance of a retrieved document to a user question. \n "
"Here is the retrieved document: \n\n {context} \n\n"
"Here is the user question: {question} \n"
"If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n"
"Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."
)
prompt = GRADE_PROMPT.format(question=question, context=context)
response = (
self.grader_model
.with_structured_output(self.GradeDocuments)
.invoke([{"role": "user", "content": prompt}])
)
grade_result = "generate_answer" if response.binary_score == "yes" else "rewrite_question"
return {"grade_result": grade_result}
except Exception as e:
logger.error(f"Error in grade_documents_node: {e}")
return {"grade_result": "generate_answer"} # Default to generating answer
def rewrite_question(self, state: MessagesState):
"""Rewrite the original user question."""
try:
messages = state["messages"]
question = messages[0].content if messages else ""
REWRITE_PROMPT = (
"Look at the input and try to reason about the underlying semantic intent / meaning.\n"
"Here is the initial question:"
"\n ------- \n"
"{question}"
"\n ------- \n"
"Formulate an improved question that would be better for searching legal consulting information:"
)
prompt = REWRITE_PROMPT.format(question=question)
response = self.response_model.invoke([{"role": "user", "content": prompt}])
return {"messages": [HumanMessage(content=response.content)]}
except Exception as e:
logger.error(f"Error in rewrite_question: {e}")
# Return original message if rewriting fails
return {"messages": state["messages"][:1] if state["messages"] else []}
def generate_answer(self, state: MessagesState):
"""Generate an answer based on retrieved context."""
try:
question = state["messages"][0].content if state["messages"] else ""
context = state["messages"][-1].content if len(state["messages"]) > 1 else ""
GENERATE_PROMPT = (
"You are an assistant for question-answering tasks about legal consulting for small and medium businesses. "
"Use the following pieces of retrieved context to answer the question. "
"If you don't know the answer, just say that you don't know. "
"Keep the answer concise but informative. Always remind users that this is general information and not legal advice.\n"
"Question: {question} \n"
"Context: {context}"
)
prompt = GENERATE_PROMPT.format(question=question, context=context)
response = self.response_model.invoke([{"role": "user", "content": prompt}])
return {"messages": [response]}
except Exception as e:
logger.error(f"Error in generate_answer: {e}")
return {"messages": [AIMessage(content="I'm sorry, I encountered an error while generating the answer. Please try again.")]}
def chat(self, message: str, history: list) -> str:
"""Main chat function for Gradio interface."""
try:
if not message or not message.strip():
return "Please enter a question."
if not self.response_model:
return "Sorry, the system is not properly initialized. Please check if the API key is configured correctly."
# For simple cases without retriever, provide direct response
if not self.graph:
prompt = f"""You are a helpful assistant specializing in legal consulting for small and medium enterprises.
Answer this question: {message}
Always remind users that this is general information only and not professional legal advice."""
response = self.response_model.invoke([{"role": "user", "content": prompt}])
return response.content if hasattr(response, 'content') else str(response)
# Create initial state
initial_state = {"messages": [HumanMessage(content=message)]}
# Run the graph
result = self.graph.invoke(initial_state)
# Extract the final response
if result and "messages" in result and result["messages"]:
final_message = result["messages"][-1]
if hasattr(final_message, 'content'):
return final_message.content
else:
return str(final_message)
else:
return "I'm sorry, I couldn't generate a response. Please try again."
except Exception as e:
logger.error(f"Error in chat function: {e}")
return f"An error occurred while processing your request. Please try again."
# Initialize the bot
logger.info("Initializing Legal Consulting Bot...")
bot = LegalConsultingBot()
# Create Gradio interface
def create_interface():
with gr.Blocks(
title="Legal Consulting Assistant for SMEs",
theme=gr.themes.Soft(),
css="""
.container {
max-width: 800px;
margin: auto;
}
"""
) as demo:
gr.Markdown("""
# 🏒 Legal Consulting Assistant for Small & Medium Enterprises
Get informed answers about legal consulting, compliance, and business law for SMEs.
This assistant uses information from various legal consulting websites to provide relevant guidance.
**⚠️ Important**: This provides general information only and does not constitute professional legal advice.
""")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
height=500,
placeholder="πŸ’¬ Ask me about legal consulting for your business...",
avatar_images=("πŸ‘€", "πŸ€–"),
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="e.g., What legal structure should I choose for my startup?",
label="Your Question",
scale=4
)
submit_btn = gr.Button("Send", variant="primary", scale=1)
clear = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
def respond(message, chat_history):
if not message.strip():
return "", chat_history
# Show typing indicator
chat_history.append((message, "πŸ€” Thinking..."))
yield "", chat_history
# Get bot response
bot_message = bot.chat(message, chat_history)
chat_history[-1] = (message, bot_message)
yield "", chat_history
def clear_chat():
return []
# Event handlers
msg.submit(respond, [msg, chatbot], [msg, chatbot])
submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
clear.click(clear_chat, outputs=[chatbot])
gr.Markdown("""
### πŸ“‹ Example Questions
- *What legal structure should I choose for my small business?*
- *What compliance requirements do SMEs need to consider?*
- *How can I protect my business's intellectual property?*
- *What should be included in employment contracts?*
### βš–οΈ Disclaimer
This chatbot provides general information about legal consulting for SMEs based on publicly available resources.
**This information should not be considered as professional legal advice.**
Always consult with qualified legal professionals for specific legal matters and before making important business decisions.
""")
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)