ST-THOMAS-OF-AQUINAS's picture
Update app.py
01d723b verified
import gradio as gr
import PyPDF2
import re
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
# ----------------------------
# Embedding model
# ----------------------------
embed_model = SentenceTransformer("all-mpnet-base-v2")
# ----------------------------
# In-memory storage
# ----------------------------
vector_store = None
chunks_store = None
embeddings_store = None
TOP_K = 3 # number of chunks to retrieve
# ----------------------------
# PDF Loader and Chunker
# ----------------------------
def load_pdf(file):
pdf_reader = PyPDF2.PdfReader(file)
text_pages = [page.extract_text() for page in pdf_reader.pages]
return text_pages
def chunk_text(text_pages, chunk_size=200, overlap=50):
chunks = []
for page in text_pages:
if not page:
continue
words = re.split(r'\s+', page)
start = 0
while start < len(words):
end = start + chunk_size
chunks.append(" ".join(words[start:end]))
start += chunk_size - overlap
return chunks
# ----------------------------
# Vectorization
# ----------------------------
def vectorize_pdf(marking_scheme_file):
global vector_store, chunks_store, embeddings_store
pages = load_pdf(marking_scheme_file)
chunks = chunk_text(pages)
embeddings = embed_model.encode(chunks, convert_to_numpy=True)
vector_store = faiss.IndexFlatL2(embeddings.shape[1])
vector_store.add(embeddings)
chunks_store = chunks
embeddings_store = embeddings
# Preview table
table_preview = []
for i, chunk in enumerate(chunks[:10]):
table_preview.append({
"chunk_id": i + 1,
"text_preview": chunk[:50].replace("\n"," ") + ("..." if len(chunk) > 50 else ""),
"embedding_preview": np.round(embeddings[i][:5], 4).tolist()
})
return {
"num_chunks": len(chunks),
"preview": table_preview
}
# ----------------------------
# Parse student PDF (Question + Answer)
# ----------------------------
def parse_student_pdf_qna(student_pdf_file):
"""
Parses a PDF where each answer is in format:
Question: <text>
Answer: <text>
Returns a list of (question, answer) tuples.
"""
pages = load_pdf(student_pdf_file)
text = "\n".join(pages)
# Regex to match Question: ... Answer: ...
pattern = re.compile(r"Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)", re.DOTALL | re.IGNORECASE)
qas = pattern.findall(text)
# Strip extra spaces
qas = [(q.strip(), a.strip()) for q, a in qas]
return qas
# ----------------------------
# Retrieve relevant chunks and generate enhanced prompt
# ----------------------------
def generate_enhanced_prompts(student_pdf_file, top_k=TOP_K, max_marks=4):
global vector_store, chunks_store, embeddings_store
if vector_store is None or chunks_store is None:
return "Error: No marking scheme vector store loaded. Please upload PDF first."
qas = parse_student_pdf_qna(student_pdf_file)
prompts = {}
for question, answer_text in qas:
# Embed student answer
query_vec = embed_model.encode([answer_text], convert_to_numpy=True)
# Search FAISS
distances, indices = vector_store.search(query_vec, top_k)
retrieved_chunks = [chunks_store[i] for i in indices[0]]
# Create enhanced prompt
prompt = f"""Instruction: You are a national exam marker. Compare the student's answer with the marking scheme and award marks according to the scheme. Provide rationale. Award partial marks if some points are covered. Output in JSON.
Question: {question}
Answer: {answer_text}
Marking Scheme Context: {' '.join(retrieved_chunks)}
Maximum Marks: {max_marks}
Guidelines: If answer contains part of correct points, award partial marks proportionally.
Output Format:
{{
"score": <numeric>,
"rationale": "<explanation>"
}}
"""
prompts[question] = prompt
return prompts
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("## Vectorization + Retrieval + Enhanced Prompt Generation")
# Upload marking scheme PDF
pdf_file = gr.File(label="Upload Marking Scheme PDF")
vector_output = gr.JSON(label="Vectorization Info")
submit_vector = gr.Button("Vectorize PDF")
submit_vector.click(vectorize_pdf, inputs=[pdf_file], outputs=[vector_output])
# Upload student answer PDF
student_pdf = gr.File(label="Upload Student Answer PDF")
prompts_output = gr.JSON(label="Generated Prompts for Marking")
submit_prompts = gr.Button("Generate Enhanced Prompts")
submit_prompts.click(generate_enhanced_prompts, inputs=[student_pdf], outputs=[prompts_output])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)