File size: 4,729 Bytes
b4bcb01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96874c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4bcb01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96874c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4bcb01
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os, uuid, gc
from functools import lru_cache

from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from dotenv import load_dotenv
import PyPDF2

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
# from langchain_huggingface import (
#     HuggingFaceEmbeddings,
#     HuggingFaceEndpoint,
#     ChatHuggingFace,
# )
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA

load_dotenv()
UPLOAD_FOLDER = os.path.join("/tmp", "uploads")
os.makedirs(UPLOAD_FOLDER, exist_ok=True)

os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"


app = Flask(__name__)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
qa_chains: dict[str, RetrievalQA] = {}

@lru_cache(maxsize=1)
def get_embedder():
    return HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )

@lru_cache(maxsize=1)
def get_llm():
    # return ChatHuggingFace(
    #     llm=HuggingFaceEndpoint(
    #         repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
    #         temperature=0.3,
    #         max_new_tokens=512,
    #     )
    # )
    return ChatGroq(
        model_name="llama-3.1-8b-instant", # other options - llama-3.3-70b-versatile(very accurate), mistral-saba-24b(unlimited)
        temperature=0.3,
        max_tokens=512,
    )

def extract_text_from_pdf(path: str) -> str:
    with open(path, "rb") as f:
        reader = PyPDF2.PdfReader(f)
        return "".join(p.extract_text() or "" for p in reader.pages)

def build_qa_for_pdf(path: str) -> RetrievalQA:
    text = extract_text_from_pdf(path)
    chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\
             .split_text(text)

    store = FAISS.from_texts(chunks, get_embedder())
    retriever = store.as_retriever(search_type="mmr", search_kwargs={"k": 5})
    return RetrievalQA.from_chain_type(llm=get_llm(), retriever=retriever)

def summarize_pdf(path: str) -> str:
    """Summarize an entire PDF using the LLM."""
    text = extract_text_from_pdf(path)

    # Limit text size (avoid overloading model)
    if len(text) > 5000:
        text = text[:5000]  # truncate long docs for quick summarization

    llm = get_llm()
    prompt = f"Summarize the following legal document in a clear and concise way:\n\n{text}"
    result = llm.invoke(prompt)

    return result.content if hasattr(result, "content") else str(result)


@app.route("/healthz")
def healthz():
    return "ok", 200

@app.route("/", methods=["GET"])
def index():
    return render_template("index.html")

@app.post("/upload")
def upload_file():
    file = request.files.get("file")
    if not file or not file.filename.lower().endswith(".pdf"):
        return jsonify({"error": "Please upload a valid PDF."}), 400

    filename = f"{uuid.uuid4()}_{secure_filename(file.filename)}"
    path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
    file.save(path)

    try:
        qa_chains[filename] = build_qa_for_pdf(path)
        return jsonify({"file_id": filename}), 200
    except Exception as e:
        return jsonify({"error": f"Error building chain: {e}"}), 500

@app.post("/ask")
def ask_question():
    data = request.get_json(force=True)
    file_id  = data.get("file_id")
    question = (data.get("question") or "").strip()

    if not file_id or not question:
        return jsonify({"error": "Missing file_id or question."}), 400

    qa_chain = qa_chains.get(file_id)
    if not qa_chain:
        return jsonify({"error": "File not found or expired."}), 404

    try:
        result = qa_chain.invoke({"query": question})
        answer = result["result"] if isinstance(result, dict) else result
        gc.collect()
        return jsonify({"answer": answer}), 200
    except Exception as e:
        return jsonify({"error": f"Error: {e}"}), 500
    
@app.post("/summarize")
def summarize():
    data = request.get_json(force=True)
    file_id = data.get("file_id")

    if not file_id:
        return jsonify({"error": "Missing file_id."}), 400

    # locate the uploaded file from /tmp/uploads
    path = os.path.join(app.config["UPLOAD_FOLDER"], file_id)
    if not os.path.exists(path):
        return jsonify({"error": "File not found or expired."}), 404

    try:
        summary = summarize_pdf(path)
        gc.collect()
        return jsonify({"summary": summary}), 200
    except Exception as e:
        return jsonify({"error": f"Error summarizing: {e}"}), 500


if __name__ == "__main__":
    app.run(debug=True)