Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import faiss | |
| import gradio as gr | |
| import numpy as np | |
| import pdfplumber | |
| import docx | |
| from typing import List, Optional | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| # Utility: Clean text helper | |
| def clean_text(text: str) -> str: | |
| text = re.sub(r'\s+', ' ', text) # collapse whitespace | |
| text = text.strip() | |
| return text | |
| # Text chunking (smaller chunks for better semantic search) | |
| def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]: | |
| words = text.split() | |
| chunks = [] | |
| start = 0 | |
| while start < len(words): | |
| end = min(start + chunk_size, len(words)) | |
| chunk = ' '.join(words[start:end]) | |
| chunks.append(clean_text(chunk)) | |
| start += chunk_size - overlap | |
| return chunks | |
| # Document loader for txt, pdf, docx | |
| def load_document(file_path: str) -> str: | |
| ext = os.path.splitext(file_path)[1].lower() | |
| text = "" | |
| if ext == ".txt": | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| elif ext == ".pdf": | |
| with pdfplumber.open(file_path) as pdf: | |
| pages = [page.extract_text() for page in pdf.pages if page.extract_text()] | |
| text = "\n".join(pages) | |
| elif ext == ".docx": | |
| doc = docx.Document(file_path) | |
| paragraphs = [para.text for para in doc.paragraphs if para.text.strip()] | |
| text = "\n".join(paragraphs) | |
| else: | |
| raise ValueError(f"Unsupported file type: {ext}") | |
| return clean_text(text) | |
| class SmartDocumentRAG: | |
| def __init__(self): | |
| print("Loading embedder and models...") | |
| self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # small, fast | |
| self.documents = [] | |
| self.embeddings = None | |
| self.index = None | |
| self.is_indexed = False | |
| # Load QA pipelines | |
| self.model_type = "distilbert-qa" # change to "flan-t5" for generative | |
| if self.model_type == "distilbert-qa": | |
| self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") | |
| elif self.model_type == "flan-t5": | |
| self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base") | |
| else: | |
| self.qa_pipeline = None | |
| self.document_summary = "" | |
| def process_documents(self, files: List[gr.File]) -> str: | |
| if not files: | |
| return "β οΈ No files uploaded." | |
| print(f"Processing {len(files)} files...") | |
| all_text = "" | |
| for file in files: | |
| try: | |
| # gr.File is a dict-like, get 'name' key for path | |
| path = file.name if hasattr(file, 'name') else file | |
| text = load_document(path) | |
| all_text += text + "\n" | |
| except Exception as e: | |
| print(f"Error loading {file}: {e}") | |
| all_text = clean_text(all_text) | |
| chunks = chunk_text(all_text) | |
| if not chunks: | |
| return "β οΈ No text extracted from documents." | |
| self.documents = chunks | |
| print(f"Created {len(chunks)} text chunks.") | |
| # Embed and build FAISS index | |
| self.embeddings = self.embedder.encode(self.documents, convert_to_numpy=True) | |
| dimension = self.embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dimension) # Cosine similarity with normalized vectors | |
| faiss.normalize_L2(self.embeddings) | |
| self.index.add(self.embeddings) | |
| self.is_indexed = True | |
| # Generate summary (simple: first 3 chunks joined) | |
| summary_text = " ".join(self.documents[:3]) | |
| self.document_summary = summary_text if summary_text else "Summary not available." | |
| return f"β Processed {len(files)} files and created index with {len(chunks)} chunks." | |
| def find_relevant_content(self, query: str, k: int = 5) -> str: | |
| if not self.is_indexed: | |
| return "" | |
| query_emb = self.embedder.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(query_emb) | |
| k = min(k, len(self.documents)) | |
| distances, indices = self.index.search(query_emb, k) | |
| relevant_chunks = [] | |
| for dist, idx in zip(distances[0], indices[0]): | |
| if dist > 0.1 and idx < len(self.documents): | |
| relevant_chunks.append(self.documents[idx]) | |
| context = " ".join(relevant_chunks) | |
| print(f"Found {len(relevant_chunks)} relevant chunks with distances >0.1") | |
| return context | |
| def answer_question(self, query: str) -> str: | |
| if not query.strip(): | |
| return "β Please ask a valid question." | |
| if not self.is_indexed: | |
| return "π Please upload and process documents first." | |
| query_lower = query.lower() | |
| if any(word in query_lower for word in ['summary', 'summarize', 'overview', 'about']): | |
| return f"π Document Summary:\n\n{self.document_summary}" | |
| context = self.find_relevant_content(query, k=5) | |
| print(f"Context for query: {context[:500]}...") | |
| if not context: | |
| return "π Sorry, no relevant information found. Try rephrasing your question." | |
| try: | |
| if self.model_type == "distilbert-qa": | |
| result = self.qa_pipeline(question=query, context=context) | |
| print(f"QA pipeline result: {result}") | |
| answer = result.get('answer', '').strip() | |
| score = result.get('score', 0.0) | |
| if not answer or score < 0.05: | |
| return "π€ I couldn't find a confident answer based on the documents." | |
| snippet = context[:300].strip() | |
| if len(context) > 300: | |
| snippet += "..." | |
| return f"**Answer:** {answer}\n\n*Context snippet:* {snippet}" | |
| elif self.model_type == "flan-t5": | |
| prompt = ( | |
| f"Answer the question based on the context below.\n\n" | |
| f"Context:\n{context}\n\n" | |
| f"Question: {query}\nAnswer:" | |
| ) | |
| result = self.qa_pipeline(prompt, max_length=200, num_return_sequences=1) | |
| print(f"Generative pipeline result: {result}") | |
| answer = result[0]['generated_text'].replace(prompt, '').strip() | |
| if not answer: | |
| return "π€ I couldn't find a confident answer based on the documents." | |
| return f"**Answer:** {answer}" | |
| else: | |
| return "β οΈ Unsupported model type." | |
| except Exception as e: | |
| print(f"Exception in answer_question: {e}") | |
| return f"β Error: {str(e)}" | |
| # Create Gradio UI | |
| def create_interface(): | |
| rag = SmartDocumentRAG() | |
| with gr.Blocks(title="π§ Enhanced Document Q&A") as demo: | |
| gr.Markdown( | |
| """ | |
| # π§ Enhanced Document Q&A System | |
| **Features:** | |
| - Semantic search with FAISS + SentenceTransformer | |
| - Supports PDF, DOCX, TXT uploads | |
| - Uses DistilBERT or Flan-T5 for Q&A | |
| - Shows answer with context snippet | |
| """ | |
| ) | |
| with gr.Tab("Upload & Process"): | |
| file_upload = gr.File(file_types=['.pdf', '.docx', '.txt'], label="Upload Documents", file_count="multiple") | |
| process_btn = gr.Button("Process Documents") | |
| process_status = gr.Textbox(label="Processing Status", interactive=False, lines=4) | |
| process_btn.click(fn=rag.process_documents, inputs=[file_upload], outputs=[process_status]) | |
| with gr.Tab("Q&A"): | |
| question_input = gr.Textbox(label="Ask your question", lines=2, placeholder="Type your question here...") | |
| ask_btn = gr.Button("Get Answer") | |
| answer_output = gr.Textbox(label="Answer", lines=8, interactive=False) | |
| ask_btn.click(fn=rag.answer_question, inputs=[question_input], outputs=[answer_output]) | |
| with gr.Tab("Summary"): | |
| summary_btn = gr.Button("Get Document Summary") | |
| summary_output = gr.Textbox(label="Summary", lines=6, interactive=False) | |
| summary_btn.click(fn=lambda: rag.answer_question("summary"), inputs=[], outputs=[summary_output]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |