File size: 11,746 Bytes
72e0c96
c1ea31c
95abb5a
c1ea31c
7bf9c65
 
 
c1ea31c
 
95abb5a
7bf9c65
c1ea31c
7bf9c65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
 
 
 
7bf9c65
 
 
c1ea31c
 
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
95abb5a
c1ea31c
 
95abb5a
c1ea31c
 
 
 
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
95abb5a
c1ea31c
7bf9c65
 
c1ea31c
95abb5a
c1ea31c
7bf9c65
 
 
 
c1ea31c
 
7bf9c65
c1ea31c
7bf9c65
c1ea31c
7bf9c65
 
 
c1ea31c
7bf9c65
c1ea31c
7bf9c65
c1ea31c
 
7bf9c65
 
c1ea31c
 
 
 
 
 
 
 
 
 
3eeedea
c1ea31c
 
 
3eeedea
c1ea31c
95abb5a
c1ea31c
3eeedea
 
95abb5a
c1ea31c
3eeedea
 
95abb5a
c1ea31c
3eeedea
c1ea31c
95abb5a
c1ea31c
3eeedea
c1ea31c
95abb5a
c1ea31c
 
95abb5a
c1ea31c
 
 
95abb5a
c1ea31c
 
3eeedea
c1ea31c
 
 
95abb5a
c1ea31c
3eeedea
 
95abb5a
c1ea31c
3eeedea
95abb5a
c1ea31c
 
95abb5a
c1ea31c
 
95abb5a
c1ea31c
 
 
 
 
 
 
95abb5a
3eeedea
 
95abb5a
c1ea31c
3eeedea
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
 
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e0c96
c1ea31c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
72e0c96
95abb5a
c1ea31c
 
 
72e0c96
95abb5a
72e0c96
c1ea31c
 
 
 
 
95abb5a
c1ea31c
 
 
 
 
 
 
 
 
 
3eeedea
c1ea31c
 
 
3eeedea
72e0c96
c1ea31c
 
 
 
 
72e0c96
 
95abb5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import gradio as gr
import PyPDF2
import re
import json
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import tempfile
import os

# Initialize the model and tokenizer directly
print("Loading models... This may take a minute on first run.")

model_name = "valhalla/t5-small-qg-hl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Set to evaluation mode and CPU
model.eval()
device = torch.device("cpu")
model.to(device)

def generate_questions(context: str, answer: str, max_length: int = 128) -> str:
    """Generate a question using T5 model."""
    try:
        # Format: "generate question: <hl> answer <hl> context"
        input_text = f"generate question: <hl> {answer} <hl> {context}"
        
        # Tokenize
        inputs = tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding=True
        ).to(device)
        
        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                num_beams=4,
                early_stopping=True,
                do_sample=True,
                temperature=0.7
            )
        
        # Decode
        question = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up
        question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip()
        
        return question if len(question) > 10 else ""
        
    except Exception as e:
        print(f"Error generating question: {e}")
        return ""

def extract_text_from_pdf(pdf_file) -> str:
    """Extract text from uploaded PDF file."""
    text = ""
    try:
        if isinstance(pdf_file, str):
            pdf_reader = PyPDF2.PdfReader(pdf_file)
        else:
            pdf_reader = PyPDF2.PdfReader(pdf_file)
        
        for page in pdf_reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + "\n"
    except Exception as e:
        return f"Error reading PDF: {str(e)}"
    
    return text

def clean_text(text: str) -> str:
    """Clean and preprocess extracted text."""
    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text)
    # Remove special characters but keep sentence structure
    text = re.sub(r'[^\w\s.,;!?-]', '', text)
    return text.strip()

def chunk_text(text: str, max_chunk_size: int = 512, overlap: int = 50) -> List[str]:
    """Split text into overlapping chunks for processing."""
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = ""
    
    for sentence in sentences:
        if len(current_chunk) + len(sentence) < max_chunk_size:
            current_chunk += " " + sentence
        else:
            if current_chunk:
                chunks.append(current_chunk.strip())
            current_chunk = sentence
    
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    # Add overlap between chunks for context
    overlapped_chunks = []
    for i, chunk in enumerate(chunks):
        if i > 0 and overlap > 0:
            prev_sentences = chunks[i-1].split('. ')
            overlap_text = '. '.join(prev_sentences[-2:]) if len(prev_sentences) > 1 else chunks[i-1][-overlap:]
            chunk = overlap_text + " " + chunk
        overlapped_chunks.append(chunk)
    
    return overlapped_chunks

def generate_qa_pairs(chunk: str, num_questions: int = 2) -> List[Dict[str, str]]:
    """Generate question-answer pairs from a text chunk."""
    flashcards = []
    
    # Skip chunks that are too short
    words = chunk.split()
    if len(words) < 20:
        return []
    
    try:
        # Split into sentences to use as answers
        sentences = [s.strip() for s in chunk.split('. ') if len(s.strip()) > 20]
        
        if len(sentences) < 1:
            return []
        
        # Generate questions for different sentences
        for i in range(min(num_questions, len(sentences))):
            answer = sentences[i]
            
            # Skip very short answers
            if len(answer.split()) < 3:
                continue
            
            question = generate_questions(chunk, answer)
            
            if question and question != answer:  # Make sure they're different
                flashcards.append({
                    "question": question,
                    "answer": answer,
                    "context": chunk[:200] + "..." if len(chunk) > 200 else chunk
                })
                
    except Exception as e:
        print(f"Error generating QA: {e}")
    
    return flashcards

def process_pdf(pdf_file, questions_per_chunk: int = 2, max_chunks: int = 20):
    """Main processing function."""
    if pdf_file is None:
        return "Please upload a PDF file.", "", "", "Your flashcards will appear here..."
    
    try:
        # Extract text
        yield "πŸ“„ Extracting text from PDF...", "", "", "Processing..."
        raw_text = extract_text_from_pdf(pdf_file)
        
        if raw_text.startswith("Error"):
            yield raw_text, "", "", "Error occurred"
            return
        
        if len(raw_text.strip()) < 100:
            yield "PDF appears to be empty or contains no extractable text.", "", "", "Error occurred"
            return
        
        # Clean text
        yield "🧹 Cleaning text...", "", "", "Processing..."
        cleaned_text = clean_text(raw_text)
        
        # Chunk text
        yield "βœ‚οΈ Chunking text into sections...", "", "", "Processing..."
        chunks = chunk_text(cleaned_text)
        
        # Limit chunks for CPU performance
        chunks = chunks[:max_chunks]
        
        # Generate flashcards
        all_flashcards = []
        total_chunks = len(chunks)
        
        for i, chunk in enumerate(chunks):
            progress = f"🎴 Generating flashcards... ({i+1}/{total_chunks} chunks processed)"
            yield progress, "", "", "Processing..."
            
            cards = generate_qa_pairs(chunk, questions_per_chunk)
            all_flashcards.extend(cards)
        
        if not all_flashcards:
            yield "Could not generate flashcards from this PDF. Try a PDF with more textual content.", "", "", "No flashcards generated"
            return
        
        # Format output
        yield "βœ… Finalizing...", "", "", "Almost done..."
        
        # Create formatted display
        display_text = format_flashcards_display(all_flashcards)
        
        # Create JSON download
        json_output = json.dumps(all_flashcards, indent=2, ensure_ascii=False)
        
        # Create Anki/CSV format
        csv_lines = ["Question,Answer"]
        for card in all_flashcards:
            q = card['question'].replace('"', '""')
            a = card['answer'].replace('"', '""')
            csv_lines.append(f'"{q}","{a}"')
        csv_output = "\n".join(csv_lines)
        
        # FINAL OUTPUT - this updates all components
        yield "βœ… Done! Generated {} flashcards".format(len(all_flashcards)), csv_output, json_output, display_text
        
    except Exception as e:
        error_msg = f"Error processing PDF: {str(e)}"
        print(error_msg)
        yield error_msg, "", "", error_msg

def format_flashcards_display(flashcards: List[Dict]) -> str:
    """Format flashcards for nice display."""
    lines = [f"## 🎴 Generated {len(flashcards)} Flashcards\n"]
    
    for i, card in enumerate(flashcards, 1):
        lines.append(f"### Card {i}")
        lines.append(f"**Q:** {card['question']}")
        lines.append(f"**A:** {card['answer']}")
        lines.append(f"*Context: {card['context'][:100]}...*\n")
        lines.append("---\n")
    
    return "\n".join(lines)

def create_sample_flashcard():
    """Create a sample flashcard for demo purposes."""
    sample = [{
        "question": "What is the capital of France?",
        "answer": "Paris is the capital and most populous city of France.",
        "context": "Paris is the capital and most populous city of France..."
    }]
    return format_flashcards_display(sample)

# Custom CSS for better styling
custom_css = """
.flashcard-container {
    border: 2px solid #e0e0e0;
    border-radius: 10px;
    padding: 20px;
    margin: 10px 0;
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
    color: white;
}
.question {
    font-size: 1.2em;
    font-weight: bold;
    margin-bottom: 10px;
}
.answer {
    font-size: 1em;
    opacity: 0.9;
}
"""

# Gradio Interface
with gr.Blocks(css=custom_css, title="PDF to Flashcards") as demo:
    gr.Markdown("""
    # πŸ“š PDF to Flashcards Generator
    
    Upload any PDF document and automatically generate study flashcards (Q&A pairs) using AI.
    
    **Features:**
    - 🧠 Uses local CPU-friendly AI (no GPU needed)
    - πŸ“„ Extracts text from any PDF
    - βœ‚οΈ Intelligently chunks content
    - 🎴 Generates question-answer pairs
    - πŸ’Ύ Export to CSV (Anki-compatible) or JSON
    
    *Note: Processing is done entirely on CPU, so large PDFs may take a few minutes.*
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            pdf_input = gr.File(
                label="Upload PDF",
                file_types=[".pdf"],
                type="filepath"
            )
            
            with gr.Row():
                questions_per_chunk = gr.Slider(
                    minimum=1,
                    maximum=5,
                    value=2,
                    step=1,
                    label="Questions per section"
                )
                max_chunks = gr.Slider(
                    minimum=5,
                    maximum=50,
                    value=20,
                    step=5,
                    label="Max sections to process"
                )
            
            process_btn = gr.Button("πŸš€ Generate Flashcards", variant="primary")
            
            gr.Markdown("""
            ### πŸ’‘ Tips:
            - Text-based PDFs work best (scanned images won't work)
            - Academic papers and articles work great
            - Adjust "Questions per section" based on content density
            """)
        
        with gr.Column(scale=2):
            status_text = gr.Textbox(
                label="Status",
                value="Ready to process PDF...",
                interactive=False
            )
            
            output_display = gr.Markdown(
                label="Generated Flashcards",
                value="Your flashcards will appear here..."
            )
    
    with gr.Row():
        with gr.Column():
            csv_output = gr.Textbox(
                label="CSV Format (for Anki import)",
                lines=10,
                visible=True
            )
            gr.Markdown("*Copy the CSV content and save as `.csv` file to import into Anki*")
        
        with gr.Column():
            json_output = gr.Textbox(
                label="JSON Format",
                lines=10,
                visible=True
            )
            gr.Markdown("*Raw JSON data for custom applications*")
    
    # FIXED: Direct binding without the broken .then() chain
    process_btn.click(
        fn=process_pdf,
        inputs=[pdf_input, questions_per_chunk, max_chunks],
        outputs=[status_text, csv_output, json_output, output_display]
    )
    
    # Example section
    gr.Markdown("---")
    gr.Markdown("### 🎯 Example Output Format")
    gr.Markdown(create_sample_flashcard())

if __name__ == "__main__":
    demo.launch()