Legal-Docs-AI / app.py
ommore86's picture
ohh
96874c5
import os, uuid, gc
from functools import lru_cache
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from dotenv import load_dotenv
import PyPDF2
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
# from langchain_huggingface import (
# HuggingFaceEmbeddings,
# HuggingFaceEndpoint,
# ChatHuggingFace,
# )
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
load_dotenv()
UPLOAD_FOLDER = os.path.join("/tmp", "uploads")
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
app = Flask(__name__)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
qa_chains: dict[str, RetrievalQA] = {}
@lru_cache(maxsize=1)
def get_embedder():
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
@lru_cache(maxsize=1)
def get_llm():
# return ChatHuggingFace(
# llm=HuggingFaceEndpoint(
# repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
# temperature=0.3,
# max_new_tokens=512,
# )
# )
return ChatGroq(
model_name="llama-3.1-8b-instant", # other options - llama-3.3-70b-versatile(very accurate), mistral-saba-24b(unlimited)
temperature=0.3,
max_tokens=512,
)
def extract_text_from_pdf(path: str) -> str:
with open(path, "rb") as f:
reader = PyPDF2.PdfReader(f)
return "".join(p.extract_text() or "" for p in reader.pages)
def build_qa_for_pdf(path: str) -> RetrievalQA:
text = extract_text_from_pdf(path)
chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\
.split_text(text)
store = FAISS.from_texts(chunks, get_embedder())
retriever = store.as_retriever(search_type="mmr", search_kwargs={"k": 5})
return RetrievalQA.from_chain_type(llm=get_llm(), retriever=retriever)
def summarize_pdf(path: str) -> str:
"""Summarize an entire PDF using the LLM."""
text = extract_text_from_pdf(path)
# Limit text size (avoid overloading model)
if len(text) > 5000:
text = text[:5000] # truncate long docs for quick summarization
llm = get_llm()
prompt = f"Summarize the following legal document in a clear and concise way:\n\n{text}"
result = llm.invoke(prompt)
return result.content if hasattr(result, "content") else str(result)
@app.route("/healthz")
def healthz():
return "ok", 200
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
@app.post("/upload")
def upload_file():
file = request.files.get("file")
if not file or not file.filename.lower().endswith(".pdf"):
return jsonify({"error": "Please upload a valid PDF."}), 400
filename = f"{uuid.uuid4()}_{secure_filename(file.filename)}"
path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
file.save(path)
try:
qa_chains[filename] = build_qa_for_pdf(path)
return jsonify({"file_id": filename}), 200
except Exception as e:
return jsonify({"error": f"Error building chain: {e}"}), 500
@app.post("/ask")
def ask_question():
data = request.get_json(force=True)
file_id = data.get("file_id")
question = (data.get("question") or "").strip()
if not file_id or not question:
return jsonify({"error": "Missing file_id or question."}), 400
qa_chain = qa_chains.get(file_id)
if not qa_chain:
return jsonify({"error": "File not found or expired."}), 404
try:
result = qa_chain.invoke({"query": question})
answer = result["result"] if isinstance(result, dict) else result
gc.collect()
return jsonify({"answer": answer}), 200
except Exception as e:
return jsonify({"error": f"Error: {e}"}), 500
@app.post("/summarize")
def summarize():
data = request.get_json(force=True)
file_id = data.get("file_id")
if not file_id:
return jsonify({"error": "Missing file_id."}), 400
# locate the uploaded file from /tmp/uploads
path = os.path.join(app.config["UPLOAD_FOLDER"], file_id)
if not os.path.exists(path):
return jsonify({"error": "File not found or expired."}), 404
try:
summary = summarize_pdf(path)
gc.collect()
return jsonify({"summary": summary}), 200
except Exception as e:
return jsonify({"error": f"Error summarizing: {e}"}), 500
if __name__ == "__main__":
app.run(debug=True)