Spaces:
Sleeping
Sleeping
File size: 4,729 Bytes
b4bcb01 96874c5 b4bcb01 96874c5 b4bcb01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os, uuid, gc
from functools import lru_cache
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from dotenv import load_dotenv
import PyPDF2
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
# from langchain_huggingface import (
# HuggingFaceEmbeddings,
# HuggingFaceEndpoint,
# ChatHuggingFace,
# )
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
load_dotenv()
UPLOAD_FOLDER = os.path.join("/tmp", "uploads")
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
app = Flask(__name__)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
qa_chains: dict[str, RetrievalQA] = {}
@lru_cache(maxsize=1)
def get_embedder():
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
@lru_cache(maxsize=1)
def get_llm():
# return ChatHuggingFace(
# llm=HuggingFaceEndpoint(
# repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
# temperature=0.3,
# max_new_tokens=512,
# )
# )
return ChatGroq(
model_name="llama-3.1-8b-instant", # other options - llama-3.3-70b-versatile(very accurate), mistral-saba-24b(unlimited)
temperature=0.3,
max_tokens=512,
)
def extract_text_from_pdf(path: str) -> str:
with open(path, "rb") as f:
reader = PyPDF2.PdfReader(f)
return "".join(p.extract_text() or "" for p in reader.pages)
def build_qa_for_pdf(path: str) -> RetrievalQA:
text = extract_text_from_pdf(path)
chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\
.split_text(text)
store = FAISS.from_texts(chunks, get_embedder())
retriever = store.as_retriever(search_type="mmr", search_kwargs={"k": 5})
return RetrievalQA.from_chain_type(llm=get_llm(), retriever=retriever)
def summarize_pdf(path: str) -> str:
"""Summarize an entire PDF using the LLM."""
text = extract_text_from_pdf(path)
# Limit text size (avoid overloading model)
if len(text) > 5000:
text = text[:5000] # truncate long docs for quick summarization
llm = get_llm()
prompt = f"Summarize the following legal document in a clear and concise way:\n\n{text}"
result = llm.invoke(prompt)
return result.content if hasattr(result, "content") else str(result)
@app.route("/healthz")
def healthz():
return "ok", 200
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
@app.post("/upload")
def upload_file():
file = request.files.get("file")
if not file or not file.filename.lower().endswith(".pdf"):
return jsonify({"error": "Please upload a valid PDF."}), 400
filename = f"{uuid.uuid4()}_{secure_filename(file.filename)}"
path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
file.save(path)
try:
qa_chains[filename] = build_qa_for_pdf(path)
return jsonify({"file_id": filename}), 200
except Exception as e:
return jsonify({"error": f"Error building chain: {e}"}), 500
@app.post("/ask")
def ask_question():
data = request.get_json(force=True)
file_id = data.get("file_id")
question = (data.get("question") or "").strip()
if not file_id or not question:
return jsonify({"error": "Missing file_id or question."}), 400
qa_chain = qa_chains.get(file_id)
if not qa_chain:
return jsonify({"error": "File not found or expired."}), 404
try:
result = qa_chain.invoke({"query": question})
answer = result["result"] if isinstance(result, dict) else result
gc.collect()
return jsonify({"answer": answer}), 200
except Exception as e:
return jsonify({"error": f"Error: {e}"}), 500
@app.post("/summarize")
def summarize():
data = request.get_json(force=True)
file_id = data.get("file_id")
if not file_id:
return jsonify({"error": "Missing file_id."}), 400
# locate the uploaded file from /tmp/uploads
path = os.path.join(app.config["UPLOAD_FOLDER"], file_id)
if not os.path.exists(path):
return jsonify({"error": "File not found or expired."}), 404
try:
summary = summarize_pdf(path)
gc.collect()
return jsonify({"summary": summary}), 200
except Exception as e:
return jsonify({"error": f"Error summarizing: {e}"}), 500
if __name__ == "__main__":
app.run(debug=True)
|