File size: 8,964 Bytes
81ab677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import gradio as gr
from transformers import pipeline
from pypdf import PdfReader
import torch
import math

# --- Configuration & Model Loading ---

# Use GPU if available, otherwise CPU
device = 0 if torch.cuda.is_available() else -1

print(f"Loading models on device: {'GPU' if device == 0 else 'CPU'}...")

# 1. Summarization Model
# 'facebook/bart-large-cnn' is excellent for abstractive summarization
summarizer = pipeline(
    "summarization",
    model="facebook/bart-large-cnn",
    device=device
)

# 2. Question Generation Model
# Using a specific lightweight model for QG to ensure quality questions
# Running this on CPU is fast enough if GPU isn't available
qg_pipeline = pipeline(
    "text2text-generation",
    model="valhalla/t5-small-e2e-qg",
    device=device
)

print("Models loaded successfully.")

# --- Core Logic Functions ---

def extract_text_from_pdf(pdf_file):
    """Extracts text from the uploaded PDF file."""
    if pdf_file is None:
        return ""
    
    try:
        reader = PdfReader(pdf_file.name)
        text = ""
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + "\n"
        return text.strip()
    except Exception as e:
        return f"Error reading PDF: {str(e)}"

def split_text_into_chunks(text, max_chunk_len=3000):
    """
    Splits text into chunks safe for the model (BART limit is ~1024 tokens).
    We use character length as a safe proxy (~4 chars/token).
    """
    words = text.split()
    chunks = []
    current_chunk = []
    current_length = 0
    
    for word in words:
        if current_length + len(word) + 1 > max_chunk_len:
            chunks.append(" ".join(current_chunk))
            current_chunk = [word]
            current_length = len(word)
        else:
            current_chunk.append(word)
            current_length += len(word) + 1
            
    if current_chunk:
        chunks.append(" ".join(current_chunk))
    return chunks

def generate_summary(text, length_mode="Medium"):
    """
    Summarizes text. Handles long text by chunking.
    recursive summarization is applied if text is too long.
    """
    if not text:
        return "No text provided."
    
    # Define constraints based on user choice
    if length_mode == "Short":
        max_len, min_len = 100, 30
    elif length_mode == "Long":
        max_len, min_len = 400, 150
    else: # Medium
        max_len, min_len = 250, 60

    # If text is short enough, summarize directly
    if len(text) < 3000:
        try:
            # Clamp constraints to text length to avoid model errors on very short inputs
            input_len = len(text.split())
            adjusted_max = min(max_len, max(input_len // 2, 20))
            adjusted_min = min(min_len, max(adjusted_max - 10, 5))
            
            summary = summarizer(text, max_length=adjusted_max, min_length=adjusted_min, do_sample=False)[0]['summary_text']
            return summary
        except Exception as e:
            return f"Error in summarization: {str(e)}"

    # If text is long, chunk it
    chunks = split_text_into_chunks(text, max_chunk_len=3000)
    chunk_summaries = []
    
    for chunk in chunks:
        try:
            # Summarize each chunk
            res = summarizer(chunk, max_length=150, min_length=40, do_sample=False)
            chunk_summaries.append(res[0]['summary_text'])
        except Exception as e:
            print(f"Skipping chunk due to error: {e}")
            continue
            
    # Combine chunk summaries
    combined_text = " ".join(chunk_summaries)
    
    # Recursive pass: if the combined summary is still too long, summarize it again
    # Otherwise return the concatenated summaries (to avoid losing too much detail)
    if len(combined_text) > 4000:
        return generate_summary(combined_text, length_mode)
    else:
        return combined_text

def generate_questions_list(text, num_questions=10):
    """Generates a list of questions based on the text."""
    if not text:
        return []
        
    # QG models work best on shorter contexts. We'll use the generated summary 
    # as context if the text is too long, or the text itself if short.
    # However, generating 10 distinct questions usually requires providing 
    # answers or using an end-to-end generator.
    # valhalla/t5-small-e2e-qg generates questions directly.
    
    try:
        # We process the text in segments to get enough questions
        chunks = split_text_into_chunks(text, max_chunk_len=2000)
        questions = []
        
        # Limit chunks to avoid taking forever (process first few chunks or spread them)
        selected_chunks = chunks[:5] 
        
        for chunk in selected_chunks:
            # This specific model generates questions given text with "generate questions: " prefix
            # Note: actual usage might vary, but standard T5-e2e works like this or just raw text
            # The valhalla model is trained to output questions.
            input_text = "generate questions: " + chunk
            
            # Generate multiple sequences
            outputs = qg_pipeline(
                input_text, 
                max_length=64,
                num_return_sequences=2,
                do_sample=True,
                top_k=50, 
                top_p=0.95
            )
            
            for out in outputs:
                q = out['generated_text']
                if q not in questions:
                    questions.append(q)
            
            if len(questions) >= num_questions:
                break
                
        return questions[:num_questions]
    except Exception as e:
        return [f"Could not generate questions: {str(e)}"]

def format_bullet_notes(summary_text):
    """Parses a prose summary into bullet points by splitting sentences."""
    sentences = summary_text.replace(". ", ".\n").split("\n")
    bullets = [f"- {s.strip()}" for s in sentences if s.strip()]
    return "\n".join(bullets)

# --- Main App Logic ---

def process_pdf_data(file_obj, length_mode, enable_questions):
    if file_obj is None:
        return "Please upload a PDF file.", "", ""
    
    # 1. Extract Text
    raw_text = extract_text_from_pdf(file_obj)
    if not raw_text or len(raw_text) < 50:
        return "Error: Could not extract text from PDF or PDF is empty.", "", ""
    
    status_msg = f"Extracted {len(raw_text)} characters. Processing..."
    print(status_msg)
    
    # 2. Summarize
    # We pass the raw text. The function handles chunking.
    final_summary = generate_summary(raw_text, length_mode)
    
    # 3. Create Notes (Formatted Summary)
    notes_markdown = "### πŸ“ Key Bullet Notes\n\n" + format_bullet_notes(final_summary)
    
    # 4. Generate Questions (if requested)
    questions_markdown = ""
    if enable_questions:
        # We use the summary as context for questions to ensure they focus on key points, 
        # unless summary is too short, then we use a part of raw text.
        context_for_q = final_summary if len(final_summary) > 500 else raw_text[:2000]
        qs = generate_questions_list(context_for_q, num_questions=10)
        
        questions_markdown = "### ❓ Important Questions\n\n"
        for i, q in enumerate(qs, 1):
            questions_markdown += f"{i}. {q}\n"
    
    # Combine Summary for display
    summary_markdown = f"### πŸ“– Summary\n\n{final_summary}"
    
    return summary_markdown, notes_markdown, questions_markdown

# --- Gradio UI ---

theme = gr.themes.Soft(
    primary_hue="blue",
    secondary_hue="slate",
)

with gr.Blocks(theme=theme, title="AI Notes Maker") as app:
    gr.Markdown(
        """
        # πŸ“‘ AI Notes Maker
        Upload a PDF lecture, paper, or article. Get a summary, key notes, and study questions instantly.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
            
            with gr.Accordion("Settings", open=True):
                length_slider = gr.Radio(
                    ["Short", "Medium", "Long"], 
                    label="Notes Length", 
                    value="Medium"
                )
                question_check = gr.Checkbox(
                    label="Generate Important Questions", 
                    value=True
                )
            
            submit_btn = gr.Button("Generate Notes", variant="primary")
        
        with gr.Column(scale=2):
            output_summary = gr.Markdown(label="Summary")
            output_notes = gr.Markdown(label="Key Notes")
            output_questions = gr.Markdown(label="Questions")
            
    submit_btn.click(
        fn=process_pdf_data,
        inputs=[pdf_input, length_slider, question_check],
        outputs=[output_summary, output_notes, output_questions]
    )

if __name__ == "__main__":
    app.launch()