example / app.py
heerjtdev's picture
Update app.py
3eeedea verified
raw
history blame
11.7 kB
import gradio as gr
import PyPDF2
import re
import json
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import tempfile
import os
# Initialize the model and tokenizer directly
print("Loading models... This may take a minute on first run.")
model_name = "valhalla/t5-small-qg-hl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Set to evaluation mode and CPU
model.eval()
device = torch.device("cpu")
model.to(device)
def generate_questions(context: str, answer: str, max_length: int = 128) -> str:
"""Generate a question using T5 model."""
try:
# Format: "generate question: <hl> answer <hl> context"
input_text = f"generate question: <hl> {answer} <hl> {context}"
# Tokenize
inputs = tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
).to(device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True,
do_sample=True,
temperature=0.7
)
# Decode
question = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up
question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip()
return question if len(question) > 10 else ""
except Exception as e:
print(f"Error generating question: {e}")
return ""
def extract_text_from_pdf(pdf_file) -> str:
"""Extract text from uploaded PDF file."""
text = ""
try:
if isinstance(pdf_file, str):
pdf_reader = PyPDF2.PdfReader(pdf_file)
else:
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page in pdf_reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
except Exception as e:
return f"Error reading PDF: {str(e)}"
return text
def clean_text(text: str) -> str:
"""Clean and preprocess extracted text."""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove special characters but keep sentence structure
text = re.sub(r'[^\w\s.,;!?-]', '', text)
return text.strip()
def chunk_text(text: str, max_chunk_size: int = 512, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks for processing."""
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chunk_size:
current_chunk += " " + sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk.strip())
# Add overlap between chunks for context
overlapped_chunks = []
for i, chunk in enumerate(chunks):
if i > 0 and overlap > 0:
prev_sentences = chunks[i-1].split('. ')
overlap_text = '. '.join(prev_sentences[-2:]) if len(prev_sentences) > 1 else chunks[i-1][-overlap:]
chunk = overlap_text + " " + chunk
overlapped_chunks.append(chunk)
return overlapped_chunks
def generate_qa_pairs(chunk: str, num_questions: int = 2) -> List[Dict[str, str]]:
"""Generate question-answer pairs from a text chunk."""
flashcards = []
# Skip chunks that are too short
words = chunk.split()
if len(words) < 20:
return []
try:
# Split into sentences to use as answers
sentences = [s.strip() for s in chunk.split('. ') if len(s.strip()) > 20]
if len(sentences) < 1:
return []
# Generate questions for different sentences
for i in range(min(num_questions, len(sentences))):
answer = sentences[i]
# Skip very short answers
if len(answer.split()) < 3:
continue
question = generate_questions(chunk, answer)
if question and question != answer: # Make sure they're different
flashcards.append({
"question": question,
"answer": answer,
"context": chunk[:200] + "..." if len(chunk) > 200 else chunk
})
except Exception as e:
print(f"Error generating QA: {e}")
return flashcards
def process_pdf(pdf_file, questions_per_chunk: int = 2, max_chunks: int = 20):
"""Main processing function."""
if pdf_file is None:
return "Please upload a PDF file.", "", "", "Your flashcards will appear here..."
try:
# Extract text
yield "πŸ“„ Extracting text from PDF...", "", "", "Processing..."
raw_text = extract_text_from_pdf(pdf_file)
if raw_text.startswith("Error"):
yield raw_text, "", "", "Error occurred"
return
if len(raw_text.strip()) < 100:
yield "PDF appears to be empty or contains no extractable text.", "", "", "Error occurred"
return
# Clean text
yield "🧹 Cleaning text...", "", "", "Processing..."
cleaned_text = clean_text(raw_text)
# Chunk text
yield "βœ‚οΈ Chunking text into sections...", "", "", "Processing..."
chunks = chunk_text(cleaned_text)
# Limit chunks for CPU performance
chunks = chunks[:max_chunks]
# Generate flashcards
all_flashcards = []
total_chunks = len(chunks)
for i, chunk in enumerate(chunks):
progress = f"🎴 Generating flashcards... ({i+1}/{total_chunks} chunks processed)"
yield progress, "", "", "Processing..."
cards = generate_qa_pairs(chunk, questions_per_chunk)
all_flashcards.extend(cards)
if not all_flashcards:
yield "Could not generate flashcards from this PDF. Try a PDF with more textual content.", "", "", "No flashcards generated"
return
# Format output
yield "βœ… Finalizing...", "", "", "Almost done..."
# Create formatted display
display_text = format_flashcards_display(all_flashcards)
# Create JSON download
json_output = json.dumps(all_flashcards, indent=2, ensure_ascii=False)
# Create Anki/CSV format
csv_lines = ["Question,Answer"]
for card in all_flashcards:
q = card['question'].replace('"', '""')
a = card['answer'].replace('"', '""')
csv_lines.append(f'"{q}","{a}"')
csv_output = "\n".join(csv_lines)
# FINAL OUTPUT - this updates all components
yield "βœ… Done! Generated {} flashcards".format(len(all_flashcards)), csv_output, json_output, display_text
except Exception as e:
error_msg = f"Error processing PDF: {str(e)}"
print(error_msg)
yield error_msg, "", "", error_msg
def format_flashcards_display(flashcards: List[Dict]) -> str:
"""Format flashcards for nice display."""
lines = [f"## 🎴 Generated {len(flashcards)} Flashcards\n"]
for i, card in enumerate(flashcards, 1):
lines.append(f"### Card {i}")
lines.append(f"**Q:** {card['question']}")
lines.append(f"**A:** {card['answer']}")
lines.append(f"*Context: {card['context'][:100]}...*\n")
lines.append("---\n")
return "\n".join(lines)
def create_sample_flashcard():
"""Create a sample flashcard for demo purposes."""
sample = [{
"question": "What is the capital of France?",
"answer": "Paris is the capital and most populous city of France.",
"context": "Paris is the capital and most populous city of France..."
}]
return format_flashcards_display(sample)
# Custom CSS for better styling
custom_css = """
.flashcard-container {
border: 2px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.question {
font-size: 1.2em;
font-weight: bold;
margin-bottom: 10px;
}
.answer {
font-size: 1em;
opacity: 0.9;
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css, title="PDF to Flashcards") as demo:
gr.Markdown("""
# πŸ“š PDF to Flashcards Generator
Upload any PDF document and automatically generate study flashcards (Q&A pairs) using AI.
**Features:**
- 🧠 Uses local CPU-friendly AI (no GPU needed)
- πŸ“„ Extracts text from any PDF
- βœ‚οΈ Intelligently chunks content
- 🎴 Generates question-answer pairs
- πŸ’Ύ Export to CSV (Anki-compatible) or JSON
*Note: Processing is done entirely on CPU, so large PDFs may take a few minutes.*
""")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(
label="Upload PDF",
file_types=[".pdf"],
type="filepath"
)
with gr.Row():
questions_per_chunk = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Questions per section"
)
max_chunks = gr.Slider(
minimum=5,
maximum=50,
value=20,
step=5,
label="Max sections to process"
)
process_btn = gr.Button("πŸš€ Generate Flashcards", variant="primary")
gr.Markdown("""
### πŸ’‘ Tips:
- Text-based PDFs work best (scanned images won't work)
- Academic papers and articles work great
- Adjust "Questions per section" based on content density
""")
with gr.Column(scale=2):
status_text = gr.Textbox(
label="Status",
value="Ready to process PDF...",
interactive=False
)
output_display = gr.Markdown(
label="Generated Flashcards",
value="Your flashcards will appear here..."
)
with gr.Row():
with gr.Column():
csv_output = gr.Textbox(
label="CSV Format (for Anki import)",
lines=10,
visible=True
)
gr.Markdown("*Copy the CSV content and save as `.csv` file to import into Anki*")
with gr.Column():
json_output = gr.Textbox(
label="JSON Format",
lines=10,
visible=True
)
gr.Markdown("*Raw JSON data for custom applications*")
# FIXED: Direct binding without the broken .then() chain
process_btn.click(
fn=process_pdf,
inputs=[pdf_input, questions_per_chunk, max_chunks],
outputs=[status_text, csv_output, json_output, output_display]
)
# Example section
gr.Markdown("---")
gr.Markdown("### 🎯 Example Output Format")
gr.Markdown(create_sample_flashcard())
if __name__ == "__main__":
demo.launch()