| |
| import os |
| import gc |
| import tempfile |
| import gradio as gr |
| import torch |
| import numpy as np |
| import faiss |
| from typing import Tuple, Dict, Any, Optional |
| import spaces |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" |
| EMBED_MODEL_NAME = "BAAI/bge-large-en-v1.5" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| MAX_PROMPT_LENGTH = 28000 |
|
|
| |
| QA_PROMPT_TEMPLATE = ( |
| "System: You are a helpful assistant. Answer the user's question based *only* on the provided context. " |
| "If the answer is not found in the context, state that clearly.\n\n" |
| "Context:\n---\n{context}\n---\n\nQuestion: {question}\n\nAnswer:" |
| ) |
|
|
| SUMMARY_PROMPTS = { |
| "Quick": ( |
| "You are an expert academic summarizer. Provide a single, concise paragraph that summarizes the absolute key takeaway of the following document. " |
| "Be brief and direct.\n\nDocument:\n---\n{text}\n---\n\nQuick Summary:" |
| ), |
| "Standard": ( |
| "You are an expert academic summarizer. Provide a detailed, well-structured summary of the following document. " |
| "Cover the key points, methodology, findings, and conclusions.\n\n" |
| "Document:\n---\n{text}\n---\n\nStandard Summary:" |
| ), |
| "Detailed": ( |
| "You are an expert academic summarizer. Provide a highly detailed and comprehensive summary of the following document. " |
| "Go into depth on the methodology, specific results, limitations, and any mention of future work. Use multiple paragraphs for structure.\n\n" |
| "Document:\n---\n{text}\n---\n\nDetailed Summary:" |
| ) |
| } |
|
|
| |
| class ModelManager: |
| _llm_pipe = None |
| _embed_model = None |
|
|
| @classmethod |
| def _clear_gpu_memory(cls): |
| """Frees up GPU memory by deleting models and clearing the cache.""" |
| models = [cls._llm_pipe, cls._embed_model] |
| for model in models: |
| if model: |
| try: |
| del model |
| except Exception: |
| pass |
| cls._llm_pipe = None |
| cls._embed_model = None |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| print("[Memory] GPU Memory Cleared.") |
|
|
| @classmethod |
| def get_llm_pipeline(cls): |
| """Loads and returns the LLM pipeline, ensuring no other models are loaded.""" |
| if cls._llm_pipe is None: |
| cls._clear_gpu_memory() |
| print("[LLM] Loading model...") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained( |
| LLM_MODEL_NAME, |
| device_map=DEVICE, |
| torch_dtype=torch.bfloat16 |
| ) |
| cls._llm_pipe = pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| max_new_tokens=1024, |
| temperature=0.2, |
| top_p=0.95, |
| ) |
| print("[LLM] Model loaded successfully.") |
| except Exception as e: |
| print(f"[LLM] Failed to load model: {e}") |
| return None |
| return cls._llm_pipe |
|
|
| @classmethod |
| def get_embedding_model(cls): |
| """Loads and returns the embedding model, ensuring the LLM is not loaded.""" |
| |
| from langchain_huggingface import HuggingFaceEmbeddings |
| if cls._embed_model is None: |
| cls._clear_gpu_memory() |
| print("[Embed] Loading embedding model...") |
| try: |
| cls._embed_model = HuggingFaceEmbeddings( |
| model_name=EMBED_MODEL_NAME, |
| model_kwargs={"device": DEVICE}, |
| encode_kwargs={"normalize_embeddings": True} |
| ) |
| print("[Embed] Embedding model loaded successfully.") |
| except Exception as e: |
| print(f"[Embed] Failed to load model: {e}") |
| return None |
| return cls._embed_model |
|
|
| |
| @spaces.GPU |
| def invoke_llm(prompt_str: str) -> str: |
| """Invokes the LLM with a given prompt.""" |
| if len(prompt_str) > MAX_PROMPT_LENGTH: |
| prompt_str = prompt_str[:MAX_PROMPT_LENGTH] |
| print(f"[invoke_llm] Prompt truncated to {MAX_PROMPT_LENGTH} characters.") |
|
|
| try: |
| pipe = ModelManager.get_llm_pipeline() |
| if not pipe: |
| return "Error: LLM could not be loaded." |
|
|
| with torch.no_grad(): |
| outputs = pipe(prompt_str) |
|
|
| if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]: |
| |
| return outputs[0]["generated_text"].replace(prompt_str, "").strip() |
| return "No valid response was generated." |
|
|
| except Exception as e: |
| print(f"[invoke_llm] Error: {e}") |
| return f"LLM invocation failed: {e}" |
|
|
| @spaces.GPU |
| def process_pdf_and_index(pdf_path: str) -> Tuple[str, Optional[Dict[str, Any]]]: |
| """Processes a PDF, creates embeddings, and builds a FAISS index.""" |
| from langchain_community.document_loaders import PyMuPDFLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
| if not pdf_path: |
| return "No file path provided.", None |
|
|
| try: |
| print("[Process] Loading and splitting PDF...") |
| docs = PyMuPDFLoader(pdf_path).load() |
| chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150).split_documents(docs) |
| texts = [c.page_content for c in chunks if c.page_content.strip()] |
|
|
| if not texts: |
| return "No text could be extracted from the PDF.", None |
| print(f"[Process] Extracted {len(texts)} text chunks.") |
|
|
| embed_model = ModelManager.get_embedding_model() |
| if not embed_model: |
| return "Could not load embedding model.", None |
|
|
| print(f"[Process] Creating embeddings...") |
| embeddings = embed_model.embed_documents(texts) |
| emb_np = np.array(embeddings, dtype=np.float32) |
|
|
| print("[Process] Building and saving FAISS index...") |
| index = faiss.IndexFlatL2(emb_np.shape[1]) |
| index.add(emb_np) |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".faiss") as f: |
| index_path = f.name |
| faiss.write_index(index, index_path) |
|
|
| state_bundle = {"index_path": index_path, "texts": texts} |
| return f"Successfully processed and indexed {len(texts)} chunks.", state_bundle |
|
|
| except Exception as e: |
| print(f"[process_pdf] Exception: {e}") |
| return f"Error processing PDF: {e}", None |
|
|
| @spaces.GPU |
| def retrieve_and_answer(question: str, state_bundle: Dict[str, Any]) -> Tuple[str, str]: |
| """Retrieves context and generates an answer for a given question.""" |
| if not (state_bundle and "index_path" in state_bundle): |
| return "Please upload and process a PDF first.", "" |
|
|
| try: |
| embed_model = ModelManager.get_embedding_model() |
| if not embed_model: |
| return "Error loading embedding model.", "" |
|
|
| index = faiss.read_index(state_bundle["index_path"]) |
| texts = state_bundle.get("texts", []) |
|
|
| query_embedding = embed_model.embed_query(question) |
| q_arr = np.array([query_embedding], dtype=np.float32) |
|
|
| _, indices = index.search(q_arr, k=5) |
|
|
| sources = [texts[idx] for idx in indices[0] if 0 <= idx < len(texts)] |
| if not sources: |
| return "Could not find relevant information.", "" |
|
|
| context = "\n\n---\n\n".join(sources) |
| sources_preview = "\n\n---\n\n".join(s[:500] + "..." for s in sources) |
|
|
| prompt = QA_PROMPT_TEMPLATE.format(context=context, question=question) |
| answer = invoke_llm(prompt) |
|
|
| return answer, sources_preview |
|
|
| except Exception as e: |
| print(f"[retrieve_and_answer] Error: {e}") |
| return f"An error occurred: {e}", "" |
|
|
| @spaces.GPU |
| def summarize_document(state_bundle: Dict[str, Any], summary_type: str) -> Tuple[str, Optional[str]]: |
| """Generates a summary of the document and saves it to a temporary file.""" |
| if not (state_bundle and "texts" in state_bundle): |
| return "Please upload and process a PDF first.", None |
|
|
| texts = state_bundle.get("texts", []) |
| if not texts: |
| return "No text available to summarize.", None |
|
|
| full_text = "\n\n".join(texts) |
|
|
| prompt_template = SUMMARY_PROMPTS.get(summary_type, SUMMARY_PROMPTS["Standard"]) |
| prompt = prompt_template.format(text=full_text) |
|
|
| print(f"[Summarize] Generating '{summary_type}' summary...") |
| final_summary = invoke_llm(prompt) |
|
|
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8") |
| temp_file.write(final_summary) |
| temp_file.close() |
|
|
| return final_summary, temp_file.name |
|
|
| |
| with gr.Blocks(title="PDF Summarizer & Assistant", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 📚 PDF Summarizer & Q&A Assistant") |
| gr.Markdown("Upload a PDF to generate a summary or ask questions about its content.") |
|
|
| state = gr.State() |
|
|
| with gr.Row(): |
| pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"], type="filepath") |
| process_btn = gr.Button("Process PDF", variant="primary") |
|
|
| status_output = gr.Textbox(label="Status", interactive=False) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Summarization"): |
| gr.Markdown("### Generate a Summary") |
| gr.Markdown("Select the level of detail you want in the summary.") |
| summary_type_radio = gr.Radio( |
| choices=["Quick", "Standard", "Detailed"], |
| value="Standard", |
| label="Summary Type" |
| ) |
| summary_btn = gr.Button("Generate Summary", variant="secondary") |
| out_summary = gr.Textbox(label="Document Summary", lines=20, max_lines=25) |
| download_btn = gr.DownloadButton("Download Summary", visible=False) |
|
|
| with gr.TabItem("Question & Answer"): |
| gr.Markdown("### Ask a Question") |
| gr.Markdown("Ask a specific question about the document's content.") |
| q_text = gr.Textbox(label="Your Question", placeholder="e.g., What was the main conclusion of the study?") |
| q_btn = gr.Button("Get Answer", variant="secondary") |
| q_out = gr.Textbox(label="Answer", lines=8) |
| q_sources = gr.Textbox(label="Retrieved Sources", lines=8, max_lines=10) |
|
|
| |
| def handle_process(pdf_file): |
| """Wrapper to handle PDF processing and clear old outputs.""" |
| if pdf_file is None: |
| return "Please upload a file first.", None, "", "", "", "", None |
| status_msg, bundle = process_pdf_and_index(pdf_file.name) |
| |
| return status_msg, bundle, "", "", "", "", None |
|
|
| process_btn.click( |
| fn=handle_process, |
| inputs=[pdf_in], |
| outputs=[status_output, state, out_summary, q_text, q_out, q_sources, download_btn] |
| ) |
|
|
| q_btn.click( |
| fn=retrieve_and_answer, |
| inputs=[q_text, state], |
| outputs=[q_out, q_sources] |
| ) |
|
|
| summary_btn.click( |
| fn=summarize_document, |
| inputs=[state, summary_type_radio], |
| outputs=[out_summary, download_btn] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=False, show_error=True) |
|
|