heerjtdev commited on
Commit
c1ea31c
Β·
verified Β·
1 Parent(s): acf77a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -373
app.py CHANGED
@@ -1,415 +1,324 @@
1
  import gradio as gr
2
- import fitz
3
- import torch
4
- import os
5
  import re
6
- import numpy as np
7
- from collections import Counter
8
- import onnxruntime as ort
9
- from onnxruntime import SessionOptions, GraphOptimizationLevel
10
- from langchain_text_splitters import RecursiveCharacterTextSplitter
11
- from langchain_community.vectorstores import FAISS
12
- from langchain_core.embeddings import Embeddings
13
- from transformers import AutoTokenizer
14
- from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForCausalLM
15
- from huggingface_hub import snapshot_download
16
- from sentence_transformers import SentenceTransformer # Add this for cross-encoder
17
-
18
- PROVIDERS = ["CPUExecutionProvider"]
19
-
20
- # ---------------------------------------------------------
21
- # 1. EMBEDDINGS (Your existing code - good)
22
- # ---------------------------------------------------------
23
- class OnnxBgeEmbeddings(Embeddings):
24
- def __init__(self):
25
- model_name = "Xenova/bge-small-en-v1.5"
26
- print(f"πŸ”„ Loading Embeddings: {model_name}...")
27
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
28
- self.model = ORTModelForFeatureExtraction.from_pretrained(
29
- model_name, export=False, provider=PROVIDERS[0]
30
- )
31
 
32
- def _process_batch(self, texts):
33
- inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
34
- with torch.no_grad():
35
- outputs = self.model(**inputs)
36
- embeddings = outputs.last_hidden_state[:, 0]
37
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
38
- return embeddings.numpy().tolist()
 
39
 
40
- def embed_documents(self, texts):
41
- return self._process_batch(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- def embed_query(self, text):
44
- return self._process_batch([text])[0]
 
 
 
 
 
45
 
46
- # ---------------------------------------------------------
47
- # 2. RULE-BASED GRADING ENGINE (NEW - No LLM needed)
48
- # ---------------------------------------------------------
49
- class RuleBasedGrader:
50
- """
51
- Extracts key concepts from context and checks student answer coverage.
52
- Works 100% on CPU, deterministic, explainable.
53
- """
54
 
55
- def __init__(self):
56
- # Load a small NER or keyword extraction model if needed
57
- # Or use simple TF-IDF/RAKE algorithm
58
- pass
 
 
 
59
 
60
- def extract_key_concepts(self, text, top_k=10):
61
- """
62
- Extract key noun phrases and important terms from context.
63
- Uses simple but effective heuristics.
64
- """
65
- # Clean text
66
- text = re.sub(r'[^\w\s]', ' ', text.lower())
67
- words = text.split()
68
-
69
- # Remove stopwords
70
- stopwords = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'must', 'shall', 'can', 'need', 'dare', 'ought', 'used', 'it', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'we', 'they'}
71
-
72
- # Get word frequencies (excluding stopwords)
73
- words = [w for w in words if w not in stopwords and len(w) > 2]
74
- word_freq = Counter(words)
75
-
76
- # Get bigrams (two-word phrases)
77
- bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words)-1)]
78
- bigram_freq = Counter(bigrams)
79
-
80
- # Combine unigrams and bigrams
81
- concepts = []
82
- for word, count in word_freq.most_common(top_k):
83
- if count > 1: # Only include words that appear multiple times
84
- concepts.append(word)
85
-
86
- for bigram, count in bigram_freq.most_common(top_k//2):
87
- if count > 1:
88
- concepts.append(bigram)
89
-
90
- return list(set(concepts))[:top_k] # Remove duplicates, limit to top_k
91
 
92
- def check_concept_coverage(self, student_answer, key_concepts):
93
- """
94
- Check which key concepts from context appear in student answer.
95
- Returns coverage score and missing concepts.
96
- """
97
- student_lower = student_answer.lower()
98
- found_concepts = []
99
- missing_concepts = []
100
-
101
- for concept in key_concepts:
102
- # Check for exact match or partial match
103
- if concept in student_lower:
104
- found_concepts.append(concept)
105
- else:
106
- # Check for word stems (e.g., "running" matches "run")
107
- concept_words = concept.split()
108
- if all(any(word in student_lower for word in [cw, cw+'s', cw+'es', cw+'ed', cw+'ing']) for cw in concept_words):
109
- found_concepts.append(concept)
110
- else:
111
- missing_concepts.append(concept)
112
-
113
- coverage = len(found_concepts) / len(key_concepts) if key_concepts else 0
114
- return coverage, found_concepts, missing_concepts
115
 
116
- def detect_contradictions(self, context, student_answer):
117
- """
118
- Simple contradiction detection using negation patterns.
119
- """
120
- context_lower = context.lower()
121
- answer_lower = student_answer.lower()
122
-
123
- # Common negation patterns
124
- negation_words = ['not', 'no', 'never', 'none', 'nothing', 'nobody', 'neither', 'nowhere', 'hardly', 'scarcely', 'barely', "doesn't", "isn't", "wasn't", "shouldn't", "wouldn't", "couldn't", "can't", "don't", "didn't", "hasn't", "haven't", "hadn't", "won't"]
125
-
126
- contradictions = []
127
-
128
- # Extract sentences from context that contain key facts
129
- context_sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 10]
130
-
131
- for sent in context_sentences:
132
- sent_lower = sent.lower()
133
- # Check if student says opposite
134
- for neg in negation_words:
135
- if neg in sent_lower:
136
- # Context has negation, check if student affirms
137
- positive_version = sent_lower.replace(neg, '').strip()
138
- if any(word in answer_lower for word in positive_version.split()[:5]):
139
- contradictions.append(f"Context says: '{sent}' but student contradicts this")
140
- else:
141
- # Context is positive, check if student negates
142
- # This is harder - would need semantic understanding
143
- pass
144
-
145
- return contradictions
146
 
147
- def calculate_semantic_similarity(self, context, student_answer, embeddings_model):
148
- """
149
- Use embeddings to calculate semantic similarity.
150
- """
151
- context_emb = embeddings_model.embed_query(context)
152
- answer_emb = embeddings_model.embed_query(student_answer)
153
-
154
- # Cosine similarity
155
- similarity = np.dot(context_emb, answer_emb) / (np.linalg.norm(context_emb) * np.linalg.norm(answer_emb))
156
- return float(similarity)
157
 
158
- def grade(self, context, question, student_answer, max_marks, embeddings_model):
159
- """
160
- Main grading function combining multiple signals.
161
- """
162
- # 1. Extract key concepts from context
163
- key_concepts = self.extract_key_concepts(context)
164
-
165
- # 2. Check concept coverage
166
- coverage, found, missing = self.check_concept_coverage(student_answer, key_concepts)
167
-
168
- # 3. Check for contradictions
169
- contradictions = self.detect_contradictions(context, student_answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # 4. Calculate semantic similarity
172
- semantic_sim = self.calculate_semantic_similarity(context, student_answer, embeddings_model)
173
 
174
- # 5. Calculate final score
175
- # Weight: 60% concept coverage, 40% semantic similarity
176
- # Penalty for contradictions: -50% per contradiction
177
 
178
- base_score = (coverage * 0.6 + semantic_sim * 0.4) * max_marks
 
 
179
 
180
- # Apply contradiction penalties
181
- contradiction_penalty = len(contradictions) * (max_marks * 0.5)
182
- final_score = max(0, base_score - contradiction_penalty)
183
 
184
- # Generate feedback
185
- feedback = f"""
186
- **Grading Analysis:**
187
 
188
- **Key Concepts Found ({len(found)}/{len(key_concepts)}):** {', '.join(found) if found else 'None'}
189
- **Key Concepts Missing:** {', '.join(missing) if missing else 'None'}
 
190
 
191
- **Concept Coverage:** {coverage:.1%}
192
- **Semantic Similarity:** {semantic_sim:.1%}
 
 
 
 
193
 
194
- **Contradictions Detected:** {len(contradictions)}
195
- {chr(10).join(['- ' + c for c in contradictions]) if contradictions else 'None'}
196
 
197
- **Calculation:** ({coverage:.1%} Γ— 0.6 + {semantic_sim:.1%} Γ— 0.4) Γ— {max_marks} - {contradiction_penalty:.1f} penalty = **{final_score:.1f}/{max_marks}**
198
- """
199
 
200
- return final_score, feedback
201
-
202
- # ---------------------------------------------------------
203
- # 3. LLM EVALUATOR (Fallback for edge cases)
204
- # ---------------------------------------------------------
205
- class LLMEvaluator:
206
- def __init__(self):
207
- self.repo_id = "onnx-community/Qwen2.5-0.5B-Instruct"
208
- self.local_dir = "onnx_qwen_local"
209
 
210
- if not os.path.exists(self.local_dir):
211
- snapshot_download(
212
- repo_id=self.repo_id,
213
- local_dir=self.local_dir,
214
- allow_patterns=["config.json", "generation_config.json", "tokenizer*", "special_tokens_map.json", "*.jinja", "onnx/model_fp16.onnx*"]
215
- )
216
-
217
- self.tokenizer = AutoTokenizer.from_pretrained(self.local_dir)
218
 
219
- sess_options = SessionOptions()
220
- sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
 
 
 
 
 
221
 
222
- self.model = ORTModelForCausalLM.from_pretrained(
223
- self.local_dir,
224
- subfolder="onnx",
225
- file_name="model_fp16.onnx",
226
- use_cache=True,
227
- use_io_binding=False,
228
- provider=PROVIDERS[0],
229
- session_options=sess_options
230
- )
231
-
232
- def evaluate(self, context, question, student_answer, max_marks, rule_based_score):
233
- """
234
- Use LLM only for ambiguous cases or to verify edge cases.
235
- Simplified prompt for 0.5B model.
236
- """
237
- # If rule-based gave clear 0 or max, don't bother with LLM
238
- if rule_based_score == 0:
239
- return "Score: 0/{max_marks}\nFeedback: Answer contains significant errors or contradictions with the reference text."
240
- if rule_based_score == max_marks:
241
- return "Score: {max_marks}/{max_marks}\nFeedback: Excellent answer that fully covers the reference material."
242
 
243
- # Otherwise, use LLM for nuanced cases
244
- prompt = f"""Grade this answer based ONLY on the context provided.
245
 
246
- Context: {context[:500]}
247
- Question: {question}
248
- Student Answer: {student_answer}
 
 
 
 
 
 
 
 
 
249
 
250
- Rules:
251
- 1. Give 0 if answer contradicts context or adds outside information
252
- 2. Give full marks only if answer matches context exactly
253
- 3. Give partial marks for partial matches
 
 
 
 
254
 
255
- Output exactly:
256
- Score: X/{max_marks}
257
- Feedback: One sentence explanation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
260
-
261
- with torch.no_grad():
262
- outputs = self.model.generate(
263
- **inputs,
264
- max_new_tokens=50,
265
- temperature=0.1,
266
- do_sample=False,
267
- pad_token_id=self.tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
271
- # Extract just the generated part (after the prompt)
272
- response = response[len(self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)):]
273
- return response.strip()
274
-
275
- # ---------------------------------------------------------
276
- # 4. MAIN APPLICATION
277
- # ---------------------------------------------------------
278
- class VectorSystem:
279
- def __init__(self):
280
- self.vector_store = None
281
- self.embeddings = OnnxBgeEmbeddings()
282
- self.rule_grader = RuleBasedGrader()
283
- self.llm = LLMEvaluator()
284
- self.all_chunks = []
285
- self.total_chunks = 0
286
-
287
- def process_content(self, file_obj, raw_text):
288
- has_file = file_obj is not None
289
- has_text = raw_text is not None and len(raw_text.strip()) > 0
290
-
291
- if has_file and has_text:
292
- return "❌ Error: Provide EITHER file OR text, not both."
293
-
294
- if not has_file and not has_text:
295
- return "⚠️ No content provided."
296
-
297
- try:
298
- text = ""
299
- if has_file:
300
- if file_obj.name.endswith('.pdf'):
301
- doc = fitz.open(file_obj.name)
302
- for page in doc:
303
- text += page.get_text()
304
- elif file_obj.name.endswith('.txt'):
305
- with open(file_obj.name, 'r', encoding='utf-8') as f:
306
- text = f.read()
307
- else:
308
- return "❌ Only .pdf and .txt supported."
309
- else:
310
- text = raw_text
311
-
312
- # Larger chunks for better context
313
- text_splitter = RecursiveCharacterTextSplitter(
314
- chunk_size=1000,
315
- chunk_overlap=200,
316
- separators=["\n\n", "\n", ". ", " ", ""]
317
  )
318
- self.all_chunks = text_splitter.split_text(text)
319
- self.total_chunks = len(self.all_chunks)
320
 
321
- if not self.all_chunks:
322
- return "Content empty."
323
-
324
- metadatas = [{"id": i} for i in range(self.total_chunks)]
325
- self.vector_store = FAISS.from_texts(
326
- self.all_chunks,
327
- self.embeddings,
328
- metadatas=metadatas
329
  )
330
-
331
- return f"βœ… Indexed {self.total_chunks} chunks."
332
- except Exception as e:
333
- return f"Error: {str(e)}"
334
-
335
- def process_query(self, question, student_answer, max_marks):
336
- if not self.vector_store:
337
- return "⚠️ Upload content first.", ""
338
- if not question:
339
- return "⚠️ Enter a question.", ""
340
- if not student_answer:
341
- return "⚠️ Enter a student answer.", ""
342
-
343
- # Retrieve relevant context
344
- results = self.vector_store.similarity_search_with_score(question, k=2)
345
-
346
- # Combine top 2 chunks for better context
347
- context_parts = []
348
- for doc, score in results:
349
- context_parts.append(self.all_chunks[doc.metadata['id']])
350
-
351
- expanded_context = "\n".join(context_parts)
352
-
353
- # Use rule-based grading (fast, deterministic)
354
- score, feedback = self.rule_grader.grade(
355
- expanded_context,
356
- question,
357
- student_answer,
358
- max_marks,
359
- self.embeddings
360
- )
361
-
362
- # Optional: Use LLM for ambiguous cases (score between 20-80%)
363
- # Uncomment if you want LLM verification
364
- # if 0.2 < (score/max_marks) < 0.8:
365
- # llm_feedback = self.llm.evaluate(expanded_context, question, student_answer, max_marks, score)
366
- # feedback += f"\n\n**LLM Verification:**\n{llm_feedback}"
367
-
368
- evidence_display = f"### πŸ“š Context Used:\n{expanded_context[:800]}..."
369
- grade_display = f"### πŸ“ Grade: {score:.1f}/{max_marks}\n\n{feedback}"
370
-
371
- return evidence_display, grade_display
372
-
373
- # Initialize and launch
374
- system = VectorSystem()
375
-
376
- with gr.Blocks(title="EduGenius AI Grader") as demo:
377
- gr.Markdown("# ⚑ EduGenius: CPU Optimized RAG")
378
- gr.Markdown("Hybrid Rule-Based + LLM Grading (ONNX Optimized)")
379
 
380
  with gr.Row():
381
- with gr.Column(scale=1):
382
- gr.Markdown("### Source Input")
383
- pdf_input = gr.File(label="Upload Chapter (PDF/TXT)")
384
- gr.Markdown("**OR**")
385
- text_input = gr.Textbox(
386
- label="Paste Context",
387
- placeholder="Paste text here...",
388
- lines=5
389
  )
390
- upload_btn = gr.Button("Index Content", variant="primary")
391
- status_msg = gr.Textbox(label="Status", interactive=False)
392
-
393
- with gr.Column(scale=2):
394
- q_input = gr.Textbox(label="Question", scale=2)
395
- max_marks = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Max Marks")
396
- a_input = gr.TextArea(label="Student Answer", lines=5)
397
- run_btn = gr.Button("Retrieve & Grade", variant="secondary")
398
-
399
- with gr.Row():
400
- evidence_box = gr.Markdown()
401
- grade_box = gr.Markdown()
402
-
403
- upload_btn.click(
404
- system.process_content,
405
- inputs=[pdf_input, text_input],
406
- outputs=[status_msg]
407
- )
408
- run_btn.click(
409
- system.process_query,
410
- inputs=[q_input, a_input, max_marks],
411
- outputs=[evidence_box, grade_box]
412
  )
 
 
 
 
 
413
 
414
  if __name__ == "__main__":
415
  demo.launch()
 
1
  import gradio as gr
2
+ import PyPDF2
 
 
3
  import re
4
+ import json
5
+ from typing import List, Dict, Tuple
6
+ from transformers import pipeline
7
+ import tempfile
8
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Initialize the question generation pipeline using a small CPU-friendly model
11
+ print("Loading models... This may take a minute on first run.")
12
+ qa_generator = pipeline(
13
+ "text2text-generation",
14
+ model="valhalla/t5-small-qg-hl",
15
+ tokenizer="valhalla/t5-small-qg-hl",
16
+ device=-1 # Force CPU
17
+ )
18
 
19
+ def extract_text_from_pdf(pdf_file) -> str:
20
+ """Extract text from uploaded PDF file."""
21
+ text = ""
22
+ try:
23
+ # Handle both file path and file object
24
+ if isinstance(pdf_file, str):
25
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
26
+ else:
27
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
28
+
29
+ for page in pdf_reader.pages:
30
+ text += page.extract_text() + "\n"
31
+ except Exception as e:
32
+ return f"Error reading PDF: {str(e)}"
33
+
34
+ return text
35
 
36
+ def clean_text(text: str) -> str:
37
+ """Clean and preprocess extracted text."""
38
+ # Remove excessive whitespace
39
+ text = re.sub(r'\s+', ' ', text)
40
+ # Remove special characters but keep sentence structure
41
+ text = re.sub(r'[^\w\s.,;!?-]', '', text)
42
+ return text.strip()
43
 
44
+ def chunk_text(text: str, max_chunk_size: int = 512, overlap: int = 50) -> List[str]:
45
+ """Split text into overlapping chunks for processing."""
46
+ sentences = re.split(r'(?<=[.!?])\s+', text)
47
+ chunks = []
48
+ current_chunk = ""
 
 
 
49
 
50
+ for sentence in sentences:
51
+ if len(current_chunk) + len(sentence) < max_chunk_size:
52
+ current_chunk += " " + sentence
53
+ else:
54
+ if current_chunk:
55
+ chunks.append(current_chunk.strip())
56
+ current_chunk = sentence
57
 
58
+ if current_chunk:
59
+ chunks.append(current_chunk.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Add overlap between chunks for context
62
+ overlapped_chunks = []
63
+ for i, chunk in enumerate(chunks):
64
+ if i > 0 and overlap > 0:
65
+ prev_sentences = chunks[i-1].split('. ')
66
+ overlap_text = '. '.join(prev_sentences[-2:]) if len(prev_sentences) > 1 else chunks[i-1][-overlap:]
67
+ chunk = overlap_text + " " + chunk
68
+ overlapped_chunks.append(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ return overlapped_chunks
71
+
72
+ def generate_qa_pairs(chunk: str, num_questions: int = 2) -> List[Dict[str, str]]:
73
+ """Generate question-answer pairs from a text chunk."""
74
+ flashcards = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # Skip chunks that are too short
77
+ if len(chunk.split()) < 20:
78
+ return []
 
 
 
 
 
 
 
79
 
80
+ try:
81
+ # Generate highlight format for T5 question generation
82
+ # We'll create simple highlight by taking key sentences
83
+ sentences = chunk.split('. ')
84
+ if len(sentences) < 2:
85
+ return []
86
+
87
+ # Generate questions for different parts of the chunk
88
+ for i in range(min(num_questions, len(sentences))):
89
+ # Create highlight context
90
+ highlight = sentences[i]
91
+ context = chunk
92
+
93
+ # Format for T5: "generate question: <hl> highlight <hl> context"
94
+ input_text = f"generate question: <hl> {highlight} <hl> {context}"
95
+
96
+ # Generate question
97
+ outputs = qa_generator(
98
+ input_text,
99
+ max_length=128,
100
+ num_return_sequences=1,
101
+ do_sample=True,
102
+ temperature=0.7
103
+ )
104
+
105
+ question = outputs[0]['generated_text'].strip()
106
+
107
+ # Clean up question
108
+ question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip()
109
+
110
+ if question and len(question) > 10:
111
+ flashcards.append({
112
+ "question": question,
113
+ "answer": highlight.strip(),
114
+ "context": context[:200] + "..." if len(context) > 200 else context
115
+ })
116
+
117
+ except Exception as e:
118
+ print(f"Error generating QA: {e}")
119
+
120
+ return flashcards
121
+
122
+ def process_pdf(pdf_file, questions_per_chunk: int = 2, max_chunks: int = 20):
123
+ """Main processing function."""
124
+ if pdf_file is None:
125
+ return "Please upload a PDF file.", None, None
126
+
127
+ try:
128
+ # Extract text
129
+ yield "πŸ“„ Extracting text from PDF...", None, None
130
+ raw_text = extract_text_from_pdf(pdf_file)
131
 
132
+ if raw_text.startswith("Error"):
133
+ return raw_text, None, None
134
 
135
+ if len(raw_text.strip()) < 100:
136
+ return "PDF appears to be empty or contains no extractable text.", None, None
 
137
 
138
+ # Clean text
139
+ yield "🧹 Cleaning text...", None, None
140
+ cleaned_text = clean_text(raw_text)
141
 
142
+ # Chunk text
143
+ yield "βœ‚οΈ Chunking text into sections...", None, None
144
+ chunks = chunk_text(cleaned_text)
145
 
146
+ # Limit chunks for CPU performance
147
+ chunks = chunks[:max_chunks]
 
148
 
149
+ # Generate flashcards
150
+ all_flashcards = []
151
+ total_chunks = len(chunks)
152
 
153
+ for i, chunk in enumerate(chunks):
154
+ progress = f"🎴 Generating flashcards... ({i+1}/{total_chunks} chunks processed)"
155
+ yield progress, None, None
156
+
157
+ cards = generate_qa_pairs(chunk, questions_per_chunk)
158
+ all_flashcards.extend(cards)
159
 
160
+ if not all_flashcards:
161
+ return "Could not generate flashcards from this PDF. Try a PDF with more textual content.", None, None
162
 
163
+ # Format output
164
+ yield "βœ… Finalizing...", None, None
165
 
166
+ # Create formatted display
167
+ display_text = format_flashcards_display(all_flashcards)
 
 
 
 
 
 
 
168
 
169
+ # Create JSON download
170
+ json_output = json.dumps(all_flashcards, indent=2, ensure_ascii=False)
 
 
 
 
 
 
171
 
172
+ # Create Anki/CSV format
173
+ csv_lines = ["Question,Answer"]
174
+ for card in all_flashcards:
175
+ q = card['question'].replace('"', '""')
176
+ a = card['answer'].replace('"', '""')
177
+ csv_lines.append(f'"{q}","{a}"')
178
+ csv_output = "\n".join(csv_lines)
179
 
180
+ return display_text, csv_output, json_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ except Exception as e:
183
+ return f"Error processing PDF: {str(e)}", None, None
184
 
185
+ def format_flashcards_display(flashcards: List[Dict]) -> str:
186
+ """Format flashcards for nice display."""
187
+ lines = [f"## 🎴 Generated {len(flashcards)} Flashcards\n"]
188
+
189
+ for i, card in enumerate(flashcards, 1):
190
+ lines.append(f"### Card {i}")
191
+ lines.append(f"**Q:** {card['question']}")
192
+ lines.append(f"**A:** {card['answer']}")
193
+ lines.append(f"*Context: {card['context'][:100]}...*\n")
194
+ lines.append("---\n")
195
+
196
+ return "\n".join(lines)
197
 
198
+ def create_sample_flashcard():
199
+ """Create a sample flashcard for demo purposes."""
200
+ sample = [{
201
+ "question": "What is the capital of France?",
202
+ "answer": "Paris is the capital and most populous city of France.",
203
+ "context": "Paris is the capital and most populous city of France..."
204
+ }]
205
+ return format_flashcards_display(sample)
206
 
207
+ # Custom CSS for better styling
208
+ custom_css = """
209
+ .flashcard-container {
210
+ border: 2px solid #e0e0e0;
211
+ border-radius: 10px;
212
+ padding: 20px;
213
+ margin: 10px 0;
214
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
215
+ color: white;
216
+ }
217
+ .question {
218
+ font-size: 1.2em;
219
+ font-weight: bold;
220
+ margin-bottom: 10px;
221
+ }
222
+ .answer {
223
+ font-size: 1em;
224
+ opacity: 0.9;
225
+ }
226
+ """
227
 
228
+ # Gradio Interface
229
+ with gr.Blocks(css=custom_css, title="PDF to Flashcards") as demo:
230
+ gr.Markdown("""
231
+ # πŸ“š PDF to Flashcards Generator
232
+
233
+ Upload any PDF document and automatically generate study flashcards (Q&A pairs) using AI.
234
+
235
+ **Features:**
236
+ - 🧠 Uses local CPU-friendly AI (no GPU needed)
237
+ - πŸ“„ Extracts text from any PDF
238
+ - βœ‚οΈ Intelligently chunks content
239
+ - 🎴 Generates question-answer pairs
240
+ - πŸ’Ύ Export to CSV (Anki-compatible) or JSON
241
+
242
+ *Note: Processing is done entirely on CPU, so large PDFs may take a few minutes.*
243
+ """)
244
+
245
+ with gr.Row():
246
+ with gr.Column(scale=1):
247
+ pdf_input = gr.File(
248
+ label="Upload PDF",
249
+ file_types=[".pdf"],
250
+ type="filepath"
251
  )
252
+
253
+ with gr.Row():
254
+ questions_per_chunk = gr.Slider(
255
+ minimum=1,
256
+ maximum=5,
257
+ value=2,
258
+ step=1,
259
+ label="Questions per section"
260
+ )
261
+ max_chunks = gr.Slider(
262
+ minimum=5,
263
+ maximum=50,
264
+ value=20,
265
+ step=5,
266
+ label="Max sections to process"
267
+ )
268
+
269
+ process_btn = gr.Button("πŸš€ Generate Flashcards", variant="primary")
270
+
271
+ gr.Markdown("""
272
+ ### πŸ’‘ Tips:
273
+ - Text-based PDFs work best (scanned images won't work)
274
+ - Academic papers and articles work great
275
+ - Adjust "Questions per section" based on content density
276
+ """)
277
 
278
+ with gr.Column(scale=2):
279
+ status_text = gr.Textbox(
280
+ label="Status",
281
+ value="Ready to process PDF...",
282
+ interactive=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  )
 
 
284
 
285
+ output_display = gr.Markdown(
286
+ label="Generated Flashcards",
287
+ value="Your flashcards will appear here..."
 
 
 
 
 
288
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  with gr.Row():
291
+ with gr.Column():
292
+ csv_output = gr.Textbox(
293
+ label="CSV Format (for Anki import)",
294
+ lines=10,
295
+ visible=True
 
 
 
296
  )
297
+ gr.Markdown("*Copy the CSV content and save as `.csv` file to import into Anki*")
298
+
299
+ with gr.Column():
300
+ json_output = gr.Textbox(
301
+ label="JSON Format",
302
+ lines=10,
303
+ visible=True
304
+ )
305
+ gr.Markdown("*Raw JSON data for custom applications*")
306
+
307
+ # Event handlers
308
+ process_btn.click(
309
+ fn=process_pdf,
310
+ inputs=[pdf_input, questions_per_chunk, max_chunks],
311
+ outputs=[status_text, csv_output, json_output]
312
+ ).then(
313
+ fn=lambda x: x if not isinstance(x, str) or not x.startswith("πŸ“„") else gr.update(),
314
+ inputs=status_text,
315
+ outputs=output_display
 
 
 
316
  )
317
+
318
+ # Example section
319
+ gr.Markdown("---")
320
+ gr.Markdown("### 🎯 Example Output Format")
321
+ gr.Markdown(create_sample_flashcard())
322
 
323
  if __name__ == "__main__":
324
  demo.launch()