|
|
import gradio as gr |
|
|
import PyPDF2 |
|
|
import re |
|
|
import json |
|
|
from typing import List, Dict |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
print("Loading models... This may take a minute on first run.") |
|
|
|
|
|
model_name = "valhalla/t5-small-qg-hl" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
device = torch.device("cpu") |
|
|
model.to(device) |
|
|
|
|
|
def generate_questions(context: str, answer: str, max_length: int = 128) -> str: |
|
|
"""Generate a question using T5 model.""" |
|
|
try: |
|
|
|
|
|
input_text = f"generate question: <hl> {answer} <hl> {context}" |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
input_text, |
|
|
return_tensors="pt", |
|
|
max_length=512, |
|
|
truncation=True, |
|
|
padding=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
num_beams=4, |
|
|
early_stopping=True, |
|
|
do_sample=True, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
|
|
|
question = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip() |
|
|
|
|
|
return question if len(question) > 10 else "" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating question: {e}") |
|
|
return "" |
|
|
|
|
|
def extract_text_from_pdf(pdf_file) -> str: |
|
|
"""Extract text from uploaded PDF file.""" |
|
|
text = "" |
|
|
try: |
|
|
if isinstance(pdf_file, str): |
|
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
|
else: |
|
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
|
|
|
|
for page in pdf_reader.pages: |
|
|
page_text = page.extract_text() |
|
|
if page_text: |
|
|
text += page_text + "\n" |
|
|
except Exception as e: |
|
|
return f"Error reading PDF: {str(e)}" |
|
|
|
|
|
return text |
|
|
|
|
|
def clean_text(text: str) -> str: |
|
|
"""Clean and preprocess extracted text.""" |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
text = re.sub(r'[^\w\s.,;!?-]', '', text) |
|
|
return text.strip() |
|
|
|
|
|
def chunk_text(text: str, max_chunk_size: int = 512, overlap: int = 50) -> List[str]: |
|
|
"""Split text into overlapping chunks for processing.""" |
|
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
|
chunks = [] |
|
|
current_chunk = "" |
|
|
|
|
|
for sentence in sentences: |
|
|
if len(current_chunk) + len(sentence) < max_chunk_size: |
|
|
current_chunk += " " + sentence |
|
|
else: |
|
|
if current_chunk: |
|
|
chunks.append(current_chunk.strip()) |
|
|
current_chunk = sentence |
|
|
|
|
|
if current_chunk: |
|
|
chunks.append(current_chunk.strip()) |
|
|
|
|
|
|
|
|
overlapped_chunks = [] |
|
|
for i, chunk in enumerate(chunks): |
|
|
if i > 0 and overlap > 0: |
|
|
prev_sentences = chunks[i-1].split('. ') |
|
|
overlap_text = '. '.join(prev_sentences[-2:]) if len(prev_sentences) > 1 else chunks[i-1][-overlap:] |
|
|
chunk = overlap_text + " " + chunk |
|
|
overlapped_chunks.append(chunk) |
|
|
|
|
|
return overlapped_chunks |
|
|
|
|
|
def generate_qa_pairs(chunk: str, num_questions: int = 2) -> List[Dict[str, str]]: |
|
|
"""Generate question-answer pairs from a text chunk.""" |
|
|
flashcards = [] |
|
|
|
|
|
|
|
|
words = chunk.split() |
|
|
if len(words) < 20: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
sentences = [s.strip() for s in chunk.split('. ') if len(s.strip()) > 20] |
|
|
|
|
|
if len(sentences) < 1: |
|
|
return [] |
|
|
|
|
|
|
|
|
for i in range(min(num_questions, len(sentences))): |
|
|
answer = sentences[i] |
|
|
|
|
|
|
|
|
if len(answer.split()) < 3: |
|
|
continue |
|
|
|
|
|
question = generate_questions(chunk, answer) |
|
|
|
|
|
if question and question != answer: |
|
|
flashcards.append({ |
|
|
"question": question, |
|
|
"answer": answer, |
|
|
"context": chunk[:200] + "..." if len(chunk) > 200 else chunk |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating QA: {e}") |
|
|
|
|
|
return flashcards |
|
|
|
|
|
def process_pdf(pdf_file, questions_per_chunk: int = 2, max_chunks: int = 20): |
|
|
"""Main processing function.""" |
|
|
if pdf_file is None: |
|
|
return "Please upload a PDF file.", "", "", "Your flashcards will appear here..." |
|
|
|
|
|
try: |
|
|
|
|
|
yield "π Extracting text from PDF...", "", "", "Processing..." |
|
|
raw_text = extract_text_from_pdf(pdf_file) |
|
|
|
|
|
if raw_text.startswith("Error"): |
|
|
yield raw_text, "", "", "Error occurred" |
|
|
return |
|
|
|
|
|
if len(raw_text.strip()) < 100: |
|
|
yield "PDF appears to be empty or contains no extractable text.", "", "", "Error occurred" |
|
|
return |
|
|
|
|
|
|
|
|
yield "π§Ή Cleaning text...", "", "", "Processing..." |
|
|
cleaned_text = clean_text(raw_text) |
|
|
|
|
|
|
|
|
yield "βοΈ Chunking text into sections...", "", "", "Processing..." |
|
|
chunks = chunk_text(cleaned_text) |
|
|
|
|
|
|
|
|
chunks = chunks[:max_chunks] |
|
|
|
|
|
|
|
|
all_flashcards = [] |
|
|
total_chunks = len(chunks) |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
|
progress = f"π΄ Generating flashcards... ({i+1}/{total_chunks} chunks processed)" |
|
|
yield progress, "", "", "Processing..." |
|
|
|
|
|
cards = generate_qa_pairs(chunk, questions_per_chunk) |
|
|
all_flashcards.extend(cards) |
|
|
|
|
|
if not all_flashcards: |
|
|
yield "Could not generate flashcards from this PDF. Try a PDF with more textual content.", "", "", "No flashcards generated" |
|
|
return |
|
|
|
|
|
|
|
|
yield "β
Finalizing...", "", "", "Almost done..." |
|
|
|
|
|
|
|
|
display_text = format_flashcards_display(all_flashcards) |
|
|
|
|
|
|
|
|
json_output = json.dumps(all_flashcards, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
csv_lines = ["Question,Answer"] |
|
|
for card in all_flashcards: |
|
|
q = card['question'].replace('"', '""') |
|
|
a = card['answer'].replace('"', '""') |
|
|
csv_lines.append(f'"{q}","{a}"') |
|
|
csv_output = "\n".join(csv_lines) |
|
|
|
|
|
|
|
|
yield "β
Done! Generated {} flashcards".format(len(all_flashcards)), csv_output, json_output, display_text |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error processing PDF: {str(e)}" |
|
|
print(error_msg) |
|
|
yield error_msg, "", "", error_msg |
|
|
|
|
|
def format_flashcards_display(flashcards: List[Dict]) -> str: |
|
|
"""Format flashcards for nice display.""" |
|
|
lines = [f"## π΄ Generated {len(flashcards)} Flashcards\n"] |
|
|
|
|
|
for i, card in enumerate(flashcards, 1): |
|
|
lines.append(f"### Card {i}") |
|
|
lines.append(f"**Q:** {card['question']}") |
|
|
lines.append(f"**A:** {card['answer']}") |
|
|
lines.append(f"*Context: {card['context'][:100]}...*\n") |
|
|
lines.append("---\n") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def create_sample_flashcard(): |
|
|
"""Create a sample flashcard for demo purposes.""" |
|
|
sample = [{ |
|
|
"question": "What is the capital of France?", |
|
|
"answer": "Paris is the capital and most populous city of France.", |
|
|
"context": "Paris is the capital and most populous city of France..." |
|
|
}] |
|
|
return format_flashcards_display(sample) |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.flashcard-container { |
|
|
border: 2px solid #e0e0e0; |
|
|
border-radius: 10px; |
|
|
padding: 20px; |
|
|
margin: 10px 0; |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
} |
|
|
.question { |
|
|
font-size: 1.2em; |
|
|
font-weight: bold; |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
.answer { |
|
|
font-size: 1em; |
|
|
opacity: 0.9; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, title="PDF to Flashcards") as demo: |
|
|
gr.Markdown(""" |
|
|
# π PDF to Flashcards Generator |
|
|
|
|
|
Upload any PDF document and automatically generate study flashcards (Q&A pairs) using AI. |
|
|
|
|
|
**Features:** |
|
|
- π§ Uses local CPU-friendly AI (no GPU needed) |
|
|
- π Extracts text from any PDF |
|
|
- βοΈ Intelligently chunks content |
|
|
- π΄ Generates question-answer pairs |
|
|
- πΎ Export to CSV (Anki-compatible) or JSON |
|
|
|
|
|
*Note: Processing is done entirely on CPU, so large PDFs may take a few minutes.* |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
pdf_input = gr.File( |
|
|
label="Upload PDF", |
|
|
file_types=[".pdf"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
questions_per_chunk = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=5, |
|
|
value=2, |
|
|
step=1, |
|
|
label="Questions per section" |
|
|
) |
|
|
max_chunks = gr.Slider( |
|
|
minimum=5, |
|
|
maximum=50, |
|
|
value=20, |
|
|
step=5, |
|
|
label="Max sections to process" |
|
|
) |
|
|
|
|
|
process_btn = gr.Button("π Generate Flashcards", variant="primary") |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π‘ Tips: |
|
|
- Text-based PDFs work best (scanned images won't work) |
|
|
- Academic papers and articles work great |
|
|
- Adjust "Questions per section" based on content density |
|
|
""") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
status_text = gr.Textbox( |
|
|
label="Status", |
|
|
value="Ready to process PDF...", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
output_display = gr.Markdown( |
|
|
label="Generated Flashcards", |
|
|
value="Your flashcards will appear here..." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
csv_output = gr.Textbox( |
|
|
label="CSV Format (for Anki import)", |
|
|
lines=10, |
|
|
visible=True |
|
|
) |
|
|
gr.Markdown("*Copy the CSV content and save as `.csv` file to import into Anki*") |
|
|
|
|
|
with gr.Column(): |
|
|
json_output = gr.Textbox( |
|
|
label="JSON Format", |
|
|
lines=10, |
|
|
visible=True |
|
|
) |
|
|
gr.Markdown("*Raw JSON data for custom applications*") |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_pdf, |
|
|
inputs=[pdf_input, questions_per_chunk, max_chunks], |
|
|
outputs=[status_text, csv_output, json_output, output_display] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### π― Example Output Format") |
|
|
gr.Markdown(create_sample_flashcard()) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |