import os from typing import TypedDict, List from pydantic import BaseModel, Field # --- Core LangChain & Document Processing Imports --- from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from core_utils.core_model_loaders import load_embedding_model, load_gemini_llm from langgraph.graph import StateGraph, END # --- Initialize Models --- embedding_model = load_embedding_model() llm = load_gemini_llm() # --- 1. RAG Chain Logic --- def create_rag_chain(retriever): """Creates a RAG chain for answering questions about the document.""" template = """Answer the question based only on the following context: {context} Question: {question} """ prompt = PromptTemplate.from_template(template) def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) return rag_chain # --- 2. Demystifier Graph Logic --- class DemystifierState(TypedDict): document_chunks: List[str] summary: str key_terms: str final_report: str def summarize_document(state: DemystifierState): """Summarizes the provided document chunks.""" print("---NODE: Summarizing Document---") text = "\n\n".join(state["document_chunks"]) # Truncate for safety if too large for prompt text = text[:30000] prompt = f""" You are a legal expert. Summarize the following legal document content in simple, easy-to-understand language. Focus on the main purpose and parties involved. Content: {text} """ response = llm.invoke(prompt) return {"summary": response.content} def extract_key_terms(state: DemystifierState): """Extracts and explains key legal terms.""" print("---NODE: Extracting Key Terms---") text = "\n\n".join(state["document_chunks"]) text = text[:30000] prompt = f""" Identify 5-7 complex legal terms or clauses from the text below. List them and explain what they mean in plain English for a layperson. Content: {text} """ response = llm.invoke(prompt) return {"key_terms": response.content} def generate_report(state: DemystifierState): """Compiles the final analysis report.""" print("---NODE: Generating Final Report---") report = f""" # Document Analysis ## 📝 Summary {state['summary']} ## 🔑 Key Terms & Definitions {state['key_terms']} ## 💡 Expert Advice Always consult with a qualified lawyer for critical legal decisions. This analysis is AI-generated guidance. """ return {"final_report": report} # --- Build the Graph --- workflow = StateGraph(DemystifierState) workflow.add_node("summarize", summarize_document) workflow.add_node("extract_terms", extract_key_terms) workflow.add_node("compile_report", generate_report) # Parallel execution of summary and terms workflow.set_entry_point("summarize") workflow.add_edge("summarize", "extract_terms") workflow.add_edge("extract_terms", "compile_report") workflow.add_edge("compile_report", END) demystifier_agent_graph = workflow.compile() # --- 4. The Master "Controller" Function --- def process_document_for_demystification(file_path: str): """Loads a PDF, runs the full analysis, creates a RAG chain, and returns both.""" print(f"--- Processing document: {file_path} ---") loader = PyPDFLoader(file_path) documents = loader.load() if not documents: raise ValueError("No content found in PDF.") splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) chunks = splitter.split_documents(documents) print("--- Creating FAISS vector store for Q&A ---") vectorstore = FAISS.from_documents(chunks, embedding=embedding_model) retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) rag_chain = create_rag_chain(retriever) print("--- Running analysis graph for the report ---") chunk_contents = [chunk.page_content for chunk in chunks] # Limit context to avoid token limits if document is huge graph_input = {"document_chunks": chunk_contents} result = demystifier_agent_graph.invoke(graph_input) report = result.get("final_report") return {"report": report, "rag_chain": rag_chain}