jan-contract / agents /demystifier_agent.py
Amodit's picture
Restore missing agent logic
bd6f8a7
import os
from typing import TypedDict, List
from pydantic import BaseModel, Field
# --- Core LangChain & Document Processing Imports ---
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from core_utils.core_model_loaders import load_embedding_model, load_gemini_llm
from langgraph.graph import StateGraph, END
# --- Initialize Models ---
embedding_model = load_embedding_model()
llm = load_gemini_llm()
# --- 1. RAG Chain Logic ---
def create_rag_chain(retriever):
"""Creates a RAG chain for answering questions about the document."""
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = PromptTemplate.from_template(template)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
# --- 2. Demystifier Graph Logic ---
class DemystifierState(TypedDict):
document_chunks: List[str]
summary: str
key_terms: str
final_report: str
def summarize_document(state: DemystifierState):
"""Summarizes the provided document chunks."""
print("---NODE: Summarizing Document---")
text = "\n\n".join(state["document_chunks"])
# Truncate for safety if too large for prompt
text = text[:30000]
prompt = f"""
You are a legal expert. Summarize the following legal document content in simple, easy-to-understand language.
Focus on the main purpose and parties involved.
Content:
{text}
"""
response = llm.invoke(prompt)
return {"summary": response.content}
def extract_key_terms(state: DemystifierState):
"""Extracts and explains key legal terms."""
print("---NODE: Extracting Key Terms---")
text = "\n\n".join(state["document_chunks"])
text = text[:30000]
prompt = f"""
Identify 5-7 complex legal terms or clauses from the text below.
List them and explain what they mean in plain English for a layperson.
Content:
{text}
"""
response = llm.invoke(prompt)
return {"key_terms": response.content}
def generate_report(state: DemystifierState):
"""Compiles the final analysis report."""
print("---NODE: Generating Final Report---")
report = f"""
# Document Analysis
## 📝 Summary
{state['summary']}
## 🔑 Key Terms & Definitions
{state['key_terms']}
## 💡 Expert Advice
Always consult with a qualified lawyer for critical legal decisions. This analysis is AI-generated guidance.
"""
return {"final_report": report}
# --- Build the Graph ---
workflow = StateGraph(DemystifierState)
workflow.add_node("summarize", summarize_document)
workflow.add_node("extract_terms", extract_key_terms)
workflow.add_node("compile_report", generate_report)
# Parallel execution of summary and terms
workflow.set_entry_point("summarize")
workflow.add_edge("summarize", "extract_terms")
workflow.add_edge("extract_terms", "compile_report")
workflow.add_edge("compile_report", END)
demystifier_agent_graph = workflow.compile()
# --- 4. The Master "Controller" Function ---
def process_document_for_demystification(file_path: str):
"""Loads a PDF, runs the full analysis, creates a RAG chain, and returns both."""
print(f"--- Processing document: {file_path} ---")
loader = PyPDFLoader(file_path)
documents = loader.load()
if not documents:
raise ValueError("No content found in PDF.")
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = splitter.split_documents(documents)
print("--- Creating FAISS vector store for Q&A ---")
vectorstore = FAISS.from_documents(chunks, embedding=embedding_model)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
rag_chain = create_rag_chain(retriever)
print("--- Running analysis graph for the report ---")
chunk_contents = [chunk.page_content for chunk in chunks]
# Limit context to avoid token limits if document is huge
graph_input = {"document_chunks": chunk_contents}
result = demystifier_agent_graph.invoke(graph_input)
report = result.get("final_report")
return {"report": report, "rag_chain": rag_chain}