Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from langchain_core.vectorstores import InMemoryVectorStore | |
| from langchain.tools.retriever import create_retriever_tool | |
| from langgraph.graph import MessagesState, StateGraph, START, END | |
| from langchain.chat_models import init_chat_model | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from pydantic import BaseModel, Field | |
| from typing import Literal | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Get API key from Hugging Face Spaces secrets | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| if GOOGLE_API_KEY: | |
| os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY | |
| else: | |
| logger.warning("GOOGLE_API_KEY not found in environment variables") | |
| class LegalConsultingBot: | |
| def __init__(self): | |
| self.graph = None | |
| self.retriever_tool = None | |
| self.response_model = None | |
| self.grader_model = None | |
| self.initialize_bot() | |
| def initialize_bot(self): | |
| """Initialize the bot with error handling.""" | |
| try: | |
| if not GOOGLE_API_KEY: | |
| logger.error("Google API key not available") | |
| return | |
| self.setup_models() | |
| self.setup_retriever() | |
| self.setup_workflow() | |
| logger.info("Bot initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize bot: {e}") | |
| def setup_models(self): | |
| """Initialize the language models.""" | |
| try: | |
| self.response_model = init_chat_model( | |
| "gemini-2.0-flash", | |
| model_provider="google_genai", | |
| temperature=0 | |
| ) | |
| self.grader_model = init_chat_model( | |
| "gemini-2.0-flash", | |
| model_provider="google_genai", | |
| temperature=0 | |
| ) | |
| logger.info("Models initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize models: {e}") | |
| raise | |
| def setup_retriever(self): | |
| """Initialize the document retriever with legal consulting URLs.""" | |
| urls = [ | |
| "https://firststepslegal.co.uk/", | |
| "https://sprintlaw.co.uk/", | |
| "https://www.smbs.solutions/legal-and-compliance-resources-legal-assistance-for-businesses", | |
| "https://ignition.law/lawyers-for-smes/", | |
| "https://www.cocredo.co.uk/news/free-legal-advice-small-business-owners", | |
| "https://www.gannons.co.uk/sectors/smes/", | |
| "https://kkbservices.com/who-we-work-with/small-businesses/", | |
| "https://farringfordlegal.co.uk/", | |
| "https://stanislawlegal.com/en/legal-solutions/for-sme-companies/", | |
| "https://www.lawhive.co.uk/small-business/", | |
| "https://www.catalystlaw.co.uk/business-legal-advice.html", | |
| "https://medium.com/@kmitsme123/legal-consulting-tips-for-your-small-business-9075005eb574", | |
| "https://smecomply.co.uk/", | |
| "https://dojobusiness.com/blogs/news/legal-consultant-complete-guide", | |
| "https://englishlegaladvice.com/", | |
| ] | |
| try: | |
| # Load documents with error handling | |
| docs = [] | |
| successful_loads = 0 | |
| for url in urls: | |
| try: | |
| loader = WebBaseLoader(url) | |
| docs.extend(loader.load()) | |
| successful_loads += 1 | |
| logger.info(f"Successfully loaded: {url}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load {url}: {e}") | |
| continue | |
| logger.info(f"Successfully loaded {successful_loads}/{len(urls)} URLs") | |
| if not docs: | |
| logger.warning("No documents could be loaded, using fallback mode") | |
| return | |
| # Split documents | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=300, chunk_overlap=50 | |
| ) | |
| doc_splits = text_splitter.split_documents(docs) | |
| logger.info(f"Created {len(doc_splits)} document chunks") | |
| # Create embeddings and vectorstore | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
| vectorstore = InMemoryVectorStore.from_documents( | |
| documents=doc_splits, | |
| embedding=embeddings | |
| ) | |
| # Create retriever tool | |
| retriever = vectorstore.as_retriever() | |
| self.retriever_tool = create_retriever_tool( | |
| retriever, | |
| name="legal_consulting_retriever", | |
| description="Search and return relevant legal information and consulting resources for small and medium-sized businesses." | |
| ) | |
| logger.info("Retriever tool created successfully") | |
| except Exception as e: | |
| logger.error(f"Error setting up retriever: {e}") | |
| self.retriever_tool = None | |
| def setup_workflow(self): | |
| """Set up the LangGraph workflow.""" | |
| try: | |
| if not self.response_model: | |
| logger.error("Response model not available") | |
| return | |
| # Create workflow | |
| workflow = StateGraph(MessagesState) | |
| # Add nodes | |
| workflow.add_node("generate_query_or_respond", self.generate_query_or_respond) | |
| if self.retriever_tool: | |
| workflow.add_node("retrieve", ToolNode([self.retriever_tool])) | |
| workflow.add_node("grade_documents", self.grade_documents_node) | |
| workflow.add_node("rewrite_question", self.rewrite_question) | |
| workflow.add_node("generate_answer", self.generate_answer) | |
| # Add edges | |
| workflow.add_edge(START, "generate_query_or_respond") | |
| if self.retriever_tool: | |
| workflow.add_conditional_edges( | |
| "generate_query_or_respond", | |
| tools_condition, | |
| { | |
| "tools": "retrieve", | |
| END: END, | |
| }, | |
| ) | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| lambda x: x.get("grade_result", "generate_answer"), | |
| { | |
| "generate_answer": "generate_answer", | |
| "rewrite_question": "rewrite_question" | |
| } | |
| ) | |
| else: | |
| workflow.add_edge("generate_query_or_respond", END) | |
| workflow.add_edge("generate_answer", END) | |
| workflow.add_edge("rewrite_question", "generate_query_or_respond") | |
| self.graph = workflow.compile() | |
| logger.info("Workflow compiled successfully") | |
| except Exception as e: | |
| logger.error(f"Error setting up workflow: {e}") | |
| self.graph = None | |
| def generate_query_or_respond(self, state: MessagesState): | |
| """Generate query or respond directly.""" | |
| try: | |
| if not self.retriever_tool: | |
| # Fallback response when retriever is not available | |
| fallback_response = """I'm a legal consulting assistant for small and medium enterprises. While my document retriever is currently unavailable, I can still help answer general questions about: | |
| - Business formation and structure | |
| - Contract basics and employment law | |
| - Intellectual property fundamentals | |
| - Compliance and regulatory matters | |
| - General legal considerations for SMEs | |
| Please note: This is general information only, not legal advice. Always consult qualified legal professionals for specific matters.""" | |
| return {"messages": [AIMessage(content=fallback_response)]} | |
| response = ( | |
| self.response_model | |
| .bind_tools([self.retriever_tool]) | |
| .invoke(state["messages"]) | |
| ) | |
| return {"messages": [response]} | |
| except Exception as e: | |
| logger.error(f"Error in generate_query_or_respond: {e}") | |
| return {"messages": [AIMessage(content="I'm sorry, I encountered an error. Please try again.")]} | |
| class GradeDocuments(BaseModel): | |
| """Grade documents using a binary score for relevance check.""" | |
| binary_score: str = Field( | |
| description="Relevance score: 'yes' if relevant, or 'no' if not relevant" | |
| ) | |
| def grade_documents_node(self, state: MessagesState): | |
| """Node wrapper for document grading.""" | |
| try: | |
| question = state["messages"][0].content | |
| context = state["messages"][-1].content if state["messages"] else "" | |
| GRADE_PROMPT = ( | |
| "You are a grader assessing relevance of a retrieved document to a user question. \n " | |
| "Here is the retrieved document: \n\n {context} \n\n" | |
| "Here is the user question: {question} \n" | |
| "If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n" | |
| "Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question." | |
| ) | |
| prompt = GRADE_PROMPT.format(question=question, context=context) | |
| response = ( | |
| self.grader_model | |
| .with_structured_output(self.GradeDocuments) | |
| .invoke([{"role": "user", "content": prompt}]) | |
| ) | |
| grade_result = "generate_answer" if response.binary_score == "yes" else "rewrite_question" | |
| return {"grade_result": grade_result} | |
| except Exception as e: | |
| logger.error(f"Error in grade_documents_node: {e}") | |
| return {"grade_result": "generate_answer"} # Default to generating answer | |
| def rewrite_question(self, state: MessagesState): | |
| """Rewrite the original user question.""" | |
| try: | |
| messages = state["messages"] | |
| question = messages[0].content if messages else "" | |
| REWRITE_PROMPT = ( | |
| "Look at the input and try to reason about the underlying semantic intent / meaning.\n" | |
| "Here is the initial question:" | |
| "\n ------- \n" | |
| "{question}" | |
| "\n ------- \n" | |
| "Formulate an improved question that would be better for searching legal consulting information:" | |
| ) | |
| prompt = REWRITE_PROMPT.format(question=question) | |
| response = self.response_model.invoke([{"role": "user", "content": prompt}]) | |
| return {"messages": [HumanMessage(content=response.content)]} | |
| except Exception as e: | |
| logger.error(f"Error in rewrite_question: {e}") | |
| # Return original message if rewriting fails | |
| return {"messages": state["messages"][:1] if state["messages"] else []} | |
| def generate_answer(self, state: MessagesState): | |
| """Generate an answer based on retrieved context.""" | |
| try: | |
| question = state["messages"][0].content if state["messages"] else "" | |
| context = state["messages"][-1].content if len(state["messages"]) > 1 else "" | |
| GENERATE_PROMPT = ( | |
| "You are an assistant for question-answering tasks about legal consulting for small and medium businesses. " | |
| "Use the following pieces of retrieved context to answer the question. " | |
| "If you don't know the answer, just say that you don't know. " | |
| "Keep the answer concise but informative. Always remind users that this is general information and not legal advice.\n" | |
| "Question: {question} \n" | |
| "Context: {context}" | |
| ) | |
| prompt = GENERATE_PROMPT.format(question=question, context=context) | |
| response = self.response_model.invoke([{"role": "user", "content": prompt}]) | |
| return {"messages": [response]} | |
| except Exception as e: | |
| logger.error(f"Error in generate_answer: {e}") | |
| return {"messages": [AIMessage(content="I'm sorry, I encountered an error while generating the answer. Please try again.")]} | |
| def chat(self, message: str, history: list) -> str: | |
| """Main chat function for Gradio interface.""" | |
| try: | |
| if not message or not message.strip(): | |
| return "Please enter a question." | |
| if not self.response_model: | |
| return "Sorry, the system is not properly initialized. Please check if the API key is configured correctly." | |
| # For simple cases without retriever, provide direct response | |
| if not self.graph: | |
| prompt = f"""You are a helpful assistant specializing in legal consulting for small and medium enterprises. | |
| Answer this question: {message} | |
| Always remind users that this is general information only and not professional legal advice.""" | |
| response = self.response_model.invoke([{"role": "user", "content": prompt}]) | |
| return response.content if hasattr(response, 'content') else str(response) | |
| # Create initial state | |
| initial_state = {"messages": [HumanMessage(content=message)]} | |
| # Run the graph | |
| result = self.graph.invoke(initial_state) | |
| # Extract the final response | |
| if result and "messages" in result and result["messages"]: | |
| final_message = result["messages"][-1] | |
| if hasattr(final_message, 'content'): | |
| return final_message.content | |
| else: | |
| return str(final_message) | |
| else: | |
| return "I'm sorry, I couldn't generate a response. Please try again." | |
| except Exception as e: | |
| logger.error(f"Error in chat function: {e}") | |
| return f"An error occurred while processing your request. Please try again." | |
| # Initialize the bot | |
| logger.info("Initializing Legal Consulting Bot...") | |
| bot = LegalConsultingBot() | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="Legal Consulting Assistant for SMEs", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .container { | |
| max-width: 800px; | |
| margin: auto; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π’ Legal Consulting Assistant for Small & Medium Enterprises | |
| Get informed answers about legal consulting, compliance, and business law for SMEs. | |
| This assistant uses information from various legal consulting websites to provide relevant guidance. | |
| **β οΈ Important**: This provides general information only and does not constitute professional legal advice. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| placeholder="π¬ Ask me about legal consulting for your business...", | |
| avatar_images=("π€", "π€"), | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="e.g., What legal structure should I choose for my startup?", | |
| label="Your Question", | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear = gr.Button("ποΈ Clear Chat", variant="secondary") | |
| def respond(message, chat_history): | |
| if not message.strip(): | |
| return "", chat_history | |
| # Show typing indicator | |
| chat_history.append((message, "π€ Thinking...")) | |
| yield "", chat_history | |
| # Get bot response | |
| bot_message = bot.chat(message, chat_history) | |
| chat_history[-1] = (message, bot_message) | |
| yield "", chat_history | |
| def clear_chat(): | |
| return [] | |
| # Event handlers | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| submit_btn.click(respond, [msg, chatbot], [msg, chatbot]) | |
| clear.click(clear_chat, outputs=[chatbot]) | |
| gr.Markdown(""" | |
| ### π Example Questions | |
| - *What legal structure should I choose for my small business?* | |
| - *What compliance requirements do SMEs need to consider?* | |
| - *How can I protect my business's intellectual property?* | |
| - *What should be included in employment contracts?* | |
| ### βοΈ Disclaimer | |
| This chatbot provides general information about legal consulting for SMEs based on publicly available resources. | |
| **This information should not be considered as professional legal advice.** | |
| Always consult with qualified legal professionals for specific legal matters and before making important business decisions. | |
| """) | |
| return demo | |
| # Create and launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |