Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| st.set_page_config(page_title="RAG Book Analyzer", layout="wide") | |
| import torch | |
| import numpy as np | |
| import faiss | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| import fitz # PyMuPDF | |
| import docx2txt | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # ------------------------ | |
| # Configuration (optimized for reliability) | |
| # ------------------------ | |
| MODEL_NAME = "microsoft/phi-2" | |
| EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Efficient embedding model | |
| CHUNK_SIZE = 512 | |
| CHUNK_OVERLAP = 64 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_TEXT_LENGTH = 3000 # To prevent OOM errors | |
| # ------------------------ | |
| # Model Loading with Robust Error Handling | |
| # ------------------------ | |
| def load_models(): | |
| try: | |
| # Load tokenizer with special settings for Phi-2 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| padding_side="left" | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model with safe defaults | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| trust_remote_code=True, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Load efficient embedding model | |
| embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE) | |
| return tokenizer, model, embedder | |
| except Exception as e: | |
| st.error(f"Model loading failed: {str(e)}") | |
| st.stop() | |
| tokenizer, model, embedder = load_models() | |
| # ------------------------ | |
| # Text Processing Functions | |
| # ------------------------ | |
| def split_text(text): | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| length_function=len | |
| ) | |
| return splitter.split_text(text) | |
| def extract_text(file): | |
| try: | |
| if file.type == "application/pdf": | |
| doc = fitz.open(stream=file.read(), filetype="pdf") | |
| return "\n".join([page.get_text() for page in doc]) | |
| elif file.type == "text/plain": | |
| return file.read().decode("utf-8") | |
| elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
| return docx2txt.process(file) | |
| else: | |
| st.error(f"Unsupported file type: {file.type}") | |
| return "" | |
| except Exception as e: | |
| st.error(f"Error processing file: {str(e)}") | |
| return "" | |
| def build_index(chunks): | |
| embeddings = embedder.encode(chunks, show_progress_bar=False) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| return index | |
| # ------------------------ | |
| # AI Generation Functions (with safeguards) | |
| # ------------------------ | |
| def generate_summary(text): | |
| text = text[:MAX_TEXT_LENGTH] # Prevent long inputs | |
| prompt = f"Instruction: Summarize this book in a concise paragraph\nText: {text}\nSummary:" | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True | |
| ).to(DEVICE) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| summary = tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| # Extract just the summary part | |
| if "Summary:" in summary: | |
| return summary.split("Summary:")[-1].strip() | |
| return summary.replace(prompt, "").strip() | |
| def generate_answer(query, context): | |
| context = context[:MAX_TEXT_LENGTH] # Limit context size | |
| prompt = f"Instruction: Answer this question based on the context. If unsure, say 'I don't know'.\nQuestion: {query}\nContext: {context}\nAnswer:" | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True | |
| ).to(DEVICE) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=0.4, | |
| top_p=0.85, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| answer = tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| # Extract just the answer part | |
| if "Answer:" in answer: | |
| return answer.split("Answer:")[-1].strip() | |
| return answer.replace(prompt, "").strip() | |
| # ------------------------ | |
| # Streamlit UI | |
| # ------------------------ | |
| st.title("π RAG-Based Book Analyzer") | |
| st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.") | |
| st.warning("Note: First run will download models (~1.5GB). Please be patient!") | |
| uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"]) | |
| if uploaded_file: | |
| with st.spinner("Extracting text from file..."): | |
| text = extract_text(uploaded_file) | |
| if not text: | |
| st.error("Failed to extract text. Please try another file.") | |
| st.stop() | |
| st.success(f"β Extracted {len(text)} characters") | |
| with st.spinner("Generating summary (this may take a minute)..."): | |
| summary = generate_summary(text) | |
| st.markdown("### Book Summary") | |
| st.info(summary) | |
| with st.spinner("Preparing document for questions..."): | |
| chunks = split_text(text) | |
| index = build_index(chunks) | |
| st.session_state.chunks = chunks | |
| st.session_state.index = index | |
| st.success(f"β Document indexed with {len(chunks)} chunks") | |
| st.divider() | |
| if 'chunks' in st.session_state: | |
| st.markdown("### β Ask a Question about the Book") | |
| query = st.text_input("Enter your question:", key="question") | |
| if query: | |
| with st.spinner("Searching for answers..."): | |
| # Retrieve top 3 relevant chunks | |
| query_embedding = embedder.encode([query]) | |
| distances, indices = st.session_state.index.search(query_embedding, k=3) | |
| # Safely retrieve chunks | |
| retrieved_chunks = [] | |
| for i in indices[0]: | |
| if i < len(st.session_state.chunks): | |
| retrieved_chunks.append(st.session_state.chunks[i]) | |
| context = "\n\n".join(retrieved_chunks) | |
| # Generate answer | |
| answer = generate_answer(query, context) | |
| # Display results | |
| st.markdown("### π¬ Answer") | |
| st.success(answer) | |
| with st.expander("View context used for answer"): | |
| st.text(context) |