import gradio as gr import PyPDF2 import re import json from typing import List, Dict from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import tempfile import os # Initialize the model and tokenizer directly print("Loading models... This may take a minute on first run.") model_name = "valhalla/t5-small-qg-hl" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Set to evaluation mode and CPU model.eval() device = torch.device("cpu") model.to(device) def generate_questions(context: str, answer: str, max_length: int = 128) -> str: """Generate a question using T5 model.""" try: # Format: "generate question: answer context" input_text = f"generate question: {answer} {context}" # Tokenize inputs = tokenizer( input_text, return_tensors="pt", max_length=512, truncation=True, padding=True ).to(device) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, num_beams=4, early_stopping=True, do_sample=True, temperature=0.7 ) # Decode question = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip() return question if len(question) > 10 else "" except Exception as e: print(f"Error generating question: {e}") return "" def extract_text_from_pdf(pdf_file) -> str: """Extract text from uploaded PDF file.""" text = "" try: if isinstance(pdf_file, str): pdf_reader = PyPDF2.PdfReader(pdf_file) else: pdf_reader = PyPDF2.PdfReader(pdf_file) for page in pdf_reader.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" except Exception as e: return f"Error reading PDF: {str(e)}" return text def clean_text(text: str) -> str: """Clean and preprocess extracted text.""" # Remove excessive whitespace text = re.sub(r'\s+', ' ', text) # Remove special characters but keep sentence structure text = re.sub(r'[^\w\s.,;!?-]', '', text) return text.strip() def chunk_text(text: str, max_chunk_size: int = 512, overlap: int = 50) -> List[str]: """Split text into overlapping chunks for processing.""" sentences = re.split(r'(?<=[.!?])\s+', text) chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) < max_chunk_size: current_chunk += " " + sentence else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence if current_chunk: chunks.append(current_chunk.strip()) # Add overlap between chunks for context overlapped_chunks = [] for i, chunk in enumerate(chunks): if i > 0 and overlap > 0: prev_sentences = chunks[i-1].split('. ') overlap_text = '. '.join(prev_sentences[-2:]) if len(prev_sentences) > 1 else chunks[i-1][-overlap:] chunk = overlap_text + " " + chunk overlapped_chunks.append(chunk) return overlapped_chunks def generate_qa_pairs(chunk: str, num_questions: int = 2) -> List[Dict[str, str]]: """Generate question-answer pairs from a text chunk.""" flashcards = [] # Skip chunks that are too short words = chunk.split() if len(words) < 20: return [] try: # Split into sentences to use as answers sentences = [s.strip() for s in chunk.split('. ') if len(s.strip()) > 20] if len(sentences) < 1: return [] # Generate questions for different sentences for i in range(min(num_questions, len(sentences))): answer = sentences[i] # Skip very short answers if len(answer.split()) < 3: continue question = generate_questions(chunk, answer) if question and question != answer: # Make sure they're different flashcards.append({ "question": question, "answer": answer, "context": chunk[:200] + "..." if len(chunk) > 200 else chunk }) except Exception as e: print(f"Error generating QA: {e}") return flashcards def process_pdf(pdf_file, questions_per_chunk: int = 2, max_chunks: int = 20): """Main processing function.""" if pdf_file is None: return "Please upload a PDF file.", "", "", "Your flashcards will appear here..." try: # Extract text yield "๐Ÿ“„ Extracting text from PDF...", "", "", "Processing..." raw_text = extract_text_from_pdf(pdf_file) if raw_text.startswith("Error"): yield raw_text, "", "", "Error occurred" return if len(raw_text.strip()) < 100: yield "PDF appears to be empty or contains no extractable text.", "", "", "Error occurred" return # Clean text yield "๐Ÿงน Cleaning text...", "", "", "Processing..." cleaned_text = clean_text(raw_text) # Chunk text yield "โœ‚๏ธ Chunking text into sections...", "", "", "Processing..." chunks = chunk_text(cleaned_text) # Limit chunks for CPU performance chunks = chunks[:max_chunks] # Generate flashcards all_flashcards = [] total_chunks = len(chunks) for i, chunk in enumerate(chunks): progress = f"๐ŸŽด Generating flashcards... ({i+1}/{total_chunks} chunks processed)" yield progress, "", "", "Processing..." cards = generate_qa_pairs(chunk, questions_per_chunk) all_flashcards.extend(cards) if not all_flashcards: yield "Could not generate flashcards from this PDF. Try a PDF with more textual content.", "", "", "No flashcards generated" return # Format output yield "โœ… Finalizing...", "", "", "Almost done..." # Create formatted display display_text = format_flashcards_display(all_flashcards) # Create JSON download json_output = json.dumps(all_flashcards, indent=2, ensure_ascii=False) # Create Anki/CSV format csv_lines = ["Question,Answer"] for card in all_flashcards: q = card['question'].replace('"', '""') a = card['answer'].replace('"', '""') csv_lines.append(f'"{q}","{a}"') csv_output = "\n".join(csv_lines) # FINAL OUTPUT - this updates all components yield "โœ… Done! Generated {} flashcards".format(len(all_flashcards)), csv_output, json_output, display_text except Exception as e: error_msg = f"Error processing PDF: {str(e)}" print(error_msg) yield error_msg, "", "", error_msg def format_flashcards_display(flashcards: List[Dict]) -> str: """Format flashcards for nice display.""" lines = [f"## ๐ŸŽด Generated {len(flashcards)} Flashcards\n"] for i, card in enumerate(flashcards, 1): lines.append(f"### Card {i}") lines.append(f"**Q:** {card['question']}") lines.append(f"**A:** {card['answer']}") lines.append(f"*Context: {card['context'][:100]}...*\n") lines.append("---\n") return "\n".join(lines) def create_sample_flashcard(): """Create a sample flashcard for demo purposes.""" sample = [{ "question": "What is the capital of France?", "answer": "Paris is the capital and most populous city of France.", "context": "Paris is the capital and most populous city of France..." }] return format_flashcards_display(sample) # Custom CSS for better styling custom_css = """ .flashcard-container { border: 2px solid #e0e0e0; border-radius: 10px; padding: 20px; margin: 10px 0; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; } .question { font-size: 1.2em; font-weight: bold; margin-bottom: 10px; } .answer { font-size: 1em; opacity: 0.9; } """ # Gradio Interface with gr.Blocks(css=custom_css, title="PDF to Flashcards") as demo: gr.Markdown(""" # ๐Ÿ“š PDF to Flashcards Generator Upload any PDF document and automatically generate study flashcards (Q&A pairs) using AI. **Features:** - ๐Ÿง  Uses local CPU-friendly AI (no GPU needed) - ๐Ÿ“„ Extracts text from any PDF - โœ‚๏ธ Intelligently chunks content - ๐ŸŽด Generates question-answer pairs - ๐Ÿ’พ Export to CSV (Anki-compatible) or JSON *Note: Processing is done entirely on CPU, so large PDFs may take a few minutes.* """) with gr.Row(): with gr.Column(scale=1): pdf_input = gr.File( label="Upload PDF", file_types=[".pdf"], type="filepath" ) with gr.Row(): questions_per_chunk = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Questions per section" ) max_chunks = gr.Slider( minimum=5, maximum=50, value=20, step=5, label="Max sections to process" ) process_btn = gr.Button("๐Ÿš€ Generate Flashcards", variant="primary") gr.Markdown(""" ### ๐Ÿ’ก Tips: - Text-based PDFs work best (scanned images won't work) - Academic papers and articles work great - Adjust "Questions per section" based on content density """) with gr.Column(scale=2): status_text = gr.Textbox( label="Status", value="Ready to process PDF...", interactive=False ) output_display = gr.Markdown( label="Generated Flashcards", value="Your flashcards will appear here..." ) with gr.Row(): with gr.Column(): csv_output = gr.Textbox( label="CSV Format (for Anki import)", lines=10, visible=True ) gr.Markdown("*Copy the CSV content and save as `.csv` file to import into Anki*") with gr.Column(): json_output = gr.Textbox( label="JSON Format", lines=10, visible=True ) gr.Markdown("*Raw JSON data for custom applications*") # FIXED: Direct binding without the broken .then() chain process_btn.click( fn=process_pdf, inputs=[pdf_input, questions_per_chunk, max_chunks], outputs=[status_text, csv_output, json_output, output_display] ) # Example section gr.Markdown("---") gr.Markdown("### ๐ŸŽฏ Example Output Format") gr.Markdown(create_sample_flashcard()) if __name__ == "__main__": demo.launch()