Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import PyPDF2 | |
| import io | |
| from huggingface_hub import InferenceClient | |
| import os | |
| import sys | |
| import math | |
| import re | |
| from collections import Counter | |
| try: | |
| from PIL import Image | |
| except Exception: # pragma: no cover - optional runtime fallback | |
| Image = None | |
| try: | |
| import fitz # PyMuPDF | |
| except Exception: # pragma: no cover - optional runtime fallback | |
| fitz = None | |
| try: | |
| import pytesseract | |
| except Exception: # pragma: no cover - optional runtime fallback | |
| pytesseract = None | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
| from shared.components import create_method_panel, create_premium_hero | |
| client = InferenceClient(token=os.getenv("HF_TOKEN")) | |
| # Global storage | |
| chunks = [] | |
| sources = [] | |
| chunk_vectors = [] | |
| def tokenize(text): | |
| """Small lexical tokenizer used for transparent CPU-friendly retrieval.""" | |
| return re.findall(r"[a-zA-Z0-9_]+", text.lower()) | |
| def vectorize(text): | |
| return Counter(tokenize(text)) | |
| def cosine_similarity(left, right): | |
| if not left or not right: | |
| return 0.0 | |
| overlap = set(left).intersection(right) | |
| dot = sum(left[token] * right[token] for token in overlap) | |
| left_norm = math.sqrt(sum(value * value for value in left.values())) | |
| right_norm = math.sqrt(sum(value * value for value in right.values())) | |
| return dot / (left_norm * right_norm) if left_norm and right_norm else 0.0 | |
| def read_uploaded_pdf(pdf_file): | |
| """Normalize Gradio upload variants into bytes plus a display name.""" | |
| if hasattr(pdf_file, "read"): | |
| if hasattr(pdf_file, "seek"): | |
| pdf_file.seek(0) | |
| payload = pdf_file.read() | |
| source_name = getattr(pdf_file, "name", "uploaded.pdf") | |
| elif isinstance(pdf_file, (str, os.PathLike)): | |
| source_name = os.path.basename(str(pdf_file)) | |
| with open(pdf_file, "rb") as handle: | |
| payload = handle.read() | |
| elif hasattr(pdf_file, "path"): | |
| source_name = os.path.basename(str(pdf_file.path)) | |
| with open(pdf_file.path, "rb") as handle: | |
| payload = handle.read() | |
| else: | |
| payload = bytes(pdf_file) | |
| source_name = "uploaded.pdf" | |
| return payload, os.path.basename(str(source_name)) | |
| def extract_with_pypdf(payload): | |
| """Extract embedded text with PyPDF2.""" | |
| pdf_reader = PyPDF2.PdfReader(io.BytesIO(payload)) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += (page.extract_text() or "") + "\n" | |
| return text | |
| def extract_with_pymupdf(payload): | |
| """Second-pass extraction for PDFs PyPDF2 parses poorly.""" | |
| if fitz is None: | |
| return "", 0 | |
| text = "" | |
| with fitz.open(stream=payload, filetype="pdf") as document: | |
| for page in document: | |
| text += page.get_text("text") + "\n" | |
| page_count = document.page_count | |
| return text, page_count | |
| def extract_with_ocr(payload, max_pages=12): | |
| """Render PDF pages and OCR them when no embedded text exists.""" | |
| if fitz is None or Image is None: | |
| return "", 0, "OCR dependencies are not available in this runtime." | |
| if pytesseract is None: | |
| return "", 0, "OCR engine is not available in this runtime." | |
| ocr_text = [] | |
| pages_processed = 0 | |
| with fitz.open(stream=payload, filetype="pdf") as document: | |
| page_limit = min(document.page_count, max_pages) | |
| for page_index in range(page_limit): | |
| page = document.load_page(page_index) | |
| pixmap = page.get_pixmap(matrix=fitz.Matrix(2, 2), alpha=False) | |
| image = Image.frombytes( | |
| "RGB", | |
| (pixmap.width, pixmap.height), | |
| pixmap.samples, | |
| ) | |
| page_text = pytesseract.image_to_string(image, config="--psm 6").strip() | |
| if page_text: | |
| ocr_text.append(page_text) | |
| pages_processed += 1 | |
| if document.page_count > max_pages: | |
| ocr_text.append( | |
| f"\n[OCR note: processed first {max_pages} of {document.page_count} pages to keep the Space responsive.]" | |
| ) | |
| return "\n".join(ocr_text), pages_processed, "" | |
| def extract_text_from_pdf(pdf_file): | |
| """Extract text from a PDF upload, using OCR when no text layer exists.""" | |
| payload, source_name = read_uploaded_pdf(pdf_file) | |
| text = extract_with_pypdf(payload).strip() | |
| method = "PyPDF2 text layer" | |
| page_count = 0 | |
| warning = "" | |
| if len(text.split()) < 5: | |
| text, page_count = extract_with_pymupdf(payload) | |
| text = text.strip() | |
| method = "PyMuPDF text layer" | |
| if len(text.split()) < 5: | |
| max_pages = int(os.getenv("OCR_MAX_PAGES", "12")) | |
| text, pages_processed, warning = extract_with_ocr(payload, max_pages=max_pages) | |
| text = text.strip() | |
| method = f"OCR over rendered PDF pages ({pages_processed} page{'s' if pages_processed != 1 else ''})" | |
| return text, source_name, method, warning, page_count | |
| def chunk_text(text, chunk_size=500, overlap=50): | |
| """Split text into overlapping chunks.""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if len(chunk.strip()) > 0: | |
| chunks.append(chunk) | |
| return chunks | |
| def process_pdfs(pdf_files, progress=gr.Progress()): | |
| """Process uploaded PDFs and create vector store.""" | |
| global chunks, sources, chunk_vectors | |
| if not pdf_files: | |
| return "β No PDFs uploaded" | |
| chunks = [] | |
| sources = [] | |
| chunk_vectors = [] | |
| extraction_notes = [] | |
| progress(0, desc="Extracting text from PDFs...") | |
| for i, pdf_file in enumerate(pdf_files): | |
| try: | |
| text, source_name, method, warning, page_count = extract_text_from_pdf(pdf_file) | |
| except Exception as exc: | |
| return f"β Could not read PDF: {exc}" | |
| pdf_chunks = chunk_text(text) | |
| chunks.extend(pdf_chunks) | |
| sources.extend([source_name] * len(pdf_chunks)) | |
| word_count = len(text.split()) | |
| if word_count: | |
| note = f"- {source_name}: {word_count:,} words extracted via {method}" | |
| if warning: | |
| note += f" ({warning})" | |
| extraction_notes.append(note) | |
| else: | |
| detail = warning or "no text layer or OCR-readable text was found" | |
| extraction_notes.append( | |
| f"- {source_name}: {detail}." | |
| ) | |
| progress((i + 1) / len(pdf_files), desc=f"Processed {i+1}/{len(pdf_files)} PDFs") | |
| if not chunks: | |
| return ( | |
| "β No text extracted from PDFs\n\n" | |
| + "\n".join(extraction_notes) | |
| + "\n\nThis Space now tries text extraction and OCR automatically. If this still fails, the PDF may contain " | |
| "low-resolution images, protected content, or pages whose text is too blurred for OCR." | |
| ) | |
| progress(0.7, desc="Building lexical retrieval index...") | |
| chunk_vectors = [vectorize(chunk) for chunk in chunks] | |
| return f"β Processed {len(pdf_files)} PDFs into {len(chunks)} chunks\n\n" + "\n".join(extraction_notes) | |
| def retrieve_chunks(query, top_k=3): | |
| """Retrieve most relevant chunks for query.""" | |
| if not chunk_vectors or len(chunks) == 0: | |
| return [], [] | |
| query_vector = vectorize(query) | |
| scored = [ | |
| (idx, cosine_similarity(query_vector, chunk_vector)) | |
| for idx, chunk_vector in enumerate(chunk_vectors) | |
| ] | |
| scored.sort(key=lambda item: item[1], reverse=True) | |
| top = scored[:top_k] | |
| retrieved_chunks = [chunks[i] for i, _ in top] | |
| retrieved_sources = [sources[i] for i, _ in top] | |
| retrieved_scores = [score for _, score in top] | |
| return retrieved_chunks, retrieved_sources, retrieved_scores | |
| def answer_question(question, progress=gr.Progress()): | |
| """Answer question using RAG pipeline.""" | |
| if not question: | |
| return "Please enter a question", "", "" | |
| if not chunk_vectors: | |
| return "Please upload and process PDFs first", "", "" | |
| progress(0, desc="π Step 1: Retrieving relevant chunks...") | |
| retrieved_chunks, retrieved_sources, scores = retrieve_chunks(question, top_k=3) | |
| if not retrieved_chunks: | |
| return "No relevant information found", "", "" | |
| # Format retrieved chunks for display | |
| chunks_display = "" | |
| for i, (chunk, source, score) in enumerate(zip(retrieved_chunks, retrieved_sources, scores)): | |
| chunks_display += f"**Chunk {i+1}** (from {source}, lexical similarity: {score:.3f})\n" | |
| chunks_display += f"{chunk[:300]}...\n\n" | |
| progress(0.5, desc="π€ Step 2: Generating answer...") | |
| # Create prompt for generation | |
| context = "\n\n".join(retrieved_chunks) | |
| prompt = f"""Based on the following context, answer the question. If the answer is not in the context, say so. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| try: | |
| if not os.getenv("HF_TOKEN"): | |
| raise RuntimeError("HF_TOKEN is not configured; using local extractive fallback.") | |
| response = "" | |
| for token in client.text_generation( | |
| prompt, | |
| model="meta-llama/Llama-3.2-3B-Instruct", | |
| max_new_tokens=300, | |
| stream=True | |
| ): | |
| response += token | |
| progress(1.0, desc="β Done!") | |
| # Format citations | |
| citations = "\n\n**Sources:**\n" | |
| for i, source in enumerate(set(retrieved_sources)): | |
| citations += f"- {source}\n" | |
| return response.strip(), chunks_display, citations | |
| except Exception as e: | |
| fallback = ( | |
| "No hosted generation token is configured, so this Space is returning the most relevant retrieved evidence instead.\n\n" | |
| f"**Question:** {question}\n\n" | |
| f"**Best evidence:** {retrieved_chunks[0][:900]}..." | |
| ) | |
| citations = "\n\n**Sources:**\n" | |
| for source in sorted(set(retrieved_sources)): | |
| citations += f"- {source}\n" | |
| return fallback, chunks_display, citations | |
| # Gradio Interface | |
| with gr.Blocks(title="RAG from Scratch", theme=gr.themes.Soft()) as demo: | |
| create_premium_hero( | |
| "RAG from Scratch", | |
| "A transparent Retrieval-Augmented Generation lab: chunk PDFs, retrieve passages, and answer with cited context.", | |
| "π", | |
| badge="Retrieval Systems", | |
| highlights=["Lexical retrieval", "Chunk inspection", "HF Inference"], | |
| ) | |
| create_method_panel({ | |
| "Pipeline": "PDF text extraction β overlapping chunks β lexical retrieval β grounded generation.", | |
| "What it proves": "You can build and explain the moving parts behind production RAG systems.", | |
| "Community value": "A teaching Space for debugging retrieval quality before adding orchestration complexity.", | |
| }) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 1: Upload PDFs") | |
| pdf_input = gr.File( | |
| file_count="multiple", | |
| file_types=[".pdf"], | |
| type="filepath", | |
| label="Upload PDF files" | |
| ) | |
| process_btn = gr.Button("Process PDFs", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Step 2: Ask Questions") | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="What is this document about?", | |
| lines=2 | |
| ) | |
| ask_btn = gr.Button("Get Answer", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Answer") | |
| answer_output = gr.Textbox(label="Generated Answer", lines=6) | |
| citations_output = gr.Markdown(label="Citations") | |
| with gr.Accordion("π Retrieved Chunks (View Pipeline)", open=False): | |
| chunks_output = gr.Markdown(label="Chunks Used") | |
| gr.Markdown(""" | |
| ### π‘ How it works: | |
| - **Indexing**: Convert chunks into transparent lexical vectors | |
| - **Retrieval**: Find chunks with the strongest term overlap | |
| - **Generation**: LLM uses retrieved chunks to answer | |
| This is the foundation of most modern Q&A systems! | |
| """) | |
| process_btn.click( | |
| process_pdfs, | |
| inputs=[pdf_input], | |
| outputs=[status] | |
| ) | |
| ask_btn.click( | |
| answer_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, chunks_output, citations_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |