example / app.py
heerjtdev's picture
Update app.py
c1ea31c verified
raw
history blame
10.7 kB
import gradio as gr
import PyPDF2
import re
import json
from typing import List, Dict, Tuple
from transformers import pipeline
import tempfile
import os
# Initialize the question generation pipeline using a small CPU-friendly model
print("Loading models... This may take a minute on first run.")
qa_generator = pipeline(
"text2text-generation",
model="valhalla/t5-small-qg-hl",
tokenizer="valhalla/t5-small-qg-hl",
device=-1 # Force CPU
)
def extract_text_from_pdf(pdf_file) -> str:
"""Extract text from uploaded PDF file."""
text = ""
try:
# Handle both file path and file object
if isinstance(pdf_file, str):
pdf_reader = PyPDF2.PdfReader(pdf_file)
else:
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
except Exception as e:
return f"Error reading PDF: {str(e)}"
return text
def clean_text(text: str) -> str:
"""Clean and preprocess extracted text."""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove special characters but keep sentence structure
text = re.sub(r'[^\w\s.,;!?-]', '', text)
return text.strip()
def chunk_text(text: str, max_chunk_size: int = 512, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks for processing."""
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chunk_size:
current_chunk += " " + sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk.strip())
# Add overlap between chunks for context
overlapped_chunks = []
for i, chunk in enumerate(chunks):
if i > 0 and overlap > 0:
prev_sentences = chunks[i-1].split('. ')
overlap_text = '. '.join(prev_sentences[-2:]) if len(prev_sentences) > 1 else chunks[i-1][-overlap:]
chunk = overlap_text + " " + chunk
overlapped_chunks.append(chunk)
return overlapped_chunks
def generate_qa_pairs(chunk: str, num_questions: int = 2) -> List[Dict[str, str]]:
"""Generate question-answer pairs from a text chunk."""
flashcards = []
# Skip chunks that are too short
if len(chunk.split()) < 20:
return []
try:
# Generate highlight format for T5 question generation
# We'll create simple highlight by taking key sentences
sentences = chunk.split('. ')
if len(sentences) < 2:
return []
# Generate questions for different parts of the chunk
for i in range(min(num_questions, len(sentences))):
# Create highlight context
highlight = sentences[i]
context = chunk
# Format for T5: "generate question: <hl> highlight <hl> context"
input_text = f"generate question: <hl> {highlight} <hl> {context}"
# Generate question
outputs = qa_generator(
input_text,
max_length=128,
num_return_sequences=1,
do_sample=True,
temperature=0.7
)
question = outputs[0]['generated_text'].strip()
# Clean up question
question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip()
if question and len(question) > 10:
flashcards.append({
"question": question,
"answer": highlight.strip(),
"context": context[:200] + "..." if len(context) > 200 else context
})
except Exception as e:
print(f"Error generating QA: {e}")
return flashcards
def process_pdf(pdf_file, questions_per_chunk: int = 2, max_chunks: int = 20):
"""Main processing function."""
if pdf_file is None:
return "Please upload a PDF file.", None, None
try:
# Extract text
yield "πŸ“„ Extracting text from PDF...", None, None
raw_text = extract_text_from_pdf(pdf_file)
if raw_text.startswith("Error"):
return raw_text, None, None
if len(raw_text.strip()) < 100:
return "PDF appears to be empty or contains no extractable text.", None, None
# Clean text
yield "🧹 Cleaning text...", None, None
cleaned_text = clean_text(raw_text)
# Chunk text
yield "βœ‚οΈ Chunking text into sections...", None, None
chunks = chunk_text(cleaned_text)
# Limit chunks for CPU performance
chunks = chunks[:max_chunks]
# Generate flashcards
all_flashcards = []
total_chunks = len(chunks)
for i, chunk in enumerate(chunks):
progress = f"🎴 Generating flashcards... ({i+1}/{total_chunks} chunks processed)"
yield progress, None, None
cards = generate_qa_pairs(chunk, questions_per_chunk)
all_flashcards.extend(cards)
if not all_flashcards:
return "Could not generate flashcards from this PDF. Try a PDF with more textual content.", None, None
# Format output
yield "βœ… Finalizing...", None, None
# Create formatted display
display_text = format_flashcards_display(all_flashcards)
# Create JSON download
json_output = json.dumps(all_flashcards, indent=2, ensure_ascii=False)
# Create Anki/CSV format
csv_lines = ["Question,Answer"]
for card in all_flashcards:
q = card['question'].replace('"', '""')
a = card['answer'].replace('"', '""')
csv_lines.append(f'"{q}","{a}"')
csv_output = "\n".join(csv_lines)
return display_text, csv_output, json_output
except Exception as e:
return f"Error processing PDF: {str(e)}", None, None
def format_flashcards_display(flashcards: List[Dict]) -> str:
"""Format flashcards for nice display."""
lines = [f"## 🎴 Generated {len(flashcards)} Flashcards\n"]
for i, card in enumerate(flashcards, 1):
lines.append(f"### Card {i}")
lines.append(f"**Q:** {card['question']}")
lines.append(f"**A:** {card['answer']}")
lines.append(f"*Context: {card['context'][:100]}...*\n")
lines.append("---\n")
return "\n".join(lines)
def create_sample_flashcard():
"""Create a sample flashcard for demo purposes."""
sample = [{
"question": "What is the capital of France?",
"answer": "Paris is the capital and most populous city of France.",
"context": "Paris is the capital and most populous city of France..."
}]
return format_flashcards_display(sample)
# Custom CSS for better styling
custom_css = """
.flashcard-container {
border: 2px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.question {
font-size: 1.2em;
font-weight: bold;
margin-bottom: 10px;
}
.answer {
font-size: 1em;
opacity: 0.9;
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css, title="PDF to Flashcards") as demo:
gr.Markdown("""
# πŸ“š PDF to Flashcards Generator
Upload any PDF document and automatically generate study flashcards (Q&A pairs) using AI.
**Features:**
- 🧠 Uses local CPU-friendly AI (no GPU needed)
- πŸ“„ Extracts text from any PDF
- βœ‚οΈ Intelligently chunks content
- 🎴 Generates question-answer pairs
- πŸ’Ύ Export to CSV (Anki-compatible) or JSON
*Note: Processing is done entirely on CPU, so large PDFs may take a few minutes.*
""")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(
label="Upload PDF",
file_types=[".pdf"],
type="filepath"
)
with gr.Row():
questions_per_chunk = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Questions per section"
)
max_chunks = gr.Slider(
minimum=5,
maximum=50,
value=20,
step=5,
label="Max sections to process"
)
process_btn = gr.Button("πŸš€ Generate Flashcards", variant="primary")
gr.Markdown("""
### πŸ’‘ Tips:
- Text-based PDFs work best (scanned images won't work)
- Academic papers and articles work great
- Adjust "Questions per section" based on content density
""")
with gr.Column(scale=2):
status_text = gr.Textbox(
label="Status",
value="Ready to process PDF...",
interactive=False
)
output_display = gr.Markdown(
label="Generated Flashcards",
value="Your flashcards will appear here..."
)
with gr.Row():
with gr.Column():
csv_output = gr.Textbox(
label="CSV Format (for Anki import)",
lines=10,
visible=True
)
gr.Markdown("*Copy the CSV content and save as `.csv` file to import into Anki*")
with gr.Column():
json_output = gr.Textbox(
label="JSON Format",
lines=10,
visible=True
)
gr.Markdown("*Raw JSON data for custom applications*")
# Event handlers
process_btn.click(
fn=process_pdf,
inputs=[pdf_input, questions_per_chunk, max_chunks],
outputs=[status_text, csv_output, json_output]
).then(
fn=lambda x: x if not isinstance(x, str) or not x.startswith("πŸ“„") else gr.update(),
inputs=status_text,
outputs=output_display
)
# Example section
gr.Markdown("---")
gr.Markdown("### 🎯 Example Output Format")
gr.Markdown(create_sample_flashcard())
if __name__ == "__main__":
demo.launch()