sammoftah's picture
Run RAG Space with Docker OCR support
5c44ad5 verified
import gradio as gr
import PyPDF2
import io
from huggingface_hub import InferenceClient
import os
import sys
import math
import re
from collections import Counter
try:
from PIL import Image
except Exception: # pragma: no cover - optional runtime fallback
Image = None
try:
import fitz # PyMuPDF
except Exception: # pragma: no cover - optional runtime fallback
fitz = None
try:
import pytesseract
except Exception: # pragma: no cover - optional runtime fallback
pytesseract = None
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from shared.components import create_method_panel, create_premium_hero
client = InferenceClient(token=os.getenv("HF_TOKEN"))
# Global storage
chunks = []
sources = []
chunk_vectors = []
def tokenize(text):
"""Small lexical tokenizer used for transparent CPU-friendly retrieval."""
return re.findall(r"[a-zA-Z0-9_]+", text.lower())
def vectorize(text):
return Counter(tokenize(text))
def cosine_similarity(left, right):
if not left or not right:
return 0.0
overlap = set(left).intersection(right)
dot = sum(left[token] * right[token] for token in overlap)
left_norm = math.sqrt(sum(value * value for value in left.values()))
right_norm = math.sqrt(sum(value * value for value in right.values()))
return dot / (left_norm * right_norm) if left_norm and right_norm else 0.0
def read_uploaded_pdf(pdf_file):
"""Normalize Gradio upload variants into bytes plus a display name."""
if hasattr(pdf_file, "read"):
if hasattr(pdf_file, "seek"):
pdf_file.seek(0)
payload = pdf_file.read()
source_name = getattr(pdf_file, "name", "uploaded.pdf")
elif isinstance(pdf_file, (str, os.PathLike)):
source_name = os.path.basename(str(pdf_file))
with open(pdf_file, "rb") as handle:
payload = handle.read()
elif hasattr(pdf_file, "path"):
source_name = os.path.basename(str(pdf_file.path))
with open(pdf_file.path, "rb") as handle:
payload = handle.read()
else:
payload = bytes(pdf_file)
source_name = "uploaded.pdf"
return payload, os.path.basename(str(source_name))
def extract_with_pypdf(payload):
"""Extract embedded text with PyPDF2."""
pdf_reader = PyPDF2.PdfReader(io.BytesIO(payload))
text = ""
for page in pdf_reader.pages:
text += (page.extract_text() or "") + "\n"
return text
def extract_with_pymupdf(payload):
"""Second-pass extraction for PDFs PyPDF2 parses poorly."""
if fitz is None:
return "", 0
text = ""
with fitz.open(stream=payload, filetype="pdf") as document:
for page in document:
text += page.get_text("text") + "\n"
page_count = document.page_count
return text, page_count
def extract_with_ocr(payload, max_pages=12):
"""Render PDF pages and OCR them when no embedded text exists."""
if fitz is None or Image is None:
return "", 0, "OCR dependencies are not available in this runtime."
if pytesseract is None:
return "", 0, "OCR engine is not available in this runtime."
ocr_text = []
pages_processed = 0
with fitz.open(stream=payload, filetype="pdf") as document:
page_limit = min(document.page_count, max_pages)
for page_index in range(page_limit):
page = document.load_page(page_index)
pixmap = page.get_pixmap(matrix=fitz.Matrix(2, 2), alpha=False)
image = Image.frombytes(
"RGB",
(pixmap.width, pixmap.height),
pixmap.samples,
)
page_text = pytesseract.image_to_string(image, config="--psm 6").strip()
if page_text:
ocr_text.append(page_text)
pages_processed += 1
if document.page_count > max_pages:
ocr_text.append(
f"\n[OCR note: processed first {max_pages} of {document.page_count} pages to keep the Space responsive.]"
)
return "\n".join(ocr_text), pages_processed, ""
def extract_text_from_pdf(pdf_file):
"""Extract text from a PDF upload, using OCR when no text layer exists."""
payload, source_name = read_uploaded_pdf(pdf_file)
text = extract_with_pypdf(payload).strip()
method = "PyPDF2 text layer"
page_count = 0
warning = ""
if len(text.split()) < 5:
text, page_count = extract_with_pymupdf(payload)
text = text.strip()
method = "PyMuPDF text layer"
if len(text.split()) < 5:
max_pages = int(os.getenv("OCR_MAX_PAGES", "12"))
text, pages_processed, warning = extract_with_ocr(payload, max_pages=max_pages)
text = text.strip()
method = f"OCR over rendered PDF pages ({pages_processed} page{'s' if pages_processed != 1 else ''})"
return text, source_name, method, warning, page_count
def chunk_text(text, chunk_size=500, overlap=50):
"""Split text into overlapping chunks."""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if len(chunk.strip()) > 0:
chunks.append(chunk)
return chunks
def process_pdfs(pdf_files, progress=gr.Progress()):
"""Process uploaded PDFs and create vector store."""
global chunks, sources, chunk_vectors
if not pdf_files:
return "❌ No PDFs uploaded"
chunks = []
sources = []
chunk_vectors = []
extraction_notes = []
progress(0, desc="Extracting text from PDFs...")
for i, pdf_file in enumerate(pdf_files):
try:
text, source_name, method, warning, page_count = extract_text_from_pdf(pdf_file)
except Exception as exc:
return f"❌ Could not read PDF: {exc}"
pdf_chunks = chunk_text(text)
chunks.extend(pdf_chunks)
sources.extend([source_name] * len(pdf_chunks))
word_count = len(text.split())
if word_count:
note = f"- {source_name}: {word_count:,} words extracted via {method}"
if warning:
note += f" ({warning})"
extraction_notes.append(note)
else:
detail = warning or "no text layer or OCR-readable text was found"
extraction_notes.append(
f"- {source_name}: {detail}."
)
progress((i + 1) / len(pdf_files), desc=f"Processed {i+1}/{len(pdf_files)} PDFs")
if not chunks:
return (
"❌ No text extracted from PDFs\n\n"
+ "\n".join(extraction_notes)
+ "\n\nThis Space now tries text extraction and OCR automatically. If this still fails, the PDF may contain "
"low-resolution images, protected content, or pages whose text is too blurred for OCR."
)
progress(0.7, desc="Building lexical retrieval index...")
chunk_vectors = [vectorize(chunk) for chunk in chunks]
return f"βœ… Processed {len(pdf_files)} PDFs into {len(chunks)} chunks\n\n" + "\n".join(extraction_notes)
def retrieve_chunks(query, top_k=3):
"""Retrieve most relevant chunks for query."""
if not chunk_vectors or len(chunks) == 0:
return [], []
query_vector = vectorize(query)
scored = [
(idx, cosine_similarity(query_vector, chunk_vector))
for idx, chunk_vector in enumerate(chunk_vectors)
]
scored.sort(key=lambda item: item[1], reverse=True)
top = scored[:top_k]
retrieved_chunks = [chunks[i] for i, _ in top]
retrieved_sources = [sources[i] for i, _ in top]
retrieved_scores = [score for _, score in top]
return retrieved_chunks, retrieved_sources, retrieved_scores
def answer_question(question, progress=gr.Progress()):
"""Answer question using RAG pipeline."""
if not question:
return "Please enter a question", "", ""
if not chunk_vectors:
return "Please upload and process PDFs first", "", ""
progress(0, desc="πŸ” Step 1: Retrieving relevant chunks...")
retrieved_chunks, retrieved_sources, scores = retrieve_chunks(question, top_k=3)
if not retrieved_chunks:
return "No relevant information found", "", ""
# Format retrieved chunks for display
chunks_display = ""
for i, (chunk, source, score) in enumerate(zip(retrieved_chunks, retrieved_sources, scores)):
chunks_display += f"**Chunk {i+1}** (from {source}, lexical similarity: {score:.3f})\n"
chunks_display += f"{chunk[:300]}...\n\n"
progress(0.5, desc="πŸ€– Step 2: Generating answer...")
# Create prompt for generation
context = "\n\n".join(retrieved_chunks)
prompt = f"""Based on the following context, answer the question. If the answer is not in the context, say so.
Context:
{context}
Question: {question}
Answer:"""
try:
if not os.getenv("HF_TOKEN"):
raise RuntimeError("HF_TOKEN is not configured; using local extractive fallback.")
response = ""
for token in client.text_generation(
prompt,
model="meta-llama/Llama-3.2-3B-Instruct",
max_new_tokens=300,
stream=True
):
response += token
progress(1.0, desc="βœ… Done!")
# Format citations
citations = "\n\n**Sources:**\n"
for i, source in enumerate(set(retrieved_sources)):
citations += f"- {source}\n"
return response.strip(), chunks_display, citations
except Exception as e:
fallback = (
"No hosted generation token is configured, so this Space is returning the most relevant retrieved evidence instead.\n\n"
f"**Question:** {question}\n\n"
f"**Best evidence:** {retrieved_chunks[0][:900]}..."
)
citations = "\n\n**Sources:**\n"
for source in sorted(set(retrieved_sources)):
citations += f"- {source}\n"
return fallback, chunks_display, citations
# Gradio Interface
with gr.Blocks(title="RAG from Scratch", theme=gr.themes.Soft()) as demo:
create_premium_hero(
"RAG from Scratch",
"A transparent Retrieval-Augmented Generation lab: chunk PDFs, retrieve passages, and answer with cited context.",
"πŸ“š",
badge="Retrieval Systems",
highlights=["Lexical retrieval", "Chunk inspection", "HF Inference"],
)
create_method_panel({
"Pipeline": "PDF text extraction β†’ overlapping chunks β†’ lexical retrieval β†’ grounded generation.",
"What it proves": "You can build and explain the moving parts behind production RAG systems.",
"Community value": "A teaching Space for debugging retrieval quality before adding orchestration complexity.",
})
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Step 1: Upload PDFs")
pdf_input = gr.File(
file_count="multiple",
file_types=[".pdf"],
type="filepath",
label="Upload PDF files"
)
process_btn = gr.Button("Process PDFs", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Step 2: Ask Questions")
question_input = gr.Textbox(
label="Your Question",
placeholder="What is this document about?",
lines=2
)
ask_btn = gr.Button("Get Answer", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Answer")
answer_output = gr.Textbox(label="Generated Answer", lines=6)
citations_output = gr.Markdown(label="Citations")
with gr.Accordion("πŸ” Retrieved Chunks (View Pipeline)", open=False):
chunks_output = gr.Markdown(label="Chunks Used")
gr.Markdown("""
### πŸ’‘ How it works:
- **Indexing**: Convert chunks into transparent lexical vectors
- **Retrieval**: Find chunks with the strongest term overlap
- **Generation**: LLM uses retrieved chunks to answer
This is the foundation of most modern Q&A systems!
""")
process_btn.click(
process_pdfs,
inputs=[pdf_input],
outputs=[status]
)
ask_btn.click(
answer_question,
inputs=[question_input],
outputs=[answer_output, chunks_output, citations_output]
)
if __name__ == "__main__":
demo.launch()