Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| from pypdf import PdfReader | |
| import torch | |
| import math | |
| # --- Configuration & Model Loading --- | |
| # Use GPU if available, otherwise CPU | |
| device = 0 if torch.cuda.is_available() else -1 | |
| print(f"Loading models on device: {'GPU' if device == 0 else 'CPU'}...") | |
| # 1. Summarization Model | |
| # 'facebook/bart-large-cnn' is excellent for abstractive summarization | |
| summarizer = pipeline( | |
| "summarization", | |
| model="facebook/bart-large-cnn", | |
| device=device | |
| ) | |
| # 2. Question Generation Model | |
| # Using a specific lightweight model for QG to ensure quality questions | |
| # Running this on CPU is fast enough if GPU isn't available | |
| qg_pipeline = pipeline( | |
| "text2text-generation", | |
| model="valhalla/t5-small-e2e-qg", | |
| device=device | |
| ) | |
| print("Models loaded successfully.") | |
| # --- Core Logic Functions --- | |
| def extract_text_from_pdf(pdf_file): | |
| """Extracts text from the uploaded PDF file.""" | |
| if pdf_file is None: | |
| return "" | |
| try: | |
| reader = PdfReader(pdf_file.name) | |
| text = "" | |
| for page in reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text + "\n" | |
| return text.strip() | |
| except Exception as e: | |
| return f"Error reading PDF: {str(e)}" | |
| def split_text_into_chunks(text, max_chunk_len=3000): | |
| """ | |
| Splits text into chunks safe for the model (BART limit is ~1024 tokens). | |
| We use character length as a safe proxy (~4 chars/token). | |
| """ | |
| words = text.split() | |
| chunks = [] | |
| current_chunk = [] | |
| current_length = 0 | |
| for word in words: | |
| if current_length + len(word) + 1 > max_chunk_len: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = [word] | |
| current_length = len(word) | |
| else: | |
| current_chunk.append(word) | |
| current_length += len(word) + 1 | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| return chunks | |
| def generate_summary(text, length_mode="Medium"): | |
| """ | |
| Summarizes text. Handles long text by chunking. | |
| recursive summarization is applied if text is too long. | |
| """ | |
| if not text: | |
| return "No text provided." | |
| # Define constraints based on user choice | |
| if length_mode == "Short": | |
| max_len, min_len = 100, 30 | |
| elif length_mode == "Long": | |
| max_len, min_len = 400, 150 | |
| else: # Medium | |
| max_len, min_len = 250, 60 | |
| # If text is short enough, summarize directly | |
| if len(text) < 3000: | |
| try: | |
| # Clamp constraints to text length to avoid model errors on very short inputs | |
| input_len = len(text.split()) | |
| adjusted_max = min(max_len, max(input_len // 2, 20)) | |
| adjusted_min = min(min_len, max(adjusted_max - 10, 5)) | |
| summary = summarizer(text, max_length=adjusted_max, min_length=adjusted_min, do_sample=False)[0]['summary_text'] | |
| return summary | |
| except Exception as e: | |
| return f"Error in summarization: {str(e)}" | |
| # If text is long, chunk it | |
| chunks = split_text_into_chunks(text, max_chunk_len=3000) | |
| chunk_summaries = [] | |
| for chunk in chunks: | |
| try: | |
| # Summarize each chunk | |
| res = summarizer(chunk, max_length=150, min_length=40, do_sample=False) | |
| chunk_summaries.append(res[0]['summary_text']) | |
| except Exception as e: | |
| print(f"Skipping chunk due to error: {e}") | |
| continue | |
| # Combine chunk summaries | |
| combined_text = " ".join(chunk_summaries) | |
| # Recursive pass: if the combined summary is still too long, summarize it again | |
| # Otherwise return the concatenated summaries (to avoid losing too much detail) | |
| if len(combined_text) > 4000: | |
| return generate_summary(combined_text, length_mode) | |
| else: | |
| return combined_text | |
| def generate_questions_list(text, num_questions=10): | |
| """Generates a list of questions based on the text.""" | |
| if not text: | |
| return [] | |
| # QG models work best on shorter contexts. We'll use the generated summary | |
| # as context if the text is too long, or the text itself if short. | |
| # However, generating 10 distinct questions usually requires providing | |
| # answers or using an end-to-end generator. | |
| # valhalla/t5-small-e2e-qg generates questions directly. | |
| try: | |
| # We process the text in segments to get enough questions | |
| chunks = split_text_into_chunks(text, max_chunk_len=2000) | |
| questions = [] | |
| # Limit chunks to avoid taking forever (process first few chunks or spread them) | |
| selected_chunks = chunks[:5] | |
| for chunk in selected_chunks: | |
| # This specific model generates questions given text with "generate questions: " prefix | |
| # Note: actual usage might vary, but standard T5-e2e works like this or just raw text | |
| # The valhalla model is trained to output questions. | |
| input_text = "generate questions: " + chunk | |
| # Generate multiple sequences | |
| outputs = qg_pipeline( | |
| input_text, | |
| max_length=64, | |
| num_return_sequences=2, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95 | |
| ) | |
| for out in outputs: | |
| q = out['generated_text'] | |
| if q not in questions: | |
| questions.append(q) | |
| if len(questions) >= num_questions: | |
| break | |
| return questions[:num_questions] | |
| except Exception as e: | |
| return [f"Could not generate questions: {str(e)}"] | |
| def format_bullet_notes(summary_text): | |
| """Parses a prose summary into bullet points by splitting sentences.""" | |
| sentences = summary_text.replace(". ", ".\n").split("\n") | |
| bullets = [f"- {s.strip()}" for s in sentences if s.strip()] | |
| return "\n".join(bullets) | |
| # --- Main App Logic --- | |
| def process_pdf_data(file_obj, length_mode, enable_questions): | |
| if file_obj is None: | |
| return "Please upload a PDF file.", "", "" | |
| # 1. Extract Text | |
| raw_text = extract_text_from_pdf(file_obj) | |
| if not raw_text or len(raw_text) < 50: | |
| return "Error: Could not extract text from PDF or PDF is empty.", "", "" | |
| status_msg = f"Extracted {len(raw_text)} characters. Processing..." | |
| print(status_msg) | |
| # 2. Summarize | |
| # We pass the raw text. The function handles chunking. | |
| final_summary = generate_summary(raw_text, length_mode) | |
| # 3. Create Notes (Formatted Summary) | |
| notes_markdown = "### π Key Bullet Notes\n\n" + format_bullet_notes(final_summary) | |
| # 4. Generate Questions (if requested) | |
| questions_markdown = "" | |
| if enable_questions: | |
| # We use the summary as context for questions to ensure they focus on key points, | |
| # unless summary is too short, then we use a part of raw text. | |
| context_for_q = final_summary if len(final_summary) > 500 else raw_text[:2000] | |
| qs = generate_questions_list(context_for_q, num_questions=10) | |
| questions_markdown = "### β Important Questions\n\n" | |
| for i, q in enumerate(qs, 1): | |
| questions_markdown += f"{i}. {q}\n" | |
| # Combine Summary for display | |
| summary_markdown = f"### π Summary\n\n{final_summary}" | |
| return summary_markdown, notes_markdown, questions_markdown | |
| # --- Gradio UI --- | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="slate", | |
| ) | |
| with gr.Blocks(theme=theme, title="AI Notes Maker") as app: | |
| gr.Markdown( | |
| """ | |
| # π AI Notes Maker | |
| Upload a PDF lecture, paper, or article. Get a summary, key notes, and study questions instantly. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| with gr.Accordion("Settings", open=True): | |
| length_slider = gr.Radio( | |
| ["Short", "Medium", "Long"], | |
| label="Notes Length", | |
| value="Medium" | |
| ) | |
| question_check = gr.Checkbox( | |
| label="Generate Important Questions", | |
| value=True | |
| ) | |
| submit_btn = gr.Button("Generate Notes", variant="primary") | |
| with gr.Column(scale=2): | |
| output_summary = gr.Markdown(label="Summary") | |
| output_notes = gr.Markdown(label="Key Notes") | |
| output_questions = gr.Markdown(label="Questions") | |
| submit_btn.click( | |
| fn=process_pdf_data, | |
| inputs=[pdf_input, length_slider, question_check], | |
| outputs=[output_summary, output_notes, output_questions] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |