Spaces:
Sleeping
Sleeping
| # KEEPING YOUR ORIGINAL IMPORTS | |
| import gradio as gr | |
| import PyPDF2 | |
| import io | |
| from transformers import pipeline, AutoTokenizer | |
| import torch | |
| import re | |
| from typing import List, Tuple | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # QUESTION-ANSWERING ADDITION | |
| qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
| # === SUMMARIZER CLASS === | |
| class PDFSummarizer: | |
| def __init__(self): | |
| self.model_name = "sshleifer/distilbart-cnn-12-6" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| try: | |
| self.summarizer = pipeline( | |
| "summarization", | |
| model=self.model_name, | |
| device=0 if self.device == "cuda" else -1, | |
| framework="pt", | |
| model_kwargs={"torch_dtype": torch.float16 if self.device == "cuda" else torch.float32} | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| self.model_name = "facebook/bart-large-cnn" | |
| self.summarizer = pipeline("summarization", model=self.model_name, device=-1) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| print("Fallback model loaded") | |
| def extract_text_from_pdf(self, pdf_file) -> str: | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_file)) | |
| text = "" | |
| for page_num, page in enumerate(pdf_reader.pages): | |
| page_text = page.extract_text() | |
| if page_text.strip(): | |
| text += f"\n--- Page {page_num + 1} ---\n" | |
| text += page_text | |
| return text.strip() | |
| except Exception as e: | |
| raise Exception(f"Error extracting text from PDF: {str(e)}") | |
| def clean_text(self, text: str) -> str: | |
| text = re.sub(r'\s+', ' ', text) | |
| text = re.sub(r'[^\w\s.,!?;:()\-"]', ' ', text) | |
| text = re.sub(r'--- Page \d+ ---', '', text) | |
| return text.strip() | |
| def chunk_text(self, text: str, max_chunk_length: int = 512) -> List[str]: | |
| sentences = text.split('. ') | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| potential_chunk = current_chunk + sentence + ". " | |
| if len(potential_chunk.split()) <= max_chunk_length: | |
| current_chunk = potential_chunk | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence + ". " | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks[:5] | |
| def summarize_chunk(self, chunk: str, max_length: int = 100, min_length: int = 30) -> str: | |
| try: | |
| summary = self.summarizer( | |
| chunk, | |
| max_length=max_length, | |
| min_length=min_length, | |
| do_sample=False, | |
| truncation=True, | |
| early_stopping=True, | |
| num_beams=2 | |
| ) | |
| return summary[0]['summary_text'] | |
| except Exception as e: | |
| return f"Error summarizing chunk: {str(e)}" | |
| def process_pdf(self, pdf_file, summary_type: str) -> Tuple[str, str, str]: | |
| try: | |
| raw_text = self.extract_text_from_pdf(pdf_file) | |
| if not raw_text.strip(): | |
| return "β Error: No text could be extracted from the PDF.", "", "" | |
| cleaned_text = self.clean_text(raw_text) | |
| word_count = len(cleaned_text.split()) | |
| char_count = len(cleaned_text) | |
| if word_count < 50: | |
| return "β Error: PDF contains too little text to summarize.", "", "" | |
| chunks = self.chunk_text(cleaned_text) | |
| if summary_type == "Brief (Quick)": | |
| max_len, min_len = 60, 20 | |
| elif summary_type == "Detailed": | |
| max_len, min_len = 100, 40 | |
| else: | |
| max_len, min_len = 150, 60 | |
| chunk_summaries = [] | |
| for i, chunk in enumerate(chunks): | |
| print(f"Processing chunk {i+1}/{len(chunks)}") | |
| summary = self.summarize_chunk(chunk, max_len, min_len) | |
| chunk_summaries.append(summary) | |
| combined_summary = " ".join(chunk_summaries) | |
| if len(chunks) <= 2: | |
| final_summary = combined_summary | |
| else: | |
| final_summary = self.summarize_chunk( | |
| combined_summary, | |
| max_length=min(200, max_len * 1.5), | |
| min_length=min_len | |
| ) | |
| summary_stats = f""" | |
| π **Document Statistics:** | |
| - Original word count: {word_count:,} | |
| - Original character count: {char_count:,} | |
| - Pages processed: {len(chunks)} | |
| - Summary word count: {len(final_summary.split()):,} | |
| - Compression ratio: {word_count / len(final_summary.split()):.1f}:1 | |
| """ | |
| return final_summary, summary_stats, "β Summary generated successfully!" | |
| except Exception as e: | |
| return f"β Error processing PDF: {str(e)}", "", "" | |
| pdf_summarizer = PDFSummarizer() | |
| global_pdf_text = "" # used for QA | |
| def summarize_pdf_interface(pdf_file, summary_type): | |
| global global_pdf_text | |
| if pdf_file is None: | |
| return "β Please upload a PDF file.", "", "" | |
| try: | |
| with open(pdf_file, 'rb') as f: | |
| pdf_content = f.read() | |
| global_pdf_text = pdf_summarizer.clean_text(pdf_summarizer.extract_text_from_pdf(pdf_content)) | |
| summary, stats, status = pdf_summarizer.process_pdf(pdf_content, summary_type) | |
| return summary, stats, status | |
| except Exception as e: | |
| return f"β Error: {str(e)}", "", "" | |
| # === NEW: QA FUNCTION === | |
| def answer_question_interface(question): | |
| if not global_pdf_text: | |
| return "β Please upload and summarize a PDF first." | |
| try: | |
| answer = qa_pipeline(question=question, context=global_pdf_text) | |
| return answer["answer"] | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # === GRADIO INTERFACE === | |
| def create_interface(): | |
| with gr.Blocks(title="π AI PDF Summarizer & QA", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# π PDF Summarizer + π¬ Question Answering") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_input = gr.File(label="π Upload PDF", file_types=[".pdf"], type="filepath") | |
| summary_type = gr.Radio( | |
| choices=["Brief (Quick)", "Detailed", "Comprehensive"], | |
| value="Detailed", | |
| label="π Summary Length" | |
| ) | |
| summarize_btn = gr.Button("π Generate Summary", variant="primary") | |
| status_output = gr.Textbox(label="π Status", interactive=False, max_lines=2) | |
| with gr.Column(scale=2): | |
| summary_output = gr.Textbox(label="π Summary", lines=15, interactive=False) | |
| stats_output = gr.Markdown(label="π Document Statistics") | |
| summarize_btn.click( | |
| fn=summarize_pdf_interface, | |
| inputs=[pdf_input, summary_type], | |
| outputs=[summary_output, stats_output, status_output] | |
| ) | |
| pdf_input.change( | |
| fn=summarize_pdf_interface, | |
| inputs=[pdf_input, summary_type], | |
| outputs=[summary_output, stats_output, status_output] | |
| ) | |
| gr.Markdown("## π¬ Ask a Question About the PDF") | |
| with gr.Row(): | |
| question_input = gr.Textbox(label="β Your Question", placeholder="e.g. What is the main finding?") | |
| answer_output = gr.Textbox(label="π‘ Answer", interactive=False) | |
| question_input.submit(fn=answer_question_interface, inputs=question_input, outputs=answer_output) | |
| return interface | |
| # === MAIN === | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch() | |