Spaces:
Sleeping
Sleeping
| import os, uuid, gc | |
| from functools import lru_cache | |
| from flask import Flask, render_template, request, jsonify | |
| from werkzeug.utils import secure_filename | |
| from dotenv import load_dotenv | |
| import PyPDF2 | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| # from langchain_huggingface import ( | |
| # HuggingFaceEmbeddings, | |
| # HuggingFaceEndpoint, | |
| # ChatHuggingFace, | |
| # ) | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import RetrievalQA | |
| load_dotenv() | |
| UPLOAD_FOLDER = os.path.join("/tmp", "uploads") | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
| os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets" | |
| app = Flask(__name__) | |
| app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER | |
| qa_chains: dict[str, RetrievalQA] = {} | |
| def get_embedder(): | |
| return HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| def get_llm(): | |
| # return ChatHuggingFace( | |
| # llm=HuggingFaceEndpoint( | |
| # repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| # temperature=0.3, | |
| # max_new_tokens=512, | |
| # ) | |
| # ) | |
| return ChatGroq( | |
| model_name="llama-3.1-8b-instant", # other options - llama-3.3-70b-versatile(very accurate), mistral-saba-24b(unlimited) | |
| temperature=0.3, | |
| max_tokens=512, | |
| ) | |
| def extract_text_from_pdf(path: str) -> str: | |
| with open(path, "rb") as f: | |
| reader = PyPDF2.PdfReader(f) | |
| return "".join(p.extract_text() or "" for p in reader.pages) | |
| def build_qa_for_pdf(path: str) -> RetrievalQA: | |
| text = extract_text_from_pdf(path) | |
| chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\ | |
| .split_text(text) | |
| store = FAISS.from_texts(chunks, get_embedder()) | |
| retriever = store.as_retriever(search_type="mmr", search_kwargs={"k": 5}) | |
| return RetrievalQA.from_chain_type(llm=get_llm(), retriever=retriever) | |
| def summarize_pdf(path: str) -> str: | |
| """Summarize an entire PDF using the LLM.""" | |
| text = extract_text_from_pdf(path) | |
| # Limit text size (avoid overloading model) | |
| if len(text) > 5000: | |
| text = text[:5000] # truncate long docs for quick summarization | |
| llm = get_llm() | |
| prompt = f"Summarize the following legal document in a clear and concise way:\n\n{text}" | |
| result = llm.invoke(prompt) | |
| return result.content if hasattr(result, "content") else str(result) | |
| def healthz(): | |
| return "ok", 200 | |
| def index(): | |
| return render_template("index.html") | |
| def upload_file(): | |
| file = request.files.get("file") | |
| if not file or not file.filename.lower().endswith(".pdf"): | |
| return jsonify({"error": "Please upload a valid PDF."}), 400 | |
| filename = f"{uuid.uuid4()}_{secure_filename(file.filename)}" | |
| path = os.path.join(app.config["UPLOAD_FOLDER"], filename) | |
| file.save(path) | |
| try: | |
| qa_chains[filename] = build_qa_for_pdf(path) | |
| return jsonify({"file_id": filename}), 200 | |
| except Exception as e: | |
| return jsonify({"error": f"Error building chain: {e}"}), 500 | |
| def ask_question(): | |
| data = request.get_json(force=True) | |
| file_id = data.get("file_id") | |
| question = (data.get("question") or "").strip() | |
| if not file_id or not question: | |
| return jsonify({"error": "Missing file_id or question."}), 400 | |
| qa_chain = qa_chains.get(file_id) | |
| if not qa_chain: | |
| return jsonify({"error": "File not found or expired."}), 404 | |
| try: | |
| result = qa_chain.invoke({"query": question}) | |
| answer = result["result"] if isinstance(result, dict) else result | |
| gc.collect() | |
| return jsonify({"answer": answer}), 200 | |
| except Exception as e: | |
| return jsonify({"error": f"Error: {e}"}), 500 | |
| def summarize(): | |
| data = request.get_json(force=True) | |
| file_id = data.get("file_id") | |
| if not file_id: | |
| return jsonify({"error": "Missing file_id."}), 400 | |
| # locate the uploaded file from /tmp/uploads | |
| path = os.path.join(app.config["UPLOAD_FOLDER"], file_id) | |
| if not os.path.exists(path): | |
| return jsonify({"error": "File not found or expired."}), 404 | |
| try: | |
| summary = summarize_pdf(path) | |
| gc.collect() | |
| return jsonify({"summary": summary}), 200 | |
| except Exception as e: | |
| return jsonify({"error": f"Error summarizing: {e}"}), 500 | |
| if __name__ == "__main__": | |
| app.run(debug=True) | |